Switched to new database method

master
Alex 2021-05-28 22:09:14 +02:00
parent 5e806774ed
commit 45d1759fff
1 changed files with 132 additions and 100 deletions

View File

@ -4,7 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"fmt" "errors"
"math/big" "math/big"
"regexp" "regexp"
"strings" "strings"
@ -12,11 +12,14 @@ import (
"unicode" "unicode"
"git.umbach.dev/app-idea/rest-api/modules/database" "git.umbach.dev/app-idea/rest-api/modules/database"
"git.umbach.dev/app-idea/rest-api/modules/serversettings"
"git.umbach.dev/app-idea/rest-api/modules/structs"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
ua "github.com/mileusna/useragent" ua "github.com/mileusna/useragent"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
type LoginInput struct { type LoginInput struct {
@ -27,7 +30,7 @@ type LoginInput struct {
} }
func NewUser(c *fiber.Ctx) error { func NewUser(c *fiber.Ctx) error {
// swagger:operation POST /users usersNewUser // swagger:operation POST /users User usersNewUser
// --- // ---
// summary: Create new user // summary: Create new user
// produces: // produces:
@ -50,7 +53,7 @@ func NewUser(c *fiber.Ctx) error {
// required: true // required: true
// - name: hashtag // - name: hashtag
// in: query // in: query
// description: hashtag of the client (length 2-6, UPPERCASE) // description: hashtag of the client (length 2-6, UPPERCASE (Letters, Numbers))
// type: string // type: string
// - name: avatar_url // - name: avatar_url
// in: query // in: query
@ -63,6 +66,7 @@ func NewUser(c *fiber.Ctx) error {
// responses: // responses:
// '201': // '201':
// description: user created // description: user created
// "$ref": "#/definitions/User"
// '400': // '400':
// description: format is not correct // description: format is not correct
// '422': // '422':
@ -88,25 +92,24 @@ func NewUser(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusForbidden) return c.SendStatus(fiber.StatusForbidden)
} }
db, err := database.GetDatabase() user := structs.User{}
if db == nil || err != nil { user.Email = input.Email
return c.SendStatus(fiber.StatusInternalServerError)
}
defer db.Close() //db = database.DB
if !isEmailAvailable(db, input.Email) { if !isEmailAvailable(database.DB, user, input.Email) {
return c.SendStatus(fiber.StatusUnprocessableEntity) return c.SendStatus(fiber.StatusUnprocessableEntity)
} }
if input.Hashtag == "" { if input.Hashtag == "" {
input.Hashtag, err = generateRandomHashtag(db, 6) input.Hashtag, err = generateRandomString(6)
//input.Hashtag, err = generateRandomHashtag(database.DB, 6)
if err != nil { if err != nil {
return c.SendStatus(fiber.StatusInternalServerError) return c.SendStatus(fiber.StatusInternalServerError)
} }
} else if !isHashtagValid(db, input.Hashtag) { } else if !isHashtagValid(database.DB, input.Hashtag) {
return c.SendStatus(fiber.StatusUnprocessableEntity) return c.SendStatus(fiber.StatusUnprocessableEntity)
} }
@ -118,19 +121,32 @@ func NewUser(c *fiber.Ctx) error {
} }
userId := strings.Replace(uuid.New().String(), "-", "", -1) userId := strings.Replace(uuid.New().String(), "-", "", -1)
created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss //created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss
stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);") //user := structs.User{ID: userId, Hashtag: input.Hashtag, Name: input.Username, Email: input.Email, Password: string(hashedPassword), LastLogin: time.Now(), CreatedAt: time.Now()}
stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created)
stmt.Close() user.ID = userId
user.Hashtag = input.Hashtag
user.Name = input.Username
user.Password = string(hashedPassword)
user.LastLogin = time.Now()
user.CreatedAt = time.Now()
res := database.DB.Create(&user)
log.Infoln("inserted", res, user)
//stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);")
//stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created)
//stmt.Close()
if err != nil { if err != nil {
log.Warnln("Failed to insert user to db:", err.Error()) log.Warnln("Failed to insert user to db:", err.Error())
return c.SendStatus(fiber.StatusInternalServerError) return c.SendStatus(fiber.StatusInternalServerError)
} }
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent())) sessionId, err := createUserSession(database.DB, userId, c.IP(), string(c.Context().UserAgent()))
if err != nil { if err != nil {
return c.SendStatus(fiber.StatusInternalServerError) return c.SendStatus(fiber.StatusInternalServerError)
@ -162,7 +178,8 @@ func generateRandomString(n int) (string, error) {
return string(r), nil return string(r), nil
} }
func generateRandomHashtag(db *sql.DB, n int) (string, error) { /*
func generateRandomHashtag(db *gorm.DB, n int) (string, error) {
c := make(chan bool) c := make(chan bool)
var s string var s string
var err error var err error
@ -176,13 +193,15 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) {
} }
go func() { go func() {
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s) /*err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
c <- true c <- true
} else { } else {
c <- false c <- false
} }
c <- true
}() }()
if msg := <-c; msg { if msg := <-c; msg {
@ -191,19 +210,21 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) {
} }
return s, nil return s, nil
} } */
func isHashtagValid(db *sql.DB, h string) bool { func isHashtagValid(db *gorm.DB, h string) bool {
if !isUpper(h) || len(h) < 2 || len(h) > 6 { if !isUpper(h) || len(h) < 2 || len(h) > 6 {
return false return false
} }
/*
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h)
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h) if err == sql.ErrNoRows {
return true
}
return false */
if err == sql.ErrNoRows { return true
return true
}
return false
} }
func isUpper(s string) bool { func isUpper(s string) bool {
@ -231,13 +252,24 @@ func isEmailValid(e string) bool {
return emailRegex.MatchString(e) return emailRegex.MatchString(e)
} }
func isEmailAvailable(db *sql.DB, e string) bool { func isEmailAvailable(db *gorm.DB, user structs.User, email string) bool {
err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e) log.Infoln("isEmailAvailable email", user.Email)
if err == sql.ErrNoRows { err := db.First(&user, "email = ?", user.Email).Error
log.Warnln("isEmailErr", errors.Is(err, gorm.ErrRecordNotFound))
if errors.Is(err, gorm.ErrRecordNotFound) {
return true return true
} }
return false return false
/*err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e)
if err == sql.ErrNoRows {
return true
} */
//return false
} }
func SessionIdCheck(c *fiber.Ctx) error { func SessionIdCheck(c *fiber.Ctx) error {
@ -261,23 +293,18 @@ func SessionIdCheck(c *fiber.Ctx) error {
} }
func isSessionIdValid(sessionId string) (bool, error) { func isSessionIdValid(sessionId string) (bool, error) {
db, err := database.GetDatabase()
if db == nil || err != nil { //defer db.Close()
return false, err
}
defer db.Close() deleteExpiredSessions(database.DB)
deleteExpiredSessions(db) /*err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId)
err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
/*sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires) sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires)
log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix()) log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix())
@ -306,15 +333,21 @@ func deleteSession(db *sql.DB, sessionId string) {
} }
} }
func deleteExpiredSessions(db *sql.DB) { func deleteExpiredSessions(db *gorm.DB) {
_, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()") err := db.Exec("DELETE FROM sessions WHERE expires < CURRENT_TIMESTAMP()")
if err.Error != nil {
log.Warnln("err deleting expired sessions:", err.Error)
}
/*_, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()")
if err != nil { if err != nil {
log.Warnln("err deleting expired sessions:", err) log.Warnln("err deleting expired sessions:", err)
} } */
} }
func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (string, error) { func createUserSession(db *gorm.DB, userId string, ip string, userAgent string) (string, error) {
sessionId, err := generateRandomString(32) sessionId, err := generateRandomString(32)
if err != nil { if err != nil {
@ -322,29 +355,40 @@ func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (
return "", err return "", err
} }
stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);") ua := ua.Parse(userAgent)
session := structs.Session{UserId: userId, SessionId: sessionId, IP: ip, UserAgent: ua.OS + " " + ua.Name, LastLogin: time.Now(), Expires: getExpiresTime()}
res := db.Create(&session)
log.Infoln("res", res.Error)
if res.Error != nil {
log.Warnln("Failed to create session:", res.Error)
return "", err
}
/*stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);")
if err != nil { if err != nil {
log.Warnln("Failed to insert user into db:", err) log.Warnln("Failed to insert user into db:", err)
return "", err return "", err
} } */
ua := ua.Parse(userAgent) //ua := ua.Parse(userAgent)
stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime()) //stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime())
stmt.Close() //stmt.Close()
return sessionId, nil return sessionId, nil
} }
func getExpiresTime() time.Time { func getExpiresTime() time.Time {
// TODO: db default // TODO: db default
return time.Now().Add(time.Hour * time.Duration(serversettings.Settings.ExpiredTime))
return time.Now().Add(time.Hour * 72)
} }
func Login(c *fiber.Ctx) error { func Login(c *fiber.Ctx) error {
// swagger:operation POST /user/login userLogin // swagger:operation POST /user/login User userLogin
// --- // ---
// summary: Login a user // summary: Login a user
// produces: // produces:
@ -387,37 +431,39 @@ func Login(c *fiber.Ctx) error {
input.Password = string(decodedPassword) input.Password = string(decodedPassword)
db, err := database.GetDatabase() /*if input.Username != "" {
if db == nil || err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
defer db.Close()
var userId string
var username string
var userHashtag string
var hashedPassword string
if input.Username != "" {
err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword) err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword)
} else { } else {
err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword) err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword)
} */
/*
if err != nil {
return c.SendStatus(fiber.StatusUnauthorized)
} */
db := database.DB
user := structs.User{}
if input.Username != "" {
db.Select("id, hashtag, password").Where("name = ?", input.Username).Find(&user)
log.Infoln("a", user)
} else {
db.Select("id, hashtag, name, password").Where("email = ?", input.Email).Find(&user)
log.Infoln("a", user)
} }
if err != nil { log.Infoln("pass", input.Password, user.Password)
return c.SendStatus(fiber.StatusUnauthorized)
}
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(input.Password)) err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(input.Password))
if err != nil { if err != nil {
log.Warnln("Failed to comapare bcrypt password", err) log.Warnln("Failed to comapare bcrypt password", err)
return c.SendStatus(fiber.StatusUnauthorized) return c.SendStatus(fiber.StatusUnauthorized)
} }
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent())) sessionId, err := createUserSession(database.DB, user.ID, c.IP(), string(c.Context().UserAgent()))
if err != nil { if err != nil {
return c.SendStatus(fiber.StatusInternalServerError) return c.SendStatus(fiber.StatusInternalServerError)
@ -426,10 +472,10 @@ func Login(c *fiber.Ctx) error {
expires := getExpiresTime() expires := getExpiresTime()
c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires}) c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires})
if username != "" { if user.Name != "" {
c.Cookie(&fiber.Cookie{Name: "username", Value: username, Secure: true, Expires: expires}) c.Cookie(&fiber.Cookie{Name: "name", Value: user.Name, Secure: true, Expires: expires})
} }
c.Cookie(&fiber.Cookie{Name: "user_hashtag", Value: userHashtag, Secure: true, Expires: expires}) c.Cookie(&fiber.Cookie{Name: "hashtag", Value: user.Hashtag, Secure: true, Expires: expires})
return c.SendStatus(fiber.StatusCreated) return c.SendStatus(fiber.StatusCreated)
} }
@ -439,37 +485,23 @@ func GetUser(c *fiber.Ctx) error {
} }
func GetUsers(c *fiber.Ctx) error { func GetUsers(c *fiber.Ctx) error {
cookie := c.Cookies("session_id")
log.Infoln("cookies", cookie)
db, err := database.GetDatabase()
// c.Params("id")
if db == nil || err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
defer db.Close()
list := []string{} list := []string{}
/*
var ( var (
name string name string
) )*/
/*
rows, err := db.Query("SELECT username FROM users;") rows, err := db.Query("SELECT username FROM users;")
fmt.Println("err", err) fmt.Println("err", err)
defer rows.Close() defer rows.Close()
fmt.Println("reading data:") fmt.Println("reading data:")
for rows.Next() { for rows.Next() {
err := rows.Scan(&name) err := rows.Scan(&name)
fmt.Printf("Data row = (%s, %s)\n", name, err) fmt.Printf("Data row = (%s, %s)\n", name, err)
list = append(list, name) list = append(list, name)
} }
err = rows.Err() err = rows.Err()*/
fmt.Println("Done")
return c.JSON(list) return c.JSON(list)
} }