csc8114 / code / src / client / checkpointing.py
checkpointing.py
Raw
from dataclasses import dataclass
from datetime import datetime
import os
from pathlib import Path

import torch

from src.shared.common import cfg
from src.shared.config_artifacts import build_config_ref, build_config_snapshot


@dataclass
class CheckpointState:
    best_test_loss: float = float("inf")
    best_test_f1: float = float("-inf")
    no_improvement_count: int = 0
    best_model_path: str | None = None


def evaluate_epoch(
    *,
    client_id: int,
    client_model,
    optimizer,
    current_round: int,
    epoch: int,
    avg_val_loss: float,
    val_metrics: dict[str, float | int],
    session_id: str,
    session_dir: str,
    periodic_dir: str,
    patience: int,
    ckpt_interval: int,
    state: CheckpointState,
) -> bool:
    current_f1 = float(val_metrics.get("f1", 0.0))
    current_precision = float(val_metrics.get("precision", 0.0))
    current_recall = float(val_metrics.get("recall", 0.0))
    current_threshold = float(val_metrics.get("selected_threshold", cfg.get("training", {}).get("rain_probability_threshold", 0.5)))
    config_snapshot, snapshot_policy = build_config_snapshot()
    num_layers_ckpt = sum(
        1 for key in client_model.state_dict()
        if key.startswith("lstm.weight_ih_l")
    )
    base_ckpt = {
        "round": current_round,
        "epoch": epoch + 1,
        "model_state_dict": client_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_val_loss,
        "classification_metrics": {
            "phase": str(val_metrics.get("phase", "VAL")),
            "f1": current_f1,
            "precision": current_precision,
            "recall": current_recall,
            "accuracy": float(val_metrics.get("accuracy", 0.0)),
            "threshold": current_threshold,
            "tp": int(val_metrics.get("tp", 0)),
            "fn": int(val_metrics.get("fn", 0)),
            "fp": int(val_metrics.get("fp", 0)),
            "tn": int(val_metrics.get("tn", 0)),
            "samples": int(
                int(val_metrics.get("tp", 0))
                + int(val_metrics.get("fn", 0))
                + int(val_metrics.get("fp", 0))
                + int(val_metrics.get("tn", 0))
            ),
        },
        "config": {
            "hidden_size": cfg.get("model", {}).get("hidden_size", 64),
            "num_layers": num_layers_ckpt,
            "input_size": cfg.get("model", {}).get("input_size", 5),
        },
        "config_snapshot_policy": snapshot_policy,
        "config_ref": build_config_ref(),
        "session_id": session_id,
        "client_id": client_id,
    }
    if config_snapshot is not None:
        base_ckpt["config_snapshot"] = config_snapshot

    score_improved = current_f1 > state.best_test_f1 + 1e-9
    tie_break_improved = abs(current_f1 - state.best_test_f1) <= 1e-9 and avg_val_loss < state.best_test_loss - 1e-9
    if score_improved or tie_break_improved:
        state.best_test_f1 = current_f1
        state.best_test_loss = avg_val_loss
        state.no_improvement_count = 0
        stamp = datetime.now().strftime("%Y%m%d%H%M%S")
        state.best_model_path = os.path.join(
            session_dir,
            f"best_client_{client_id}_round_{current_round}_model_{stamp}.pth",
        )
        torch.save(base_ckpt, state.best_model_path)
        print(
            f"[CLIENT {client_id}] New Best! Round {current_round}, "
            f"F1={state.best_test_f1:.4f}, Loss={state.best_test_loss:.4f}, "
            f"Threshold={current_threshold:.3f} -> {session_id}/{Path(state.best_model_path).name}"
        )
    else:
        state.no_improvement_count += 1
        print(f"[CLIENT {client_id}] No improvement for {state.no_improvement_count}/{patience} rounds.")

    if state.no_improvement_count >= patience:
        print(f"\n[EARLY STOP] Client {client_id} triggered at round {current_round} (Patience={patience})")
        return True

    if current_round > 0 and current_round % ckpt_interval == 0:
        periodic_path = os.path.join(
            periodic_dir,
            f"client_{client_id}_round_{current_round:04d}.pth",
        )
        torch.save(base_ckpt, periodic_path)
        print(f"[CLIENT {client_id}] Periodic ckpt saved: round {current_round:04d}")

    return False