package services
import (
"context"
"fmt"
"log"
"time"
"github.com/danielrhuynh/busybar/internal/models"
"github.com/danielrhuynh/busybar/pkg/database"
"github.com/Masterminds/squirrel"
)
func CreateReport(userId int, report *models.Report) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
tx, err := database.DB.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
query := database.PSQL.Insert("reports").
Columns("user_id", "bar_id", "wait_time", "busyness", "music").
Values(userId, report.BarID, report.WaitTime, report.Busyness, report.Music).
Suffix("RETURNING report_id, report_time")
sqlStr, args, err := query.ToSql()
if err != nil {
return fmt.Errorf("error generating SQL: %w", err)
}
row := tx.QueryRow(ctx, sqlStr, args...)
err = row.Scan(&report.ReportID, &report.ReportTime)
if err != nil {
return fmt.Errorf("error creating report: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("error committing transaction: %w", err)
}
// GO routine to async process achievements
go func() {
if err := ProcessReportAchievements(userId, report); err != nil {
log.Printf("Error processing achievements: %v", err)
}
}()
return nil
}
func GetWaitTimesByUserID(userID int) ([]models.Report, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
query := database.PSQL.Select(
"report_id",
"user_id",
"bar_id",
"wait_time",
"busyness",
"music",
"report_time",
).From("reports").
Where(squirrel.Eq{"user_id": userID})
sqlStr, args, err := query.ToSql()
if err != nil {
return nil, err
}
rows, err := database.DB.Query(ctx, sqlStr, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var reports []models.Report
for rows.Next() {
var report models.Report
var musicArray []string
err := rows.Scan(
&report.ReportID,
&report.UserID,
&report.BarID,
&report.WaitTime,
&report.Busyness,
&musicArray,
&report.ReportTime,
)
if err != nil {
return nil, err
}
report.Music = musicArray
reports = append(reports, report)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return reports, nil
}
func GetAverageWaitTimes() ([]models.AverageWaitTime, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
recentWaitTimes := database.PSQL.
Select("bar_id", "AVG(wait_time) AS avg_wait_time").
From("reports").
Where("report_time >= NOW() - INTERVAL '20 minutes'").
GroupBy("bar_id")
lastWaitTimesInner := database.PSQL.
Select(
"bar_id",
"wait_time",
"report_time",
"ROW_NUMBER() OVER (PARTITION BY bar_id ORDER BY report_time DESC) AS rn",
).
From("reports").
Where("report_time >= NOW() - INTERVAL '3 hours'")
lastWaitTimes := database.PSQL.
Select("bar_id", "wait_time").
FromSelect(lastWaitTimesInner, "sub").
Where("rn = 1")
recentWaitTimesSQL, _, err := recentWaitTimes.ToSql()
if err != nil {
log.Printf("Error generating SQL for recentWaitTimes: %v", err)
return nil, err
}
lastWaitTimesSQL, _, err := lastWaitTimes.ToSql()
if err != nil {
log.Printf("Error generating SQL for lastWaitTimes: %v", err)
return nil, err
}
fullQuery := database.PSQL.
Select("b.bar_id").
Column("COALESCE(rwt.avg_wait_time, lwt.wait_time::numeric) AS wait_time").
From("bars b").
LeftJoin("recent_wait_times rwt ON b.bar_id = rwt.bar_id").
LeftJoin("last_wait_times lwt ON b.bar_id = lwt.bar_id").
Where(
squirrel.Or{
squirrel.Expr("rwt.avg_wait_time IS NOT NULL"),
squirrel.Expr("lwt.wait_time IS NOT NULL"),
},
).
Prefix(
fmt.Sprintf("WITH recent_wait_times AS (%s), last_wait_times AS (%s)", recentWaitTimesSQL, lastWaitTimesSQL),
)
sqlStr, args, err := fullQuery.ToSql()
if err != nil {
log.Printf("Error generating SQL: %v", err)
return nil, err
}
log.Printf("Generated SQL:\n%s\nArguments: %v", sqlStr, args)
rows, err := database.DB.Query(ctx, sqlStr, args...)
if err != nil {
log.Printf("Error executing query: %v", err)
return nil, err
}
defer rows.Close()
var averages []models.AverageWaitTime
for rows.Next() {
var avg models.AverageWaitTime
err := rows.Scan(&avg.BarID, &avg.WaitTime)
if err != nil {
log.Printf("Error scanning row: %v", err)
return nil, err
}
averages = append(averages, avg)
}
if err := rows.Err(); err != nil {
log.Printf("Row iteration error: %v", err)
return nil, err
}
return averages, nil
}