diff --git a/main.go b/main.go index b19214f..4f073a9 100644 --- a/main.go +++ b/main.go @@ -53,25 +53,21 @@ func main() { database.DB.First(&userSession, "id = ?", sessionId) - if userSession.Id == "" { - return c.SendStatus(fiber.StatusUnauthorized) + if userSession.Id != "" { + log.Info().Msg("session id: " + userSession.Id + " user: " + userSession.Id) + + var user structs.User + + database.DB.First(&user, "id = ?", userSession.UserId) + + if user.Id != "" { + log.Info().Msg("user " + user.Id + user.Username) + + c.Locals("sessionId", sessionId) + c.Locals("userId", user.Id) + } } - log.Info().Msg("session id: " + userSession.Id + " user: " + userSession.Id) - - var user structs.User - - database.DB.First(&user, "id = ?", userSession.UserId) - - if user.Id == "" { - return c.SendStatus(fiber.StatusInternalServerError) - } - - log.Info().Msg("user " + user.Id + user.Username) - - c.Locals("sessionId", sessionId) - c.Locals("userId", user.Id) - return c.Next() } diff --git a/modules/structs/socket.go b/modules/structs/socket.go index 2fbece1..fba6f18 100644 --- a/modules/structs/socket.go +++ b/modules/structs/socket.go @@ -35,6 +35,10 @@ func (socketClient *SocketClient) SendCloseMessage() error { return socketClient.writeMessage(websocket.CloseMessage, SendSocketMessage{}, true) } +func (socketClient *SocketClient) SendUnauthorizedCloseMessage() error { + return socketClient.writeMessage(websocket.CloseMessage, SendSocketMessage{}, true) +} + func (socketClient *SocketClient) SendMessage(message SendSocketMessage) error { return socketClient.writeMessage(websocket.TextMessage, message, false) } @@ -44,7 +48,8 @@ func (socketClient *SocketClient) writeMessage(messageType int, message SendSock var err error if closeMessage { - //marshaledMessage = websocket.FormatCloseMessage(utils.WsCloseCodeNewConnectionWasMade, "") + // Status codes in the range 4000-4999 are reserved for private use + marshaledMessage = websocket.FormatCloseMessage(4001, "") } else { marshaledMessage, err = json.Marshal(message) @@ -71,3 +76,8 @@ func (socketClient *SocketClient) writeMessage(messageType int, message SendSock return nil } + +type InitUserSocketConnection struct { + Username string + Email string +} diff --git a/modules/utils/globals.go b/modules/utils/globals.go index 482b96c..b66d382 100644 --- a/modules/utils/globals.go +++ b/modules/utils/globals.go @@ -14,6 +14,12 @@ const ( HeaderXAuthorization = "X-Authorization" ) +// commands sent to web clients +const ( + SentInitUserSocketConnection = 1 + SentCmdUpdateConnectedUsers = 2 +) + var ( generalRules = map[string]string{ "Username": "required,min=" + minUsername + ",max=" + maxUsername, diff --git a/socketclients/socketclients.go b/socketclients/socketclients.go index 5722889..ae310b2 100644 --- a/socketclients/socketclients.go +++ b/socketclients/socketclients.go @@ -3,11 +3,7 @@ package socketclients import ( "janex/admin-dashboard-backend/modules/cache" "janex/admin-dashboard-backend/modules/structs" -) - -// commands sent to web clients -const ( - SentCmdUpdateConnectedUsers = 1 + "janex/admin-dashboard-backend/modules/utils" ) func BroadcastMessage(sendSocketMessage structs.SendSocketMessage) { @@ -18,7 +14,7 @@ func BroadcastMessage(sendSocketMessage structs.SendSocketMessage) { func UpdateConnectedUsers() { BroadcastMessage(structs.SendSocketMessage{ - Cmd: SentCmdUpdateConnectedUsers, + Cmd: utils.SentCmdUpdateConnectedUsers, Body: len(cache.GetSocketClients()), }) } diff --git a/socketserver/hub.go b/socketserver/hub.go index 0d31067..2a15a04 100644 --- a/socketserver/hub.go +++ b/socketserver/hub.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "janex/admin-dashboard-backend/modules/cache" + "janex/admin-dashboard-backend/modules/database" "janex/admin-dashboard-backend/modules/structs" + "janex/admin-dashboard-backend/modules/utils" "janex/admin-dashboard-backend/socketclients" "github.com/gofiber/websocket/v2" @@ -22,6 +24,12 @@ func RunHub() { userId := fmt.Sprintf("%v", newSocketClient.Conn.Locals("userId")) sessionId := fmt.Sprintf("%v", newSocketClient.Conn.Locals("sessionId")) + // close connection instantly if sessionId is empty + if sessionId == "" { + newSocketClient.SendUnauthorizedCloseMessage() + continue + } + newSocketClient.SessionId = sessionId newSocketClient.UserId = userId @@ -30,6 +38,18 @@ func RunHub() { log.Debug().Msgf("clients: %d", len(cache.GetSocketClients())) log.Debug().Msgf("REGISTER CLIENT: %s", sessionId) + var user structs.User + + database.DB.First(&user, "id = ?", userId) + + newSocketClient.SendMessage(structs.SendSocketMessage{ + Cmd: utils.SentInitUserSocketConnection, + Body: structs.InitUserSocketConnection{ + Username: user.Username, + Email: user.Email, + }, + }) + socketclients.UpdateConnectedUsers() case data := <-broadcast: