diff --git a/go.mod b/go.mod index 566d4d7..f958784 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.19 require ( git.clickandjoin.umbach.dev/ClickandJoin/go-rabbitmq-client v1.0.24 // indirect github.com/andybalholm/brotli v1.0.4 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fasthttp/websocket v1.5.0 // indirect github.com/gocql/gocql v1.3.1 // indirect github.com/gofiber/fiber/v2 v2.40.1 // indirect @@ -17,6 +19,7 @@ require ( github.com/mattn/go-isatty v0.0.16 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/rabbitmq/amqp091-go v1.5.0 // indirect + github.com/redis/go-redis/v9 v9.0.2 // indirect github.com/rivo/uniseg v0.4.3 // indirect github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d // indirect github.com/scylladb/go-reflectx v1.0.1 // indirect diff --git a/go.sum b/go.sum index 213120c..ec4ba12 100644 --- a/go.sum +++ b/go.sum @@ -44,8 +44,12 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/fasthttp/websocket v1.5.0 h1:B4zbe3xXyvIdnqjOZrafVFklCUq5ZLo/TqCt5JA1wLE= github.com/fasthttp/websocket v1.5.0/go.mod h1:n0BlOQvJdPbTuBkZT0O5+jk/sp/1/VCzquR1BehI2F4= github.com/gocql/gocql v1.3.1 h1:BTwM4rux+ah5G3oH6/MQa+tur/TDd/XAAOXDxBBs7rg= @@ -78,6 +82,8 @@ github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rabbitmq/amqp091-go v1.5.0 h1:VouyHPBu1CrKyJVfteGknGOGCzmOz0zcv/tONLkb7rg= github.com/rabbitmq/amqp091-go v1.5.0/go.mod h1:JsV0ofX5f1nwOGafb8L5rBItt9GyhfQfcJj+oyz0dGg= +github.com/redis/go-redis/v9 v9.0.2 h1:BA426Zqe/7r56kCcvxYLWe1mkaz71LKF77GwgFzSxfE= +github.com/redis/go-redis/v9 v9.0.2/go.mod h1:/xDTe9EF1LM61hek62Poq2nzQSGj0xSrEtEHbBQevps= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.3 h1:utMvzDsuh3suAEnhH0RdHmoPbU648o6CvXxTx4SBMOw= diff --git a/main.go b/main.go index 67325e8..877112c 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "clickandjoin.app/websocketserver/modules/config" "clickandjoin.app/websocketserver/modules/rabbitmq" + "clickandjoin.app/websocketserver/modules/redis" "clickandjoin.app/websocketserver/modules/scylladb" "clickandjoin.app/websocketserver/modules/structs" "clickandjoin.app/websocketserver/modules/utils" @@ -27,6 +28,8 @@ func init() { scylladb.InitDatabase() + redis.Init() + go rabbitmq.Init() } @@ -41,27 +44,37 @@ func main() { // IsWebSocketUpgrade returns true if the client // requested upgrade to the WebSocket protocol. if websocket.IsWebSocketUpgrade(c) { - wsSession := c.Query("auth") + wsSessionId := c.Query("auth") // no auth query available - if len(wsSession) != utils.LenWebSocketSession { + if len(wsSessionId) != utils.LenWebSocketSession { return c.SendStatus(fiber.StatusUnauthorized) } // validate ws session - foundWsSession := structs.UserWebSocketSession{Id: wsSession} + foundWsSession := structs.UserWebSocketSession{Id: wsSessionId} - q := scylladb.Session.Query(scylladb.WebSocketSessions.Get("id")).BindStruct(foundWsSession) + q := scylladb.Session.Query(scylladb.WebSocketSessions.Get("id", "user_id")).BindStruct(foundWsSession) if err := q.GetRelease(&foundWsSession); err != nil { - logrus.Errorln("Failed to find ws session:", wsSession, "err:", err) + logrus.Errorln("Failed to find ws session:", wsSessionId, "err:", err) return c.SendStatus(fiber.StatusUnauthorized) } + // + + if redis.ExistsUserWebSocketSessionId(foundWsSession.UserId, wsSessionId) { + logrus.Println("ws id already in list") + // TODO: kick out the other connected socket with this session + } else { + redis.AddUserWebSocketSessionId(foundWsSession.UserId, wsSessionId) + } // TODO: Further security checks such as the change of IP, user agents or whether the session ID has already opened another connection. - c.Locals("allowed", true) + c.Locals("wsSessionId", wsSessionId) + c.Locals("userId", foundWsSession.UserId) + return c.Next() } diff --git a/modules/redis/redis.go b/modules/redis/redis.go new file mode 100644 index 0000000..cc9b00b --- /dev/null +++ b/modules/redis/redis.go @@ -0,0 +1,63 @@ +package redis + +import ( + "context" + + "github.com/redis/go-redis/v9" + "github.com/sirupsen/logrus" +) + +var ctx = context.Background() +var rdb *redis.Client + +func Init() { + rdb = redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", + DB: 0, + }) + + err := rdb.Ping(ctx).Err() + + if err != nil { + logrus.Fatalln("Redis ping failed") + } +} + +func AddUserWebSocketSessionId(userId string, wsSessionId string) { + cmd := rdb.LPush(ctx, userId, wsSessionId) + + logrus.Println("b", cmd) +} + +func RemoveUserWebSocketSessionId(userId string, wsSessionId string) { + cmd := rdb.LRem(ctx, userId, -1, wsSessionId) + + logrus.Println("rem", cmd) +} + +func IsUserConnectedToAnyWebSocketServer(userId string) bool { + cmd := rdb.Exists(ctx, userId) + + logrus.Println("exists b", cmd) + + return cmd.Val() == 1 +} + +func ExistsUserWebSocketSessionId(userId string, wsSessionId string) bool { + wsSessions := rdb.LRange(ctx, userId, 0, -1) + + logrus.Println("found ws", wsSessions.Val()) + + return isWsSessionIdInList(wsSessions.Val(), wsSessionId) +} + +func isWsSessionIdInList(wsSessions []string, wsSessionId string) bool { + for _, item := range wsSessions { + if item == wsSessionId { + return true + } + } + + return false +} diff --git a/modules/structs/socket.go b/modules/structs/socket.go index 7079182..9f08443 100644 --- a/modules/structs/socket.go +++ b/modules/structs/socket.go @@ -10,6 +10,7 @@ import ( ) type SocketClient struct { + UserId string Conn *websocket.Conn connMu sync.Mutex RabbitMqQueueName string diff --git a/socketserver/hub.go b/socketserver/hub.go index 3702147..4938c19 100644 --- a/socketserver/hub.go +++ b/socketserver/hub.go @@ -1,13 +1,15 @@ package socketserver import ( + "fmt" + "clickandjoin.app/websocketserver/modules/cache" "clickandjoin.app/websocketserver/modules/rabbitmq" + "clickandjoin.app/websocketserver/modules/redis" "clickandjoin.app/websocketserver/modules/structs" "clickandjoin.app/websocketserver/modules/utils" "clickandjoin.app/websocketserver/socketclients" "github.com/gofiber/websocket/v2" - "github.com/google/uuid" "github.com/sirupsen/logrus" ) @@ -19,21 +21,26 @@ func RunHub() { for { select { case newSocketClient := <-register: - uuid := uuid.New().String() + userId := fmt.Sprintf("%v", newSocketClient.Conn.Locals("userId")) + wsSessionId := fmt.Sprintf("%v", newSocketClient.Conn.Locals("wsSessionId")) - err := rabbitmq.CreateWSClientBinding(newSocketClient, uuid) + err := rabbitmq.CreateWSClientBinding(newSocketClient, userId) if err != nil { logrus.Errorln("Failed to create client binding, err:", err) break } - cache.SocketClients[uuid] = newSocketClient + newSocketClient.UserId = userId - logrus.Debugln("REGISTER CLIENT:", uuid) + cache.SocketClients[wsSessionId] = newSocketClient + + logrus.Println("clients:", len(cache.SocketClients), cache.SocketClients) + + logrus.Debugln("REGISTER CLIENT:", wsSessionId) // for testing - marshaled, err := utils.MarshalMessage(structs.SocketMessageTest{Cmd: 99999, Body: uuid}) + marshaled, err := utils.MarshalMessage(structs.SocketMessageTest{Cmd: 99999, Body: userId}) if err != nil { logrus.Errorln("Failed to marshal uuid, err:", err) @@ -53,7 +60,7 @@ func RunHub() { logrus.Debugln("RECEIVED WEBSOCKET MESSAGE:", receivedMessage) - if receivedMessage.Rec != "" { + if len(receivedMessage.Rec) == utils.LenWebSocketSession { isConnected, recSocketClient := socketclients.IsReceiverConnectedToThisServer(receivedMessage.Rec) // send message to target receiver when connected to this server @@ -63,6 +70,11 @@ func RunHub() { } else { logrus.Debugln("FORWARDING MESSAGE: receiver connected to other server") + if !redis.IsUserConnectedToAnyWebSocketServer(receivedMessage.Rec) { + logrus.Warnln("rec user not connected to any websocket server") + break + } + err = rabbitmq.PublishClientMessage(structs.RabbitMqMessage{Cmd: receivedMessage.Cmd, Rec: receivedMessage.Rec, Body: receivedMessage.Body}) if err != nil { @@ -79,6 +91,7 @@ func RunHub() { client.CancelFunc() rabbitmq.DeleteWSClient(client.RabbitMqConsumerId, client.RabbitMqQueueName) + redis.RemoveUserWebSocketSessionId(client.UserId, id) } } }