diff --git a/routers/api/v1/user/user.go b/routers/api/v1/user/user.go index d67cc77..6f1b10e 100644 --- a/routers/api/v1/user/user.go +++ b/routers/api/v1/user/user.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" - "fmt" + "errors" "math/big" "regexp" "strings" @@ -12,11 +12,14 @@ import ( "unicode" "git.umbach.dev/app-idea/rest-api/modules/database" + "git.umbach.dev/app-idea/rest-api/modules/serversettings" + "git.umbach.dev/app-idea/rest-api/modules/structs" "github.com/gofiber/fiber/v2" "github.com/google/uuid" ua "github.com/mileusna/useragent" log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) type LoginInput struct { @@ -27,7 +30,7 @@ type LoginInput struct { } func NewUser(c *fiber.Ctx) error { - // swagger:operation POST /users usersNewUser + // swagger:operation POST /users User usersNewUser // --- // summary: Create new user // produces: @@ -50,7 +53,7 @@ func NewUser(c *fiber.Ctx) error { // required: true // - name: hashtag // in: query - // description: hashtag of the client (length 2-6, UPPERCASE) + // description: hashtag of the client (length 2-6, UPPERCASE (Letters, Numbers)) // type: string // - name: avatar_url // in: query @@ -63,6 +66,7 @@ func NewUser(c *fiber.Ctx) error { // responses: // '201': // description: user created + // "$ref": "#/definitions/User" // '400': // description: format is not correct // '422': @@ -88,25 +92,24 @@ func NewUser(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusForbidden) } - db, err := database.GetDatabase() + user := structs.User{} - if db == nil || err != nil { - return c.SendStatus(fiber.StatusInternalServerError) - } + user.Email = input.Email - defer db.Close() + //db = database.DB - if !isEmailAvailable(db, input.Email) { + if !isEmailAvailable(database.DB, user, input.Email) { return c.SendStatus(fiber.StatusUnprocessableEntity) } if input.Hashtag == "" { - input.Hashtag, err = generateRandomHashtag(db, 6) + input.Hashtag, err = generateRandomString(6) + //input.Hashtag, err = generateRandomHashtag(database.DB, 6) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) } - } else if !isHashtagValid(db, input.Hashtag) { + } else if !isHashtagValid(database.DB, input.Hashtag) { return c.SendStatus(fiber.StatusUnprocessableEntity) } @@ -118,19 +121,32 @@ func NewUser(c *fiber.Ctx) error { } userId := strings.Replace(uuid.New().String(), "-", "", -1) - created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss + //created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss - stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);") - stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created) + //user := structs.User{ID: userId, Hashtag: input.Hashtag, Name: input.Username, Email: input.Email, Password: string(hashedPassword), LastLogin: time.Now(), CreatedAt: time.Now()} - stmt.Close() + user.ID = userId + user.Hashtag = input.Hashtag + user.Name = input.Username + user.Password = string(hashedPassword) + user.LastLogin = time.Now() + user.CreatedAt = time.Now() + + res := database.DB.Create(&user) + + log.Infoln("inserted", res, user) + + //stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);") + //stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created) + + //stmt.Close() if err != nil { log.Warnln("Failed to insert user to db:", err.Error()) return c.SendStatus(fiber.StatusInternalServerError) } - sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent())) + sessionId, err := createUserSession(database.DB, userId, c.IP(), string(c.Context().UserAgent())) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) @@ -162,7 +178,8 @@ func generateRandomString(n int) (string, error) { return string(r), nil } -func generateRandomHashtag(db *sql.DB, n int) (string, error) { +/* +func generateRandomHashtag(db *gorm.DB, n int) (string, error) { c := make(chan bool) var s string var err error @@ -176,13 +193,15 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) { } go func() { - err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s) + /*err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s) if err == sql.ErrNoRows { c <- true } else { c <- false } + + c <- true }() if msg := <-c; msg { @@ -191,19 +210,21 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) { } return s, nil -} +} */ -func isHashtagValid(db *sql.DB, h string) bool { +func isHashtagValid(db *gorm.DB, h string) bool { if !isUpper(h) || len(h) < 2 || len(h) > 6 { return false } + /* + err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h) - err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h) + if err == sql.ErrNoRows { + return true + } + return false */ - if err == sql.ErrNoRows { - return true - } - return false + return true } func isUpper(s string) bool { @@ -231,13 +252,24 @@ func isEmailValid(e string) bool { return emailRegex.MatchString(e) } -func isEmailAvailable(db *sql.DB, e string) bool { - err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e) +func isEmailAvailable(db *gorm.DB, user structs.User, email string) bool { + log.Infoln("isEmailAvailable email", user.Email) - if err == sql.ErrNoRows { + err := db.First(&user, "email = ?", user.Email).Error + + log.Warnln("isEmailErr", errors.Is(err, gorm.ErrRecordNotFound)) + + if errors.Is(err, gorm.ErrRecordNotFound) { return true } return false + + /*err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e) + + if err == sql.ErrNoRows { + return true + } */ + //return false } func SessionIdCheck(c *fiber.Ctx) error { @@ -261,23 +293,18 @@ func SessionIdCheck(c *fiber.Ctx) error { } func isSessionIdValid(sessionId string) (bool, error) { - db, err := database.GetDatabase() - if db == nil || err != nil { - return false, err - } + //defer db.Close() - defer db.Close() + deleteExpiredSessions(database.DB) - deleteExpiredSessions(db) - - err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId) + /*err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId) if err == sql.ErrNoRows { return false, nil } - /*sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires) + sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires) log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix()) @@ -306,15 +333,21 @@ func deleteSession(db *sql.DB, sessionId string) { } } -func deleteExpiredSessions(db *sql.DB) { - _, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()") +func deleteExpiredSessions(db *gorm.DB) { + err := db.Exec("DELETE FROM sessions WHERE expires < CURRENT_TIMESTAMP()") + + if err.Error != nil { + log.Warnln("err deleting expired sessions:", err.Error) + } + + /*_, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()") if err != nil { log.Warnln("err deleting expired sessions:", err) - } + } */ } -func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (string, error) { +func createUserSession(db *gorm.DB, userId string, ip string, userAgent string) (string, error) { sessionId, err := generateRandomString(32) if err != nil { @@ -322,29 +355,40 @@ func createUserSession(db *sql.DB, userId string, ip string, userAgent string) ( return "", err } - stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);") + ua := ua.Parse(userAgent) + session := structs.Session{UserId: userId, SessionId: sessionId, IP: ip, UserAgent: ua.OS + " " + ua.Name, LastLogin: time.Now(), Expires: getExpiresTime()} + + res := db.Create(&session) + + log.Infoln("res", res.Error) + + if res.Error != nil { + log.Warnln("Failed to create session:", res.Error) + return "", err + } + + /*stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);") if err != nil { log.Warnln("Failed to insert user into db:", err) return "", err - } + } */ - ua := ua.Parse(userAgent) + //ua := ua.Parse(userAgent) - stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime()) - stmt.Close() + //stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime()) + //stmt.Close() return sessionId, nil } func getExpiresTime() time.Time { // TODO: db default - - return time.Now().Add(time.Hour * 72) + return time.Now().Add(time.Hour * time.Duration(serversettings.Settings.ExpiredTime)) } func Login(c *fiber.Ctx) error { - // swagger:operation POST /user/login userLogin + // swagger:operation POST /user/login User userLogin // --- // summary: Login a user // produces: @@ -387,37 +431,39 @@ func Login(c *fiber.Ctx) error { input.Password = string(decodedPassword) - db, err := database.GetDatabase() - - if db == nil || err != nil { - return c.SendStatus(fiber.StatusInternalServerError) - } - - defer db.Close() - - var userId string - var username string - var userHashtag string - var hashedPassword string - - if input.Username != "" { + /*if input.Username != "" { err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword) } else { err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword) + } */ + /* + if err != nil { + return c.SendStatus(fiber.StatusUnauthorized) + } */ + + db := database.DB + user := structs.User{} + + if input.Username != "" { + db.Select("id, hashtag, password").Where("name = ?", input.Username).Find(&user) + + log.Infoln("a", user) + } else { + db.Select("id, hashtag, name, password").Where("email = ?", input.Email).Find(&user) + + log.Infoln("a", user) } - if err != nil { - return c.SendStatus(fiber.StatusUnauthorized) - } + log.Infoln("pass", input.Password, user.Password) - err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(input.Password)) + err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(input.Password)) if err != nil { log.Warnln("Failed to comapare bcrypt password", err) return c.SendStatus(fiber.StatusUnauthorized) } - sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent())) + sessionId, err := createUserSession(database.DB, user.ID, c.IP(), string(c.Context().UserAgent())) if err != nil { return c.SendStatus(fiber.StatusInternalServerError) @@ -426,10 +472,10 @@ func Login(c *fiber.Ctx) error { expires := getExpiresTime() c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires}) - if username != "" { - c.Cookie(&fiber.Cookie{Name: "username", Value: username, Secure: true, Expires: expires}) + if user.Name != "" { + c.Cookie(&fiber.Cookie{Name: "name", Value: user.Name, Secure: true, Expires: expires}) } - c.Cookie(&fiber.Cookie{Name: "user_hashtag", Value: userHashtag, Secure: true, Expires: expires}) + c.Cookie(&fiber.Cookie{Name: "hashtag", Value: user.Hashtag, Secure: true, Expires: expires}) return c.SendStatus(fiber.StatusCreated) } @@ -439,37 +485,23 @@ func GetUser(c *fiber.Ctx) error { } func GetUsers(c *fiber.Ctx) error { - cookie := c.Cookies("session_id") - - log.Infoln("cookies", cookie) - - db, err := database.GetDatabase() - - // c.Params("id") - - if db == nil || err != nil { - return c.SendStatus(fiber.StatusInternalServerError) - } - - defer db.Close() list := []string{} - - var ( - name string - ) - - rows, err := db.Query("SELECT username FROM users;") - fmt.Println("err", err) - defer rows.Close() - fmt.Println("reading data:") - for rows.Next() { - err := rows.Scan(&name) - fmt.Printf("Data row = (%s, %s)\n", name, err) - list = append(list, name) - } - err = rows.Err() - fmt.Println("Done") + /* + var ( + name string + )*/ + /* + rows, err := db.Query("SELECT username FROM users;") + fmt.Println("err", err) + defer rows.Close() + fmt.Println("reading data:") + for rows.Next() { + err := rows.Scan(&name) + fmt.Printf("Data row = (%s, %s)\n", name, err) + list = append(list, name) + } + err = rows.Err()*/ return c.JSON(list) }