484 lines
11 KiB
Go
484 lines
11 KiB
Go
package user
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"math/big"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"git.umbach.dev/app-idea/rest-api/modules/database"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
ua "github.com/mileusna/useragent"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type LoginInput struct {
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
Hashtag string `json:"hashtag"`
|
|
}
|
|
|
|
func NewUser(c *fiber.Ctx) error {
|
|
// swagger:operation POST /users usersNewUser
|
|
// ---
|
|
// summary: Create new user
|
|
// produces:
|
|
// - application/json
|
|
// parameters:
|
|
// - name: username
|
|
// in: query
|
|
// description: username of the user (length 3-30)
|
|
// type: string
|
|
// required: true
|
|
// - name: email
|
|
// in: query
|
|
// description: email of the user (length 3-255)
|
|
// type: string
|
|
// required: true
|
|
// - name: password
|
|
// in: query
|
|
// description: password (base64) of the user (length 6-250)
|
|
// type: string
|
|
// required: true
|
|
// - name: hashtag
|
|
// in: query
|
|
// description: hashtag of the client (length 2-6, UPPERCASE)
|
|
// type: string
|
|
// - name: avatar_url
|
|
// in: query
|
|
// description: avatar url of the client
|
|
// type: string
|
|
// - name: location
|
|
// in: query
|
|
// description: location of the client (length 1-20) (for example Frankfurt)
|
|
// type: string
|
|
// responses:
|
|
// '201':
|
|
// description: user created
|
|
// '400':
|
|
// description: format is not correct
|
|
// '422':
|
|
// description: username, email or/and hashtag already assigned
|
|
|
|
var input LoginInput
|
|
|
|
if err := c.BodyParser(&input); err != nil {
|
|
log.Debugln("bodyParser failed:", err)
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
decodedPassword, err := base64.StdEncoding.DecodeString(input.Password)
|
|
|
|
if err != nil {
|
|
log.Debugln("base64 decoding failed:", err)
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
input.Password = string(decodedPassword)
|
|
|
|
if !isValid(input.Username, 3, 30) || !isEmailValid(input.Email) || !isValid(input.Password, 6, 250) {
|
|
return c.SendStatus(fiber.StatusForbidden)
|
|
}
|
|
|
|
db, err := database.GetDatabase()
|
|
|
|
if db == nil || err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
defer db.Close()
|
|
|
|
if !isEmailAvailable(db, input.Email) {
|
|
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
|
}
|
|
|
|
if input.Hashtag == "" {
|
|
input.Hashtag, err = generateRandomHashtag(db, 6)
|
|
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
} else if !isHashtagValid(db, input.Hashtag) {
|
|
return c.SendStatus(fiber.StatusUnprocessableEntity)
|
|
}
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
|
|
|
if err != nil {
|
|
log.Warnln("Failed to bcrypt password", err)
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
userId := strings.Replace(uuid.New().String(), "-", "", -1)
|
|
created := time.Now().Format("2006-01-02 15:04:05") // YYYY-MM-DD hh:mm:ss
|
|
|
|
stmt, err := db.Prepare("INSERT INTO users (user_id, user_hashtag, username, email, password, created) VALUES (?, ?, ?, ?, ?, ?);")
|
|
stmt.Exec(userId, input.Hashtag, input.Username, input.Email, hashedPassword, created)
|
|
|
|
stmt.Close()
|
|
|
|
if err != nil {
|
|
log.Warnln("Failed to insert user to db:", err.Error())
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent()))
|
|
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
expires := getExpiresTime()
|
|
|
|
c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires})
|
|
c.Cookie(&fiber.Cookie{Name: "username", Value: input.Username, Secure: true, Expires: expires})
|
|
c.Cookie(&fiber.Cookie{Name: "user_hashtag", Value: input.Hashtag, Secure: true, Expires: expires})
|
|
|
|
log.Debugln("user created", userId, input.Hashtag, input.Username, input.Email)
|
|
|
|
return c.SendStatus(fiber.StatusCreated)
|
|
}
|
|
|
|
func generateRandomString(n int) (string, error) {
|
|
const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
r := make([]byte, n)
|
|
|
|
for i := 0; i < n; i++ {
|
|
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
r[i] = letters[num.Int64()]
|
|
}
|
|
|
|
return string(r), nil
|
|
}
|
|
|
|
func generateRandomHashtag(db *sql.DB, n int) (string, error) {
|
|
c := make(chan bool)
|
|
var s string
|
|
var err error
|
|
|
|
for {
|
|
s, err = generateRandomString(6)
|
|
|
|
if err != nil {
|
|
log.Warnln("Error generating Hashtag:", err)
|
|
return "", err
|
|
}
|
|
|
|
go func() {
|
|
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", s).Scan(&s)
|
|
|
|
if err == sql.ErrNoRows {
|
|
c <- true
|
|
} else {
|
|
c <- false
|
|
}
|
|
}()
|
|
|
|
if msg := <-c; msg {
|
|
break
|
|
}
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func isHashtagValid(db *sql.DB, h string) bool {
|
|
if !isUpper(h) || len(h) < 2 || len(h) > 6 {
|
|
return false
|
|
}
|
|
|
|
err := db.QueryRow("SELECT user_hashtag FROM users WHERE user_hashtag = ?", h).Scan(&h)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isUpper(s string) bool {
|
|
for _, r := range s {
|
|
if !unicode.IsUpper(r) && unicode.IsLetter(r) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func isValid(s string, min int, max int) bool {
|
|
if len(s) < min || len(s) > max {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
var emailRegex = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$")
|
|
|
|
func isEmailValid(e string) bool {
|
|
if len(e) < 3 || len(e) > 255 {
|
|
return false
|
|
}
|
|
return emailRegex.MatchString(e)
|
|
}
|
|
|
|
func isEmailAvailable(db *sql.DB, e string) bool {
|
|
err := db.QueryRow("SELECT email FROM users WHERE email = ?", e).Scan(&e)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func SessionIdCheck(c *fiber.Ctx) error {
|
|
sessionId := c.Cookies("session_id")
|
|
|
|
if sessionId == "" {
|
|
return fiber.ErrUnauthorized
|
|
}
|
|
|
|
valid, err := isSessionIdValid(sessionId)
|
|
|
|
if err != nil {
|
|
return fiber.ErrInternalServerError
|
|
}
|
|
|
|
if valid {
|
|
return c.Next()
|
|
}
|
|
|
|
return fiber.ErrUnauthorized
|
|
}
|
|
|
|
func isSessionIdValid(sessionId string) (bool, error) {
|
|
db, err := database.GetDatabase()
|
|
|
|
if db == nil || err != nil {
|
|
return false, err
|
|
}
|
|
|
|
defer db.Close()
|
|
|
|
deleteExpiredSessions(db)
|
|
|
|
err = db.QueryRow("SELECT session_id FROM sessions WHERE session_id = ?", sessionId).Scan(&sessionId)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
|
|
/*sessionExpires, err := time.Parse("2006-01-02 15:04:05", expires)
|
|
|
|
log.Infoln("expires", expires, time.Now().Add(time.Hour*72).Unix(), sessionExpires.Unix())
|
|
|
|
if err != nil {
|
|
log.Warn("Failed to parse session datetime", err)
|
|
return false, err
|
|
}
|
|
|
|
// session has expired
|
|
if time.Now().Unix() > sessionExpires.Unix() {
|
|
log.Info("bigger")
|
|
deleteSession(db, sessionId)
|
|
|
|
return false, err
|
|
}
|
|
|
|
log.Info("not bigger") */
|
|
return true, nil
|
|
}
|
|
|
|
func deleteSession(db *sql.DB, sessionId string) {
|
|
res, err := db.Exec("DELETE FROM sessions WHERE session_id = ?", sessionId)
|
|
|
|
if err != nil {
|
|
log.Warnln("err deleting session:", err)
|
|
}
|
|
|
|
a, b := res.RowsAffected()
|
|
|
|
log.Debugln("delete session res", a, b)
|
|
}
|
|
|
|
func deleteExpiredSessions(db *sql.DB) {
|
|
res, err := db.Exec("DELETE FROM `sessions` WHERE expires < CURRENT_TIMESTAMP()")
|
|
|
|
if err != nil {
|
|
log.Warnln("err deleting expired sessions:", err)
|
|
}
|
|
|
|
a, b := res.RowsAffected()
|
|
|
|
log.Debugln("Delete expired sessions:", a, b)
|
|
}
|
|
|
|
func createUserSession(db *sql.DB, userId string, ip string, userAgent string) (string, error) {
|
|
sessionId, err := generateRandomString(32)
|
|
|
|
if err != nil {
|
|
log.Warnln("Failed to generate user session:", err)
|
|
return "", err
|
|
}
|
|
|
|
stmt, err := db.Prepare("INSERT INTO sessions (user_id, session_id, ip, user_agent, last_login, expires) VALUES (?, ?, ?, ?, ?, ?);")
|
|
|
|
if err != nil {
|
|
log.Warnln("Failed to insert user into db:", err)
|
|
return "", err
|
|
}
|
|
|
|
ua := ua.Parse(userAgent)
|
|
|
|
stmt.Exec(userId, sessionId, ip, ua.OS+" "+ua.Name, time.Now(), getExpiresTime())
|
|
stmt.Close()
|
|
|
|
return sessionId, nil
|
|
}
|
|
|
|
func getExpiresTime() time.Time {
|
|
// TODO: db default
|
|
|
|
return time.Now().Add(time.Hour * 72)
|
|
}
|
|
|
|
func Login(c *fiber.Ctx) error {
|
|
// swagger:operation POST /user/login userLogin
|
|
// ---
|
|
// summary: Login a user
|
|
// produces:
|
|
// - application/json
|
|
// parameters:
|
|
// - name: username or email
|
|
// in: query
|
|
// description: username or email
|
|
// type: string
|
|
// required: true
|
|
// - name: password
|
|
// in: query
|
|
// description: password (base64) of the user
|
|
// type: string
|
|
// required: true
|
|
// responses:
|
|
// '200':
|
|
// description: login success
|
|
// '401':
|
|
// description: login credentials not correct
|
|
|
|
var input LoginInput
|
|
|
|
if err := c.BodyParser(&input); err != nil {
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
log.Println(input)
|
|
|
|
if input.Username != "" && !isValid(input.Username, 3, 30) || input.Email != "" && !isEmailValid(input.Email) || input.Username == "" && input.Email == "" || input.Password == "" {
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
decodedPassword, err := base64.StdEncoding.DecodeString(input.Password)
|
|
|
|
if err != nil {
|
|
log.Debugln("base64 decoding failed:", err)
|
|
return c.SendStatus(fiber.StatusBadRequest)
|
|
}
|
|
|
|
input.Password = string(decodedPassword)
|
|
|
|
db, err := database.GetDatabase()
|
|
|
|
if db == nil || err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
defer db.Close()
|
|
|
|
var userId string
|
|
var username string
|
|
var userHashtag string
|
|
var hashedPassword string
|
|
|
|
if input.Username != "" {
|
|
err = db.QueryRow("SELECT user_id, user_hashtag, password FROM users WHERE username = ?", input.Username).Scan(&userId, &userHashtag, &hashedPassword)
|
|
} else {
|
|
err = db.QueryRow("SELECT user_id, user_hashtag, username, password FROM users WHERE email = ?", input.Email).Scan(&userId, &userHashtag, &username, &hashedPassword)
|
|
}
|
|
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(input.Password))
|
|
|
|
if err != nil {
|
|
log.Warnln("Failed to comapare bcrypt password", err)
|
|
return c.SendStatus(fiber.StatusUnauthorized)
|
|
}
|
|
|
|
sessionId, err := createUserSession(db, userId, c.IP(), string(c.Context().UserAgent()))
|
|
|
|
if err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
expires := getExpiresTime()
|
|
|
|
c.Cookie(&fiber.Cookie{Name: "session_id", Value: sessionId, Secure: true, HTTPOnly: true, Expires: expires})
|
|
if username != "" {
|
|
c.Cookie(&fiber.Cookie{Name: "username", Value: username, Secure: true, Expires: expires})
|
|
}
|
|
c.Cookie(&fiber.Cookie{Name: "user_hashtag", Value: userHashtag, Secure: true, Expires: expires})
|
|
|
|
return c.SendStatus(fiber.StatusCreated)
|
|
}
|
|
|
|
func GetUser(c *fiber.Ctx) error {
|
|
return c.SendString("user")
|
|
}
|
|
|
|
func GetUsers(c *fiber.Ctx) error {
|
|
cookie := c.Cookies("session_id")
|
|
|
|
log.Infoln("cookies", cookie)
|
|
|
|
db, err := database.GetDatabase()
|
|
|
|
// c.Params("id")
|
|
|
|
if db == nil || err != nil {
|
|
return c.SendStatus(fiber.StatusInternalServerError)
|
|
}
|
|
|
|
defer db.Close()
|
|
|
|
list := []string{}
|
|
|
|
var (
|
|
name string
|
|
)
|
|
|
|
rows, err := db.Query("SELECT username FROM users;")
|
|
fmt.Println("err", err)
|
|
defer rows.Close()
|
|
fmt.Println("reading data:")
|
|
for rows.Next() {
|
|
err := rows.Scan(&name)
|
|
fmt.Printf("Data row = (%s, %s)\n", name, err)
|
|
list = append(list, name)
|
|
}
|
|
err = rows.Err()
|
|
fmt.Println("Done")
|
|
|
|
return c.JSON(list)
|
|
}
|