busybar / internal / services / auth.service.go
auth.service.go
Raw
package services

import (
	"context"
	"database/sql"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"strconv"
	"time"

	"github.com/danielrhuynh/busybar/internal/models"
	"github.com/danielrhuynh/busybar/pkg/config"
	"github.com/danielrhuynh/busybar/pkg/database"

	"github.com/labstack/echo/v4"
	"golang.org/x/oauth2"

	"github.com/Masterminds/squirrel"
)

func FetchAccountInfoFromProvider(provider config.OAuthProvider, token *oauth2.Token) (map[string]interface{}, error) {
	client := &http.Client{Timeout: 10 * time.Second}

	req, err := http.NewRequest(http.MethodGet, provider.UserInfoURL, nil)
	if err != nil {
		fmt.Printf("Failed to create user info request: %v\n", err)
		return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to create user info request")
	}

	req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.TokenType, token.AccessToken))

	res, err := client.Do(req)
	if err != nil {
		fmt.Printf("Failed to get user info: %v\n", err)
		return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get user info")
	}
	defer res.Body.Close()

	body, err := io.ReadAll(res.Body)
	if err != nil {
		fmt.Printf("Failed to read user info body: %v\n", err)
		return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to read user info body")
	}

	var userInfo map[string]interface{}
	if err := json.Unmarshal(body, &userInfo); err != nil {
		fmt.Printf("Failed to unmarshal user info: %v\n", err)
		return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to unmarshal user info")
	}

	return userInfo, nil
}

func GetProviderAccountID(userInfo map[string]interface{}) string {
	id, ok := userInfo["id"]
	if !ok {
		return ""
	}

	switch v := id.(type) {
	case string:
		return v
	case float64:
		return strconv.FormatFloat(v, 'f', 0, 64)
	default:
		return fmt.Sprintf("%v", v)
	}
}

func AuthGetUserSessionFromProvider(providerName string, providerAccountID string) (*models.UserSession, error) {
	var user models.UserSession
	var firstName, lastName, email sql.NullString

	sql, args, err := database.PSQL.Select("users.id", "users.first_name", "users.last_name", "users.email", "users.is_admin").
		From("users").
		Join("accounts ON users.account_id = accounts.id").
		Where(squirrel.Eq{
			"accounts.provider":            providerName,
			"accounts.provider_account_id": providerAccountID,
		}).
		ToSql()

	if err != nil {
		fmt.Println("Error constructing SQL for user insertion:", err)
		return nil, err
	}

	err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
		&user.ID,
		&firstName,
		&lastName,
		&email,
		&user.IsAdmin,
	)

	if err != nil {
		return nil, err
	}

	user.FirstName = firstName.String
	user.LastName = lastName.String
	user.Email = email.String

	user.IsRegistered = user.Email != ""
	return &user, nil
}

func AuthGetUserSessionFromToken(accessToken string) (*models.UserSession, error) {
	var user models.UserSession
	var firstName, lastName, email sql.NullString

	sql, args, err := database.PSQL.Select("users.id", "users.first_name", "users.last_name", "users.email", "users.is_admin").
		From("users").
		Join("sessions ON users.id = sessions.user_id").
		Where(squirrel.Eq{
			"sessions.access_token": accessToken,
		}).
		ToSql()

	if err != nil {
		fmt.Println("Error constructing SQL for user insertion:", err)
		return nil, err
	}

	err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
		&user.ID,
		&firstName,
		&lastName,
		&email,
		&user.IsAdmin,
	)

	if err != nil {
		return nil, err
	}

	user.FirstName = firstName.String
	user.LastName = lastName.String
	user.Email = email.String

	user.IsRegistered = user.Email != ""
	return &user, nil
}

func AuthCreateUserSession(providerName, providerAccountID string, token *oauth2.Token) (*models.UserSession, error) {
	sql, args, err := database.PSQL.Insert("accounts").
		Columns(
			"provider",
			"provider_account_id",
			"refresh_token",
			"access_token",
			"expires_at",
		).
		Values(
			providerName,
			providerAccountID,
			token.RefreshToken,
			token.AccessToken,
			token.Expiry,
		).
		Suffix("RETURNING id").
		ToSql()

	if err != nil {
		fmt.Println("Error constructing SQL for account insertion:", err)
		return nil, err
	}

	var accountID string
	err = database.DB.QueryRow(context.Background(), sql, args...).Scan(&accountID)
	if err != nil {
		fmt.Println("Error inserting account into database:", err)
		return nil, err
	}

	sql, args, err = database.PSQL.Insert("users").
		Columns("account_id").
		Values(accountID).
		Suffix("RETURNING id").
		ToSql()

	if err != nil {
		fmt.Println("Error constructing SQL for user insertion:", err)
		return nil, err
	}

	var user models.UserSession
	err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
		&user.ID,
	)
	if err != nil {
		fmt.Println("Error inserting user into database:", err)
		return nil, err
	}

	user.IsRegistered = false
	return &user, nil
}

func AuthCreateSession(userID int) (*models.Session, error) {
	accessToken, err := GenerateToken(32)
	if err != nil {
		return nil, fmt.Errorf("failed to generate access token: %w", err)
	}

	refreshToken, err := GenerateToken(64)
	if err != nil {
		return nil, fmt.Errorf("failed to generate refresh token: %w", err)
	}

	session := &models.Session{
		ID:                    userID,
		AccessToken:           accessToken,
		RefreshToken:          refreshToken,
		AccessTokenExpiresAt:  time.Now().Add(time.Hour * 24),
		RefreshTokenExpiresAt: time.Now().Add(time.Hour * 24 * 7),
	}

	sql, args, err := database.PSQL.Insert("sessions").
		Columns(
			"user_id",
			"access_token",
			"refresh_token",
			"access_token_expires_at",
			"refresh_token_expires_at",
		).
		Values(userID, accessToken, refreshToken, session.AccessTokenExpiresAt, session.RefreshTokenExpiresAt).
		Suffix("RETURNING id").
		ToSql()

	if err != nil {
		fmt.Println("Error constructing SQL for user insertion:", err)
		return nil, err
	}

	err = database.DB.QueryRow(context.Background(), sql, args...).Scan(&session.ID)
	if err != nil {
		fmt.Println("Error inserting user into database:", err)
		return nil, err
	}

	return session, nil
}

func AuthDeleteSession(userID int) error {
	sql, args, err := database.PSQL.Delete("sessions").Where(squirrel.Eq{"user_id": userID}).ToSql()
	if err != nil {
		return err
	}

	_, err = database.DB.Exec(context.Background(), sql, args...)
	if err != nil {
		return err
	}

	return nil
}