package socketclients import ( "encoding/base64" "janex/admin-dashboard-backend/modules/cache" "janex/admin-dashboard-backend/modules/database" "janex/admin-dashboard-backend/modules/structs" "janex/admin-dashboard-backend/modules/utils" "time" "github.com/gofiber/websocket/v2" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" ) func BroadcastMessage(sendSocketMessage structs.SendSocketMessage) { for _, client := range cache.GetSocketClients() { client.SendMessage(sendSocketMessage) } } func BroadcastMessageExceptUserSessionId(ignoreUserSessionId string, sendSocketMessage structs.SendSocketMessage) { for _, client := range cache.GetSocketClients() { if client.SessionId != ignoreUserSessionId { client.SendMessage(sendSocketMessage) } } } func UpdateConnectedUsers(userId string) { var user structs.User database.DB.First(&user, "id = ?", userId) BroadcastMessage(structs.SendSocketMessage{ Cmd: utils.SentCmdUpdateConnectedUsers, Body: struct { WebSocketUsersCount int UserId string ConnectionStatus uint8 LastOnline time.Time }{ WebSocketUsersCount: len(cache.GetSocketClients()), UserId: userId, ConnectionStatus: isUserGenerallyConnected(userId), LastOnline: user.LastOnline, }, }) } func SendMessageToUser(userId string, ignoreUserSessionId string, sendSocketMessage structs.SendSocketMessage) { for _, client := range cache.GetSocketClients() { if client.UserId == userId && client.SessionId != ignoreUserSessionId { client.SendMessage(sendSocketMessage) } } } func SendMessageOnlyToSessionId(sessionId string, sendSocketMessage structs.SendSocketMessage) { for _, client := range cache.GetSocketClients() { if client.SessionId == sessionId { client.SendMessage(sendSocketMessage) } } } // This close all connections that are connected with one session id. // For example when a user has two browser tabs opened func CloseAllUserSessionConnections(sessionId string) { for _, client := range cache.GetSocketClients() { if client.SessionId == sessionId { client.SendSessionClosedMessage() } } } // Used to close all user connections // For example when a user changed his password func CloseAndDeleteAllUserConnections(userId string) { for _, client := range cache.GetSocketClients() { if client.UserId == userId { client.SendSessionClosedMessage() } } database.DB.Where("user_id = ?", userId).Delete(&structs.UserSession{}) } func GetUserSessions(userId string) []structs.UserSessionSocket { var userSessions []structs.UserSession database.DB.Where("user_id = ?", userId).Find(&userSessions) var userSessionsSocket []structs.UserSessionSocket socketClients := cache.GetSocketClients() for _, userSession := range userSessions { userSessionsSocket = append(userSessionsSocket, structs.UserSessionSocket{ IdForDeletion: userSession.IdForDeletion, UserAgent: userSession.UserAgent, ConnectionStatus: isUserSessionConnected(userSession.Id, socketClients), LastUsed: userSession.LastUsed, ExpiresAt: userSession.ExpiresAt, }) } return userSessionsSocket } func UpdateUserSessionsForUser(userId string, ignoreUserSessionId string) { GetUserSessions(userId) SendMessageToUser(userId, ignoreUserSessionId, structs.SendSocketMessage{ Cmd: utils.SentCmdUpdateUserSessions, Body: GetUserSessions(userId), }) } func isUserSessionConnected(userSessionId string, socketClients []*structs.SocketClient) uint8 { for _, socketClient := range socketClients { if socketClient.SessionId == userSessionId { return 1 } } return 0 } // Used to determine if a user is connected regardless of the session used func isUserGenerallyConnected(userId string) uint8 { for _, socketClient := range cache.GetSocketClients() { if socketClient.UserId == userId { return 1 } } return 0 } // Get all users from database. // This is used in the UI to display all users. func GetAllUsers() []structs.AllUsers { var users []structs.User var allUsers []structs.AllUsers database.DB.Find(&users) for _, user := range users { allUsers = append(allUsers, structs.AllUsers{ Id: user.Id, RoleId: user.RoleId, Avatar: user.Avatar, Username: user.Username, ConnectionStatus: isUserGenerallyConnected(user.Id), LastOnline: user.LastOnline, }) } return allUsers } func GetAllScanners() []structs.Scanner { var scanners []structs.Scanner var allScanners []structs.Scanner database.DB.Find(&scanners) for _, scanner := range scanners { // clear session to prevent leaking and sending to ui scanner.Session = "" allScanners = append(allScanners, scanner) } return allScanners } func isUsernameAvailable(username string) bool { var user structs.User database.DB.Select("username").Where("username = ?", username).Find(&user) return user.Username == "" } func isEmailAvailable(email string) bool { var user structs.User database.DB.Select("email").Where("email = ?", email).Find(&user) return user.Email == "" } func UpdateUserProfile(conn *websocket.Conn, changes map[string]interface{}) { sessionId := conn.Locals("sessionId").(string) userId := conn.Locals("userId").(string) var user structs.User var updates = make(map[string]interface{}) var changesResult = make(map[string]uint8) if changes["username"] != nil { username := changes["username"].(string) if isUsernameLengthValid(username) { // only affected if username was manipulated as min and max is provided in web ui if isUsernameAvailable(username) { user.Username = username updates["Username"] = username changesResult["Username"] = 0 } else { changesResult["Username"] = 1 } } } if changes["email"] != nil { email := changes["email"].(string) if isEmailAvailable(email) { user.Email = email updates["Email"] = email changesResult["Email"] = 0 } else { changesResult["Email"] = 1 } } if changes["oldPassword"] != nil && changes["newPassword"] != nil { oldPassword := changes["oldPassword"].(string) newPassword := changes["newPassword"].(string) decodedOldPassword, err := base64.StdEncoding.DecodeString(oldPassword) decodedNewPassword, err1 := base64.StdEncoding.DecodeString(newPassword) if err == nil && err1 == nil { if utils.IsPasswordLengthValid(string(decodedOldPassword)) { // only affected if username was manipulated as min and max is provided in web ui database.DB.Select("password").First(&user, "id = ?", userId) if err := bcrypt.CompareHashAndPassword([]byte(user.Password), decodedOldPassword); err == nil { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(decodedNewPassword), bcrypt.DefaultCost) if err == nil { user.Password = string(hashedPassword) } else { log.Error().Msgf("Failed to generate hash password %s", err.Error()) } } else { log.Error().Msg("Incorrect password") changesResult["Password"] = 1 } } } else { if err != nil { log.Error().Msgf("Failed to decode old password %s", err.Error()) } if err1 != nil { log.Error().Msgf("Failed to decode new password %s", err1.Error()) } } } if len(changes) > 0 { user.UpdatedAt = time.Now() database.DB.Model(&structs.User{}).Where("id = ?", userId).Updates(user) if changes["username"] != nil || changes["email"] != nil || changes["oldPassword"] != nil && changes["newPassword"] != nil { if changes["oldPassword"] != nil && changes["newPassword"] != nil { // user has changed password - logout all his sessions CloseAndDeleteAllUserConnections(userId) } else { SendMessageOnlyToSessionId(sessionId, structs.SendSocketMessage{ Cmd: utils.SentCmdUserProfileUpdated, Body: struct { UserId string Changes map[string]interface{} Result map[string]uint8 }{ UserId: userId, Changes: updates, Result: changesResult, }, }) } BroadcastMessageExceptUserSessionId(sessionId, structs.SendSocketMessage{ Cmd: utils.SentCmdUserProfileUpdated, Body: struct { UserId string Changes map[string]interface{} }{ UserId: userId, Changes: updates, }, }) } } } func isUsernameLengthValid(username string) bool { l := len(username) return l > utils.MinUsername && l < utils.MaxUsername } func GetAllRoles() []structs.Role { var roles []structs.Role database.DB.Find(&roles) return roles } func GetPermissionsByRoleId(roleId string) []string { var rolePermissions []structs.RolePermission database.DB.Where("role_id = ?", roleId).Find(&rolePermissions) var permissions []string for _, rolePermission := range rolePermissions { permissions = append(permissions, rolePermission.PermissionId) } return permissions } // Retrieve all roles with a list of all permissions for each role func GetAdminAreaRolesPermissions() []structs.RolePermissions { roles := GetAllRoles() var rolesPermissions []structs.RolePermission database.DB.Find(&rolesPermissions) log.Debug().Msgf("rolePermissions: %v", rolesPermissions) var rolePermissions []structs.RolePermissions for _, role := range roles { var permissions []string for _, rolePermission := range rolesPermissions { if rolePermission.RoleId == role.Id { permissions = append(permissions, rolePermission.PermissionId) } } log.Debug().Msgf("permissions %v", permissions) rolePermissions = append(rolePermissions, structs.RolePermissions{ RoleId: role.Id, Permissions: permissions, }) } log.Debug().Msgf("role permissions: %v", rolePermissions) return rolePermissions }