busybar / internal / services / reports.service.go
reports.service.go
Raw
package services

import (
	"context"
	"fmt"
	"log"
	"time"

	"github.com/danielrhuynh/busybar/internal/models"
	"github.com/danielrhuynh/busybar/pkg/database"

	"github.com/Masterminds/squirrel"
)

func CreateReport(userId int, report *models.Report) error {
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	tx, err := database.DB.Begin(ctx)
	if err != nil {
		return fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer tx.Rollback(ctx)

	query := database.PSQL.Insert("reports").
		Columns("user_id", "bar_id", "wait_time", "busyness", "music").
		Values(userId, report.BarID, report.WaitTime, report.Busyness, report.Music).
		Suffix("RETURNING report_id, report_time")

	sqlStr, args, err := query.ToSql()
	if err != nil {
		return fmt.Errorf("error generating SQL: %w", err)
	}

	row := tx.QueryRow(ctx, sqlStr, args...)
	err = row.Scan(&report.ReportID, &report.ReportTime)
	if err != nil {
		return fmt.Errorf("error creating report: %w", err)
	}

	if err := tx.Commit(ctx); err != nil {
		return fmt.Errorf("error committing transaction: %w", err)
	}

	// GO routine to async process achievements
	go func() {
		if err := ProcessReportAchievements(userId, report); err != nil {
			log.Printf("Error processing achievements: %v", err)
		}
	}()

	return nil
}

func GetWaitTimesByUserID(userID int) ([]models.Report, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()
	query := database.PSQL.Select(
		"report_id",
		"user_id",
		"bar_id",
		"wait_time",
		"busyness",
		"music",
		"report_time",
	).From("reports").
		Where(squirrel.Eq{"user_id": userID})

	sqlStr, args, err := query.ToSql()
	if err != nil {
		return nil, err
	}

	rows, err := database.DB.Query(ctx, sqlStr, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var reports []models.Report

	for rows.Next() {
		var report models.Report
		var musicArray []string
		err := rows.Scan(
			&report.ReportID,
			&report.UserID,
			&report.BarID,
			&report.WaitTime,
			&report.Busyness,
			&musicArray,
			&report.ReportTime,
		)
		if err != nil {
			return nil, err
		}
		report.Music = musicArray
		reports = append(reports, report)
	}

	if rows.Err() != nil {
		return nil, rows.Err()
	}

	return reports, nil
}

func GetAverageWaitTimes() ([]models.AverageWaitTime, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
	defer cancel()

	recentWaitTimes := database.PSQL.
		Select("bar_id", "AVG(wait_time) AS avg_wait_time").
		From("reports").
		Where("report_time >= NOW() - INTERVAL '20 minutes'").
		GroupBy("bar_id")

	lastWaitTimesInner := database.PSQL.
		Select(
			"bar_id",
			"wait_time",
			"report_time",
			"ROW_NUMBER() OVER (PARTITION BY bar_id ORDER BY report_time DESC) AS rn",
		).
		From("reports").
		Where("report_time >= NOW() - INTERVAL '3 hours'")

	lastWaitTimes := database.PSQL.
		Select("bar_id", "wait_time").
		FromSelect(lastWaitTimesInner, "sub").
		Where("rn = 1")

	recentWaitTimesSQL, _, err := recentWaitTimes.ToSql()
	if err != nil {
		log.Printf("Error generating SQL for recentWaitTimes: %v", err)
		return nil, err
	}

	lastWaitTimesSQL, _, err := lastWaitTimes.ToSql()
	if err != nil {
		log.Printf("Error generating SQL for lastWaitTimes: %v", err)
		return nil, err
	}

	fullQuery := database.PSQL.
		Select("b.bar_id").
		Column("COALESCE(rwt.avg_wait_time, lwt.wait_time::numeric) AS wait_time").
		From("bars b").
		LeftJoin("recent_wait_times rwt ON b.bar_id = rwt.bar_id").
		LeftJoin("last_wait_times lwt ON b.bar_id = lwt.bar_id").
		Where(
			squirrel.Or{
				squirrel.Expr("rwt.avg_wait_time IS NOT NULL"),
				squirrel.Expr("lwt.wait_time IS NOT NULL"),
			},
		).
		Prefix(
			fmt.Sprintf("WITH recent_wait_times AS (%s), last_wait_times AS (%s)", recentWaitTimesSQL, lastWaitTimesSQL),
		)

	sqlStr, args, err := fullQuery.ToSql()
	if err != nil {
		log.Printf("Error generating SQL: %v", err)
		return nil, err
	}

	log.Printf("Generated SQL:\n%s\nArguments: %v", sqlStr, args)

	rows, err := database.DB.Query(ctx, sqlStr, args...)
	if err != nil {
		log.Printf("Error executing query: %v", err)
		return nil, err
	}
	defer rows.Close()

	var averages []models.AverageWaitTime

	for rows.Next() {
		var avg models.AverageWaitTime
		err := rows.Scan(&avg.BarID, &avg.WaitTime)
		if err != nil {
			log.Printf("Error scanning row: %v", err)
			return nil, err
		}
		averages = append(averages, avg)
	}

	if err := rows.Err(); err != nil {
		log.Printf("Row iteration error: %v", err)
		return nil, err
	}

	return averages, nil
}