admin-dashboard-backend/socketclients/socketclients.go

358 lines
9.5 KiB
Go

package socketclients
import (
"encoding/base64"
"janex/admin-dashboard-backend/modules/cache"
"janex/admin-dashboard-backend/modules/database"
"janex/admin-dashboard-backend/modules/structs"
"janex/admin-dashboard-backend/modules/utils"
"time"
"github.com/gofiber/websocket/v2"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
)
func BroadcastMessage(sendSocketMessage structs.SendSocketMessage) {
for _, client := range cache.GetSocketClients() {
client.SendMessage(sendSocketMessage)
}
}
func BroadcastMessageExceptUserSessionId(ignoreUserSessionId string, sendSocketMessage structs.SendSocketMessage) {
for _, client := range cache.GetSocketClients() {
if client.SessionId != ignoreUserSessionId {
client.SendMessage(sendSocketMessage)
}
}
}
func UpdateConnectedUsers(userId string) {
var user structs.User
database.DB.First(&user, "id = ?", userId)
BroadcastMessage(structs.SendSocketMessage{
Cmd: utils.SentCmdUpdateConnectedUsers,
Body: struct {
WebSocketUsersCount int
UserId string
ConnectionStatus uint8
LastOnline time.Time
}{
WebSocketUsersCount: len(cache.GetSocketClients()),
UserId: userId,
ConnectionStatus: isUserGenerallyConnected(userId),
LastOnline: user.LastOnline,
},
})
}
func SendMessageToUser(userId string, ignoreUserSessionId string, sendSocketMessage structs.SendSocketMessage) {
for _, client := range cache.GetSocketClients() {
if client.UserId == userId && client.SessionId != ignoreUserSessionId {
client.SendMessage(sendSocketMessage)
}
}
}
func SendMessageOnlyToSessionId(sessionId string, sendSocketMessage structs.SendSocketMessage) {
for _, client := range cache.GetSocketClients() {
if client.SessionId == sessionId {
client.SendMessage(sendSocketMessage)
}
}
}
// This close all connections that are connected with one session id.
// For example when a user has two browser tabs opened
func CloseAllUserSessionConnections(sessionId string) {
for _, client := range cache.GetSocketClients() {
if client.SessionId == sessionId {
client.SendSessionClosedMessage()
}
}
}
// Used to close all user connections
// For example when a user changed his password
func CloseAndDeleteAllUserConnections(userId string) {
for _, client := range cache.GetSocketClients() {
if client.UserId == userId {
client.SendSessionClosedMessage()
}
}
database.DB.Where("user_id = ?", userId).Delete(&structs.UserSession{})
}
func GetUserSessions(userId string) []structs.UserSessionSocket {
var userSessions []structs.UserSession
database.DB.Where("user_id = ?", userId).Find(&userSessions)
var userSessionsSocket []structs.UserSessionSocket
socketClients := cache.GetSocketClients()
for _, userSession := range userSessions {
userSessionsSocket = append(userSessionsSocket, structs.UserSessionSocket{
IdForDeletion: userSession.IdForDeletion,
UserAgent: userSession.UserAgent,
ConnectionStatus: isUserSessionConnected(userSession.Id, socketClients),
LastUsed: userSession.LastUsed,
ExpiresAt: userSession.ExpiresAt,
})
}
return userSessionsSocket
}
func UpdateUserSessionsForUser(userId string, ignoreUserSessionId string) {
GetUserSessions(userId)
SendMessageToUser(userId, ignoreUserSessionId, structs.SendSocketMessage{
Cmd: utils.SentCmdUpdateUserSessions,
Body: GetUserSessions(userId),
})
}
func isUserSessionConnected(userSessionId string, socketClients []*structs.SocketClient) uint8 {
for _, socketClient := range socketClients {
if socketClient.SessionId == userSessionId {
return 1
}
}
return 0
}
// Used to determine if a user is connected regardless of the session used
func isUserGenerallyConnected(userId string) uint8 {
for _, socketClient := range cache.GetSocketClients() {
if socketClient.UserId == userId {
return 1
}
}
return 0
}
// Get all users from database.
// This is used in the UI to display all users.
func GetAllUsers() []structs.AllUsers {
var users []structs.User
var allUsers []structs.AllUsers
database.DB.Find(&users)
for _, user := range users {
allUsers = append(allUsers, structs.AllUsers{
Id: user.Id,
RoleId: user.RoleId,
Avatar: user.Avatar,
Username: user.Username,
ConnectionStatus: isUserGenerallyConnected(user.Id),
LastOnline: user.LastOnline,
})
}
return allUsers
}
func GetAllScanners() []structs.Scanner {
var scanners []structs.Scanner
var allScanners []structs.Scanner
database.DB.Find(&scanners)
for _, scanner := range scanners {
// clear session to prevent leaking and sending to ui
scanner.Session = ""
allScanners = append(allScanners, scanner)
}
return allScanners
}
func isUsernameAvailable(username string) bool {
var user structs.User
database.DB.Select("username").Where("username = ?", username).Find(&user)
return user.Username == ""
}
func isEmailAvailable(email string) bool {
var user structs.User
database.DB.Select("email").Where("email = ?", email).Find(&user)
return user.Email == ""
}
func UpdateUserProfile(conn *websocket.Conn, changes map[string]interface{}) {
sessionId := conn.Locals("sessionId").(string)
userId := conn.Locals("userId").(string)
var user structs.User
var updates = make(map[string]interface{})
var changesResult = make(map[string]uint8)
if changes["username"] != nil {
username := changes["username"].(string)
if isUsernameLengthValid(username) { // only affected if username was manipulated as min and max is provided in web ui
if isUsernameAvailable(username) {
user.Username = username
updates["Username"] = username
changesResult["Username"] = 0
} else {
changesResult["Username"] = 1
}
}
}
if changes["email"] != nil {
email := changes["email"].(string)
if isEmailAvailable(email) {
user.Email = email
updates["Email"] = email
changesResult["Email"] = 0
} else {
changesResult["Email"] = 1
}
}
if changes["oldPassword"] != nil && changes["newPassword"] != nil {
oldPassword := changes["oldPassword"].(string)
newPassword := changes["newPassword"].(string)
decodedOldPassword, err := base64.StdEncoding.DecodeString(oldPassword)
decodedNewPassword, err1 := base64.StdEncoding.DecodeString(newPassword)
if err == nil && err1 == nil {
if utils.IsPasswordLengthValid(string(decodedOldPassword)) { // only affected if username was manipulated as min and max is provided in web ui
database.DB.Select("password").First(&user, "id = ?", userId)
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), decodedOldPassword); err == nil {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(decodedNewPassword), bcrypt.DefaultCost)
if err == nil {
user.Password = string(hashedPassword)
} else {
log.Error().Msgf("Failed to generate hash password %s", err.Error())
}
} else {
log.Error().Msg("Incorrect password")
changesResult["Password"] = 1
}
}
} else {
if err != nil {
log.Error().Msgf("Failed to decode old password %s", err.Error())
}
if err1 != nil {
log.Error().Msgf("Failed to decode new password %s", err1.Error())
}
}
}
if len(changes) > 0 {
user.UpdatedAt = time.Now()
database.DB.Model(&structs.User{}).Where("id = ?", userId).Updates(user)
if changes["username"] != nil || changes["email"] != nil || changes["oldPassword"] != nil && changes["newPassword"] != nil {
if changes["oldPassword"] != nil && changes["newPassword"] != nil {
// user has changed password - logout all his sessions
CloseAndDeleteAllUserConnections(userId)
} else {
SendMessageOnlyToSessionId(sessionId, structs.SendSocketMessage{
Cmd: utils.SentCmdUserProfileUpdated,
Body: struct {
UserId string
Changes map[string]interface{}
Result map[string]uint8
}{
UserId: userId,
Changes: updates,
Result: changesResult,
},
})
}
BroadcastMessageExceptUserSessionId(sessionId, structs.SendSocketMessage{
Cmd: utils.SentCmdUserProfileUpdated,
Body: struct {
UserId string
Changes map[string]interface{}
}{
UserId: userId,
Changes: updates,
},
})
}
}
}
func isUsernameLengthValid(username string) bool {
l := len(username)
return l > utils.MinUsername && l < utils.MaxUsername
}
func GetAllRoles() []structs.Role {
var roles []structs.Role
database.DB.Find(&roles)
return roles
}
func GetPermissionsByRoleId(roleId string) []string {
var rolePermissions []structs.RolePermission
database.DB.Where("role_id = ?", roleId).Find(&rolePermissions)
var permissions []string
for _, rolePermission := range rolePermissions {
permissions = append(permissions, rolePermission.PermissionId)
}
return permissions
}
// Retrieve all roles with a list of all permissions for each role
func GetAdminAreaRolesPermissions() []structs.RolePermissions {
roles := GetAllRoles()
var rolesPermissions []structs.RolePermission
database.DB.Find(&rolesPermissions)
log.Debug().Msgf("rolePermissions: %v", rolesPermissions)
var rolePermissions []structs.RolePermissions
for _, role := range roles {
var permissions []string
for _, rolePermission := range rolesPermissions {
if rolePermission.RoleId == role.Id {
permissions = append(permissions, rolePermission.PermissionId)
}
}
log.Debug().Msgf("permissions %v", permissions)
rolePermissions = append(rolePermissions, structs.RolePermissions{
RoleId: role.Id,
Permissions: permissions,
})
}
log.Debug().Msgf("role permissions: %v", rolePermissions)
return rolePermissions
}