from dataclasses import dataclass
from datetime import datetime
import os
from pathlib import Path
import torch
from src.shared.common import cfg
from src.shared.config_artifacts import build_config_ref, build_config_snapshot
@dataclass
class CheckpointState:
best_test_loss: float = float("inf")
best_test_f1: float = float("-inf")
no_improvement_count: int = 0
best_model_path: str | None = None
def evaluate_epoch(
*,
client_id: int,
client_model,
optimizer,
current_round: int,
epoch: int,
avg_val_loss: float,
val_metrics: dict[str, float | int],
session_id: str,
session_dir: str,
periodic_dir: str,
patience: int,
ckpt_interval: int,
state: CheckpointState,
) -> bool:
current_f1 = float(val_metrics.get("f1", 0.0))
current_precision = float(val_metrics.get("precision", 0.0))
current_recall = float(val_metrics.get("recall", 0.0))
current_threshold = float(val_metrics.get("selected_threshold", cfg.get("training", {}).get("rain_probability_threshold", 0.5)))
config_snapshot, snapshot_policy = build_config_snapshot()
num_layers_ckpt = sum(
1 for key in client_model.state_dict()
if key.startswith("lstm.weight_ih_l")
)
base_ckpt = {
"round": current_round,
"epoch": epoch + 1,
"model_state_dict": client_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": avg_val_loss,
"classification_metrics": {
"phase": str(val_metrics.get("phase", "VAL")),
"f1": current_f1,
"precision": current_precision,
"recall": current_recall,
"accuracy": float(val_metrics.get("accuracy", 0.0)),
"threshold": current_threshold,
"tp": int(val_metrics.get("tp", 0)),
"fn": int(val_metrics.get("fn", 0)),
"fp": int(val_metrics.get("fp", 0)),
"tn": int(val_metrics.get("tn", 0)),
"samples": int(
int(val_metrics.get("tp", 0))
+ int(val_metrics.get("fn", 0))
+ int(val_metrics.get("fp", 0))
+ int(val_metrics.get("tn", 0))
),
},
"config": {
"hidden_size": cfg.get("model", {}).get("hidden_size", 64),
"num_layers": num_layers_ckpt,
"input_size": cfg.get("model", {}).get("input_size", 5),
},
"config_snapshot_policy": snapshot_policy,
"config_ref": build_config_ref(),
"session_id": session_id,
"client_id": client_id,
}
if config_snapshot is not None:
base_ckpt["config_snapshot"] = config_snapshot
score_improved = current_f1 > state.best_test_f1 + 1e-9
tie_break_improved = abs(current_f1 - state.best_test_f1) <= 1e-9 and avg_val_loss < state.best_test_loss - 1e-9
if score_improved or tie_break_improved:
state.best_test_f1 = current_f1
state.best_test_loss = avg_val_loss
state.no_improvement_count = 0
stamp = datetime.now().strftime("%Y%m%d%H%M%S")
state.best_model_path = os.path.join(
session_dir,
f"best_client_{client_id}_round_{current_round}_model_{stamp}.pth",
)
torch.save(base_ckpt, state.best_model_path)
print(
f"[CLIENT {client_id}] New Best! Round {current_round}, "
f"F1={state.best_test_f1:.4f}, Loss={state.best_test_loss:.4f}, "
f"Threshold={current_threshold:.3f} -> {session_id}/{Path(state.best_model_path).name}"
)
else:
state.no_improvement_count += 1
print(f"[CLIENT {client_id}] No improvement for {state.no_improvement_count}/{patience} rounds.")
if state.no_improvement_count >= patience:
print(f"\n[EARLY STOP] Client {client_id} triggered at round {current_round} (Patience={patience})")
return True
if current_round > 0 and current_round % ckpt_interval == 0:
periodic_path = os.path.join(
periodic_dir,
f"client_{client_id}_round_{current_round:04d}.pth",
)
torch.save(base_ckpt, periodic_path)
print(f"[CLIENT {client_id}] Periodic ckpt saved: round {current_round:04d}")
return False