package services
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/danielrhuynh/busybar/internal/models"
"github.com/danielrhuynh/busybar/pkg/config"
"github.com/danielrhuynh/busybar/pkg/database"
"github.com/labstack/echo/v4"
"golang.org/x/oauth2"
"github.com/Masterminds/squirrel"
)
func FetchAccountInfoFromProvider(provider config.OAuthProvider, token *oauth2.Token) (map[string]interface{}, error) {
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequest(http.MethodGet, provider.UserInfoURL, nil)
if err != nil {
fmt.Printf("Failed to create user info request: %v\n", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to create user info request")
}
req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.TokenType, token.AccessToken))
res, err := client.Do(req)
if err != nil {
fmt.Printf("Failed to get user info: %v\n", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get user info")
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
fmt.Printf("Failed to read user info body: %v\n", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to read user info body")
}
var userInfo map[string]interface{}
if err := json.Unmarshal(body, &userInfo); err != nil {
fmt.Printf("Failed to unmarshal user info: %v\n", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to unmarshal user info")
}
return userInfo, nil
}
func GetProviderAccountID(userInfo map[string]interface{}) string {
id, ok := userInfo["id"]
if !ok {
return ""
}
switch v := id.(type) {
case string:
return v
case float64:
return strconv.FormatFloat(v, 'f', 0, 64)
default:
return fmt.Sprintf("%v", v)
}
}
func AuthGetUserSessionFromProvider(providerName string, providerAccountID string) (*models.UserSession, error) {
var user models.UserSession
var firstName, lastName, email sql.NullString
sql, args, err := database.PSQL.Select("users.id", "users.first_name", "users.last_name", "users.email", "users.is_admin").
From("users").
Join("accounts ON users.account_id = accounts.id").
Where(squirrel.Eq{
"accounts.provider": providerName,
"accounts.provider_account_id": providerAccountID,
}).
ToSql()
if err != nil {
fmt.Println("Error constructing SQL for user insertion:", err)
return nil, err
}
err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
&user.ID,
&firstName,
&lastName,
&email,
&user.IsAdmin,
)
if err != nil {
return nil, err
}
user.FirstName = firstName.String
user.LastName = lastName.String
user.Email = email.String
user.IsRegistered = user.Email != ""
return &user, nil
}
func AuthGetUserSessionFromToken(accessToken string) (*models.UserSession, error) {
var user models.UserSession
var firstName, lastName, email sql.NullString
sql, args, err := database.PSQL.Select("users.id", "users.first_name", "users.last_name", "users.email", "users.is_admin").
From("users").
Join("sessions ON users.id = sessions.user_id").
Where(squirrel.Eq{
"sessions.access_token": accessToken,
}).
ToSql()
if err != nil {
fmt.Println("Error constructing SQL for user insertion:", err)
return nil, err
}
err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
&user.ID,
&firstName,
&lastName,
&email,
&user.IsAdmin,
)
if err != nil {
return nil, err
}
user.FirstName = firstName.String
user.LastName = lastName.String
user.Email = email.String
user.IsRegistered = user.Email != ""
return &user, nil
}
func AuthCreateUserSession(providerName, providerAccountID string, token *oauth2.Token) (*models.UserSession, error) {
sql, args, err := database.PSQL.Insert("accounts").
Columns(
"provider",
"provider_account_id",
"refresh_token",
"access_token",
"expires_at",
).
Values(
providerName,
providerAccountID,
token.RefreshToken,
token.AccessToken,
token.Expiry,
).
Suffix("RETURNING id").
ToSql()
if err != nil {
fmt.Println("Error constructing SQL for account insertion:", err)
return nil, err
}
var accountID string
err = database.DB.QueryRow(context.Background(), sql, args...).Scan(&accountID)
if err != nil {
fmt.Println("Error inserting account into database:", err)
return nil, err
}
sql, args, err = database.PSQL.Insert("users").
Columns("account_id").
Values(accountID).
Suffix("RETURNING id").
ToSql()
if err != nil {
fmt.Println("Error constructing SQL for user insertion:", err)
return nil, err
}
var user models.UserSession
err = database.DB.QueryRow(context.Background(), sql, args...).Scan(
&user.ID,
)
if err != nil {
fmt.Println("Error inserting user into database:", err)
return nil, err
}
user.IsRegistered = false
return &user, nil
}
func AuthCreateSession(userID int) (*models.Session, error) {
accessToken, err := GenerateToken(32)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, err := GenerateToken(64)
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
session := &models.Session{
ID: userID,
AccessToken: accessToken,
RefreshToken: refreshToken,
AccessTokenExpiresAt: time.Now().Add(time.Hour * 24),
RefreshTokenExpiresAt: time.Now().Add(time.Hour * 24 * 7),
}
sql, args, err := database.PSQL.Insert("sessions").
Columns(
"user_id",
"access_token",
"refresh_token",
"access_token_expires_at",
"refresh_token_expires_at",
).
Values(userID, accessToken, refreshToken, session.AccessTokenExpiresAt, session.RefreshTokenExpiresAt).
Suffix("RETURNING id").
ToSql()
if err != nil {
fmt.Println("Error constructing SQL for user insertion:", err)
return nil, err
}
err = database.DB.QueryRow(context.Background(), sql, args...).Scan(&session.ID)
if err != nil {
fmt.Println("Error inserting user into database:", err)
return nil, err
}
return session, nil
}
func AuthDeleteSession(userID int) error {
sql, args, err := database.PSQL.Delete("sessions").Where(squirrel.Eq{"user_id": userID}).ToSql()
if err != nil {
return err
}
_, err = database.DB.Exec(context.Background(), sql, args...)
if err != nil {
return err
}
return nil
}