Switched to new database method
parent
5e806774ed
commit
45d1759fff
|
@ -4,7 +4,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -12,11 +12,14 @@ import (
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"git.umbach.dev/app-idea/rest-api/modules/database"
|
"git.umbach.dev/app-idea/rest-api/modules/database"
|
||||||
|
"git.umbach.dev/app-idea/rest-api/modules/serversettings"
|
||||||
|
"git.umbach.dev/app-idea/rest-api/modules/structs"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
ua "github.com/mileusna/useragent"
|
ua "github.com/mileusna/useragent"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginInput struct {
|
type LoginInput struct {
|
||||||
|
@ -27,7 +30,7 @@ type LoginInput struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUser(c *fiber.Ctx) error {
|
func NewUser(c *fiber.Ctx) error {
|
||||||
// swagger:operation POST /users usersNewUser
|
// swagger:operation POST /users User usersNewUser
|
||||||
// ---
|
// ---
|
||||||
// summary: Create new user
|
// summary: Create new user
|
||||||
// produces:
|
// produces:
|
||||||
|
@ -50,7 +53,7 @@ func NewUser(c *fiber.Ctx) error {
|
||||||
// required: true
|
// required: true
|
||||||
// - name: hashtag
|
// - name: hashtag
|
||||||
// in: query
|
// in: query
|
||||||
// description: hashtag of the client (length 2-6, UPPERCASE)
|
// description: hashtag of the client (length 2-6, UPPERCASE (Letters, Numbers))
|
||||||
// type: string
|
// type: string
|
||||||
// - name: avatar_url
|
// - name: avatar_url
|
||||||
// in: query
|
// in: query
|
||||||
|
@ -63,6 +66,7 @@ func NewUser(c *fiber.Ctx) error {
|
||||||
// responses:
|
// responses:
|
||||||
// '201':
|
// '201':
|
||||||
// description: user created
|
// description: user created
|
||||||
|
// "$ref": "#/definitions/User"
|
||||||
// '400':
|
// '400':
|
||||||
// description: format is not correct
|
// description: format is not correct
|
||||||
// '422':
|
// '422':
|
||||||
|
@ -88,25 +92,24 @@ func NewUser(c *fiber.Ctx) error {
|
||||||
return c.SendStatus(fiber.StatusForbidden)
|
return c.SendStatus(fiber.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := database.GetDatabase()
|
user := structs.User{}
|
||||||
|
|
||||||
if db == nil || err != nil {
|
user.Email = input.Email
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer db.Close()
|
//db = database.DB
|
||||||
|
|
||||||
if !isEmailAvailable(db, input.Email) {
|
if !isEmailAvailable(database.DB, user, input.Email) {
|
||||||
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Hashtag == "" {
|
if input.Hashtag == "" {
|
||||||
input.Hashtag, err = generateRandomHashtag(db, 6)
|
input.Hashtag, err = generateRandomString(6)
|
||||||
|
//input.Hashtag, err = generateRandomHashtag(database.DB, 6)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
return c.SendStatus(fiber.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
} else if !isHashtagValid(db, input.Hashtag) {
|
} else if !isHashtagValid(database.DB, input.Hashtag) {
|
||||||
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,19 +121,32 @@ func NewUser(c *fiber.Ctx) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
userId := strings.Replace(uuid.New().String(), "-", "", -1)
|
userId := strings.Replace(uuid.New().String(), "-", "", -1)
|
||||||
created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss
|
//created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss
|
||||||
|
|
||||||
stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);")
|
//user := structs.User{ID: userId, Hashtag: input.Hashtag, Name: input.Username, Email: input.Email, Password: string(hashedPassword), LastLogin: time.Now(), CreatedAt: time.Now()}
|
||||||
stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created)
|
|
||||||
|
|
||||||
stmt.Close()
|
user.ID = userId
|
||||||
|
user.Hashtag = input.Hashtag
|
||||||
|
user.Name = input.Username
|
||||||
|
user.Password = string(hashedPassword)
|
||||||
|
user.LastLogin = time.Now()
|
||||||
|
user.CreatedAt = time.Now()
|
||||||
|
|
||||||
|
res := database.DB.Create(&user)
|
||||||
|
|
||||||
|
log.Infoln("inserted", res, user)
|
||||||
|
|
||||||
|
//stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);")
|
||||||
|
//stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created)
|
||||||
|
|
||||||
|
//stmt.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnln("Failed to insert user to db:", err.Error())
|
log.Warnln("Failed to insert user to db:", err.Error())
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
return c.SendStatus(fiber.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent()))
|
sessionId, err := createUserSession(database.DB, userId, c.IP(), string(c.Context().UserAgent()))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
return c.SendStatus(fiber.StatusInternalServerError)
|
||||||
|
@ -162,7 +178,8 @@ func generateRandomString(n int) (string, error) {
|
||||||
return string(r), nil
|
return string(r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRandomHashtag(db *sql.DB, n int) (string, error) {
|
/*
|
||||||
|
func generateRandomHashtag(db *gorm.DB, n int) (string, error) {
|
||||||
c := make(chan bool)
|
c := make(chan bool)
|
||||||
var s string
|
var s string
|
||||||
var err error
|
var err error
|
||||||
|
@ -176,13 +193,15 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s)
|
/*err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
c <- true
|
c <- true
|
||||||
} else {
|
} else {
|
||||||
c <- false
|
c <- false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c <- true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if msg := <-c; msg {
|
if msg := <-c; msg {
|
||||||
|
@ -191,19 +210,21 @@ func generateRandomHashtag(db *sql.DB, n int) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
} */
|
||||||
|
|
||||||
func isHashtagValid(db *sql.DB, h string) bool {
|
func isHashtagValid(db *gorm.DB, h string) bool {
|
||||||
if !isUpper(h) || len(h) < 2 || len(h) > 6 {
|
if !isUpper(h) || len(h) < 2 || len(h) > 6 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h)
|
||||||
|
|
||||||
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h)
|
if err == sql.ErrNoRows {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false */
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
return true
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isUpper(s string) bool {
|
func isUpper(s string) bool {
|
||||||
|
@ -231,13 +252,24 @@ func isEmailValid(e string) bool {
|
||||||
return emailRegex.MatchString(e)
|
return emailRegex.MatchString(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isEmailAvailable(db *sql.DB, e string) bool {
|
func isEmailAvailable(db *gorm.DB, user structs.User, email string) bool {
|
||||||
err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e)
|
log.Infoln("isEmailAvailable email", user.Email)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
err := db.First(&user, "email = ?", user.Email).Error
|
||||||
|
|
||||||
|
log.Warnln("isEmailErr", errors.Is(err, gorm.ErrRecordNotFound))
|
||||||
|
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
||||||
|
/*err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e)
|
||||||
|
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return true
|
||||||
|
} */
|
||||||
|
//return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionIdCheck(c *fiber.Ctx) error {
|
func SessionIdCheck(c *fiber.Ctx) error {
|
||||||
|
@ -261,23 +293,18 @@ func SessionIdCheck(c *fiber.Ctx) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSessionIdValid(sessionId string) (bool, error) {
|
func isSessionIdValid(sessionId string) (bool, error) {
|
||||||
db, err := database.GetDatabase()
|
|
||||||
|
|
||||||
if db == nil || err != nil {
|
//defer db.Close()
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer db.Close()
|
deleteExpiredSessions(database.DB)
|
||||||
|
|
||||||
deleteExpiredSessions(db)
|
/*err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId)
|
||||||
|
|
||||||
err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId)
|
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires)
|
sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires)
|
||||||
|
|
||||||
log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix())
|
log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix())
|
||||||
|
|
||||||
|
@ -306,15 +333,21 @@ func deleteSession(db *sql.DB, sessionId string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteExpiredSessions(db *sql.DB) {
|
func deleteExpiredSessions(db *gorm.DB) {
|
||||||
_, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()")
|
err := db.Exec("DELETE FROM sessions WHERE expires < CURRENT_TIMESTAMP()")
|
||||||
|
|
||||||
|
if err.Error != nil {
|
||||||
|
log.Warnln("err deleting expired sessions:", err.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*_, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnln("err deleting expired sessions:", err)
|
log.Warnln("err deleting expired sessions:", err)
|
||||||
}
|
} */
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (string, error) {
|
func createUserSession(db *gorm.DB, userId string, ip string, userAgent string) (string, error) {
|
||||||
sessionId, err := generateRandomString(32)
|
sessionId, err := generateRandomString(32)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -322,29 +355,40 @@ func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);")
|
ua := ua.Parse(userAgent)
|
||||||
|
session := structs.Session{UserId: userId, SessionId: sessionId, IP: ip, UserAgent: ua.OS + " " + ua.Name, LastLogin: time.Now(), Expires: getExpiresTime()}
|
||||||
|
|
||||||
|
res := db.Create(&session)
|
||||||
|
|
||||||
|
log.Infoln("res", res.Error)
|
||||||
|
|
||||||
|
if res.Error != nil {
|
||||||
|
log.Warnln("Failed to create session:", res.Error)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
/*stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnln("Failed to insert user into db:", err)
|
log.Warnln("Failed to insert user into db:", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
} */
|
||||||
|
|
||||||
ua := ua.Parse(userAgent)
|
//ua := ua.Parse(userAgent)
|
||||||
|
|
||||||
stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime())
|
//stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime())
|
||||||
stmt.Close()
|
//stmt.Close()
|
||||||
|
|
||||||
return sessionId, nil
|
return sessionId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getExpiresTime() time.Time {
|
func getExpiresTime() time.Time {
|
||||||
// TODO: db default
|
// TODO: db default
|
||||||
|
return time.Now().Add(time.Hour * time.Duration(serversettings.Settings.ExpiredTime))
|
||||||
return time.Now().Add(time.Hour * 72)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Login(c *fiber.Ctx) error {
|
func Login(c *fiber.Ctx) error {
|
||||||
// swagger:operation POST /user/login userLogin
|
// swagger:operation POST /user/login User userLogin
|
||||||
// ---
|
// ---
|
||||||
// summary: Login a user
|
// summary: Login a user
|
||||||
// produces:
|
// produces:
|
||||||
|
@ -387,37 +431,39 @@ func Login(c *fiber.Ctx) error {
|
||||||
|
|
||||||
input.Password = string(decodedPassword)
|
input.Password = string(decodedPassword)
|
||||||
|
|
||||||
db, err := database.GetDatabase()
|
/*if input.Username != "" {
|
||||||
|
|
||||||
if db == nil || err != nil {
|
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
var userId string
|
|
||||||
var username string
|
|
||||||
var userHashtag string
|
|
||||||
var hashedPassword string
|
|
||||||
|
|
||||||
if input.Username != "" {
|
|
||||||
err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword)
|
err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword)
|
||||||
} else {
|
} else {
|
||||||
err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword)
|
err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword)
|
||||||
|
} */
|
||||||
|
/*
|
||||||
|
if err != nil {
|
||||||
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
|
} */
|
||||||
|
|
||||||
|
db := database.DB
|
||||||
|
user := structs.User{}
|
||||||
|
|
||||||
|
if input.Username != "" {
|
||||||
|
db.Select("id, hashtag, password").Where("name = ?", input.Username).Find(&user)
|
||||||
|
|
||||||
|
log.Infoln("a", user)
|
||||||
|
} else {
|
||||||
|
db.Select("id, hashtag, name, password").Where("email = ?", input.Email).Find(&user)
|
||||||
|
|
||||||
|
log.Infoln("a", user)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
log.Infoln("pass", input.Password, user.Password)
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(input.Password))
|
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(input.Password))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnln("Failed to comapare bcrypt password", err)
|
log.Warnln("Failed to comapare bcrypt password", err)
|
||||||
return c.SendStatus(fiber.StatusUnauthorized)
|
return c.SendStatus(fiber.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent()))
|
sessionId, err := createUserSession(database.DB, user.ID, c.IP(), string(c.Context().UserAgent()))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
return c.SendStatus(fiber.StatusInternalServerError)
|
||||||
|
@ -426,10 +472,10 @@ func Login(c *fiber.Ctx) error {
|
||||||
expires := getExpiresTime()
|
expires := getExpiresTime()
|
||||||
|
|
||||||
c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires})
|
c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires})
|
||||||
if username != "" {
|
if user.Name != "" {
|
||||||
c.Cookie(&fiber.Cookie{Name: "username", Value: username, Secure: true, Expires: expires})
|
c.Cookie(&fiber.Cookie{Name: "name", Value: user.Name, Secure: true, Expires: expires})
|
||||||
}
|
}
|
||||||
c.Cookie(&fiber.Cookie{Name: "user_hashtag", Value: userHashtag, Secure: true, Expires: expires})
|
c.Cookie(&fiber.Cookie{Name: "hashtag", Value: user.Hashtag, Secure: true, Expires: expires})
|
||||||
|
|
||||||
return c.SendStatus(fiber.StatusCreated)
|
return c.SendStatus(fiber.StatusCreated)
|
||||||
}
|
}
|
||||||
|
@ -439,37 +485,23 @@ func GetUser(c *fiber.Ctx) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(c *fiber.Ctx) error {
|
func GetUsers(c *fiber.Ctx) error {
|
||||||
cookie := c.Cookies("session_id")
|
|
||||||
|
|
||||||
log.Infoln("cookies", cookie)
|
|
||||||
|
|
||||||
db, err := database.GetDatabase()
|
|
||||||
|
|
||||||
// c.Params("id")
|
|
||||||
|
|
||||||
if db == nil || err != nil {
|
|
||||||
return c.SendStatus(fiber.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
list := []string{}
|
list := []string{}
|
||||||
|
/*
|
||||||
var (
|
var (
|
||||||
name string
|
name string
|
||||||
)
|
)*/
|
||||||
|
/*
|
||||||
rows, err := db.Query("SELECT username FROM users;")
|
rows, err := db.Query("SELECT username FROM users;")
|
||||||
fmt.Println("err", err)
|
fmt.Println("err", err)
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
fmt.Println("reading data:")
|
fmt.Println("reading data:")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err := rows.Scan(&name)
|
err := rows.Scan(&name)
|
||||||
fmt.Printf("Data row = (%s, %s)\n", name, err)
|
fmt.Printf("Data row = (%s, %s)\n", name, err)
|
||||||
list = append(list, name)
|
list = append(list, name)
|
||||||
}
|
}
|
||||||
err = rows.Err()
|
err = rows.Err()*/
|
||||||
fmt.Println("Done")
|
|
||||||
|
|
||||||
return c.JSON(list)
|
return c.JSON(list)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue