busybar / internal / services / achievements.service.go
achievements.service.go
Raw
package services

// SQL is very messy right now but it works, might clean up to be more formatted and dependancy injection resistant later lol

import (
	"context"
	"fmt"
	"time"

	"github.com/danielrhuynh/busybar/pkg/database"

	"github.com/danielrhuynh/busybar/internal/models"

	"github.com/jackc/pgx/v5"
)

func getISOWeek(t time.Time) int {
	_, week := t.ISOWeek()
	return week
}

func ProcessReportAchievements(userID int, report *models.Report) error {
	ctx := context.Background()

	achievements, err := getApplicableAchievements(ctx)
	if err != nil {
		return fmt.Errorf("failed to get applicable achievements: %w", err)
	}

	for _, achievement := range achievements {
		if achievement.BarID != nil && *achievement.BarID != report.BarID {
			continue
		}

		// We can enum the required_action in the future
		if achievement.RequiredAction == "time_restricted" {
			fmt.Println("hit this")
			err = processTimeRestrictedAchievement(userID, achievement.ID, report)
		}

		if achievement.RequiredAction == "specific_date" {
			fmt.Println("hit this 2")
			err = processDateSpecificAchievement(userID, achievement.ID, report)
		}

		if achievement.RequiredAction == "report_wait_time" {
			err = processCumulativeAchievement(userID, achievement.ID)
		}

		if err != nil {
			fmt.Printf("Error processing achievement %d: %v\n", achievement.ID, err)
			continue
		}

		err = checkAndUnlockAchievement(userID, achievement.ID)
		if err != nil {
			fmt.Printf("Error checking/unlocking achievement %d: %v\n", achievement.ID, err)
		}
	}

	return nil
}

func getApplicableAchievements(ctx context.Context) ([]models.Achievement, error) {
	query := `
        SELECT achievement_id, name, description, required_action, threshold,
               bar_id, required_days, specific_date, time_restricted
        FROM achievements`

	rows, err := database.DB.Query(ctx, query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var achievements []models.Achievement
	for rows.Next() {
		var a models.Achievement
		err := rows.Scan(
			&a.ID,
			&a.Name,
			&a.Description,
			&a.RequiredAction,
			&a.Threshold,
			&a.BarID,
			&a.RequiredDays,
			&a.SpecificDate,
			&a.TimeRestricted,
		)
		if err != nil {
			return nil, err
		}
		achievements = append(achievements, a)
	}
	return achievements, rows.Err()
}

func processCumulativeAchievement(userID int, achievementID int) error {
	ctx := context.Background()
	query := `
        INSERT INTO user_achievements_mapping (user_id, achievement_id, progress_total)
        VALUES ($1, $2, 1)
        ON CONFLICT (user_id, achievement_id)
        DO UPDATE SET progress_total = user_achievements_mapping.progress_total + 1
    `
	_, err := database.DB.Exec(ctx, query, userID, achievementID)
	return err
}

func processTimeRestrictedAchievement(userID int, achievementID int, report *models.Report) error {
	ctx := context.Background()
	reportDay := int(report.ReportTime.Weekday())
	reportWeek := getISOWeek(report.ReportTime)
	reportYear := report.ReportTime.Year()

	query := `
        INSERT INTO user_achievements_mapping 
            (user_id, achievement_id, progress_days, progress_week, progress_year, progress_total)
        VALUES ($1, $2, ARRAY[$3]::integer[], $4, $5, 
            CASE 
                WHEN $3 = ANY(SELECT unnest(required_days) FROM achievements WHERE achievement_id = $2)
                THEN 1 
                ELSE 0 
            END)
        ON CONFLICT (user_id, achievement_id) DO UPDATE
        SET progress_days = 
            CASE 
                WHEN user_achievements_mapping.progress_week = $4 
                     AND user_achievements_mapping.progress_year = $5
                     AND NOT ARRAY[$3]::integer[] <@ user_achievements_mapping.progress_days
                THEN array_append(user_achievements_mapping.progress_days, $3)
                WHEN user_achievements_mapping.progress_week != $4 
                     OR user_achievements_mapping.progress_year != $5
                THEN ARRAY[$3]::integer[]
                ELSE user_achievements_mapping.progress_days
            END,
            progress_week = $4,
            progress_year = $5,
            progress_total = 
                CASE 
                    WHEN user_achievements_mapping.progress_week = $4 
                         AND user_achievements_mapping.progress_year = $5
                         AND $3 = ANY(SELECT unnest(required_days) FROM achievements WHERE achievement_id = $2)
                    THEN user_achievements_mapping.progress_total + 1
                    WHEN user_achievements_mapping.progress_week != $4 
                         OR user_achievements_mapping.progress_year != $5
                    THEN CASE 
                            WHEN $3 = ANY(SELECT unnest(required_days) FROM achievements WHERE achievement_id = $2)
                            THEN 1 
                            ELSE 0 
                         END
                    ELSE user_achievements_mapping.progress_total
                END
    `
	_, err := database.DB.Exec(ctx, query, userID, achievementID, reportDay, reportWeek, reportYear)
	return err
}

func processDateSpecificAchievement(userID int, achievementID int, report *models.Report) error {
	ctx := context.Background()
	reportDay := int(report.ReportTime.Day())
	reportMonth := int(report.ReportTime.Month())
	reportWeek := getISOWeek(report.ReportTime)
	reportYear := report.ReportTime.Year()

	query := `
        INSERT INTO user_achievements_mapping 
            (user_id, achievement_id, progress_days, progress_week, progress_year, progress_total)
        VALUES ($1, $2, ARRAY[$3]::integer[], $4, $5,
            CASE 
                WHEN EXISTS (
                    SELECT 1 
                    FROM achievements 
                    WHERE achievement_id = $2 
                    AND EXTRACT(DAY FROM specific_date) = $3
                    AND EXTRACT(MONTH FROM specific_date) = $6
                ) THEN 1
                ELSE 0
            END)
        ON CONFLICT (user_id, achievement_id) DO UPDATE
        SET 
            progress_days = ARRAY[$3]::integer[],
            progress_week = $4,
            progress_year = $5,
            progress_total = 
                CASE 
                    WHEN EXISTS (
                        SELECT 1 
                        FROM achievements 
                        WHERE achievement_id = $2 
                        AND EXTRACT(DAY FROM specific_date) = $3
                        AND EXTRACT(MONTH FROM specific_date) = $6
                    ) THEN 
                        CASE 
                            WHEN user_achievements_mapping.progress_week = $4 
                                 AND user_achievements_mapping.progress_year = $5
                            THEN user_achievements_mapping.progress_total + 1
                            ELSE 1
                        END
                    ELSE 0
                END
    `
	_, err := database.DB.Exec(ctx, query, userID, achievementID, reportDay, reportWeek, reportYear, reportMonth)
	return err
}

func checkAndUnlockAchievement(userID int, achievementID int) error {
	ctx := context.Background()
	query := `
        WITH achievement_check AS (
            SELECT 
                a.threshold,
                a.time_restricted,
                a.required_days,
                a.required_action,
                a.specific_date,
                uam.progress_total,
                uam.progress_days,
                uam.progress_week,
                uam.progress_year,
                uam.achieved_at
            FROM achievements a
            JOIN user_achievements_mapping uam 
                ON a.achievement_id = uam.achievement_id
            WHERE uam.user_id = $1 
                AND uam.achievement_id = $2
                AND uam.achieved_at IS NULL
        )
        UPDATE user_achievements_mapping
        SET achieved_at = NOW()
        WHERE user_id = $1
          AND achievement_id = $2
          AND achieved_at IS NULL
          AND EXISTS (
            SELECT 1 
            FROM achievement_check ac
            WHERE (
                -- Regular cumulative achievements
                (ac.required_action = 'report_wait_time' 
                 AND ac.progress_total >= ac.threshold)
                OR
                -- Time-restricted achievements (specific days of week)
                (ac.required_action = 'time_restricted' 
                 AND ac.progress_total >= ac.threshold
                 AND ac.progress_days @> ac.required_days)
                OR
                -- Specific date achievements
                (ac.required_action = 'specific_date'
                 AND ac.progress_total >= ac.threshold
                 AND ac.progress_week = EXTRACT(WEEK FROM ac.specific_date)::int
                 AND ac.progress_year = EXTRACT(YEAR FROM ac.specific_date)::int)
            )
          )
        RETURNING achievement_id
    `
	var achievementUnlocked int
	err := database.DB.QueryRow(ctx, query, userID, achievementID).Scan(&achievementUnlocked)
	if err == pgx.ErrNoRows {
		return nil
	}
	return err
}

func GetUserAchievements(userID int) ([]models.Achievement, error) {
	ctx := context.Background()
	query := `
        SELECT 
            a.achievement_id,
            a.name,
            a.description,
            a.required_action,
            a.threshold,
            uam.achieved_at
        FROM achievements a
        INNER JOIN user_achievements_mapping uam 
            ON a.achievement_id = uam.achievement_id 
            AND uam.user_id = $1
            AND uam.achieved_at IS NOT NULL
        ORDER BY uam.achieved_at DESC`

	rows, err := database.DB.Query(ctx, query, userID)
	if err != nil {
		return nil, fmt.Errorf("error querying achievements: %v", err)
	}
	defer rows.Close()

	var achievements []models.Achievement
	for rows.Next() {
		var achievement models.Achievement
		err := rows.Scan(
			&achievement.ID,
			&achievement.Name,
			&achievement.Description,
			&achievement.RequiredAction,
			&achievement.Threshold,
			&achievement.AchievedAt,
		)
		if err != nil {
			return nil, fmt.Errorf("error scanning achievement: %v", err)
		}
		achievements = append(achievements, achievement)
	}

	if err = rows.Err(); err != nil {
		return nil, fmt.Errorf("error iterating achievements: %v", err)
	}

	return achievements, nil
}