""" plot_training_curve.py ====================== Loads every periodic checkpoint (bestweights//periodic/) and evaluates it on the 14-day test set. Produces one figure with: - One subplot per client : train_loss (from ckpt dict) vs test_mse (computed) - A combined overlay plot : all clients test_mse on the same axes Usage: uv run python src/data/plot_training_curve.py uv run python src/data/plot_training_curve.py --session 2026-03-12_15-05-55 uv run python src/data/plot_training_curve.py --device mps """ import os import sys import glob import argparse from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn as nn import matplotlib matplotlib.use("Agg") # headless (no display required) import matplotlib.pyplot as plt import matplotlib.ticker as mticker # ── project root & shared imports ──────────────────────────────────────────── project_root = Path(__file__).resolve().parents[2] if str(project_root) not in sys.path: sys.path.append(str(project_root)) from src.shared.common import cfg from src.client.data_pipeline import FUTURE_RAIN_COL from src.models.split_lstm import ClientLSTM, ServerHead from src.shared.targets import inverse_target_scalar, rain_probability_threshold # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _parse_round(path: str) -> int: """Extract zero-padded round from periodic filename: client_1_round_0090.pth → 90""" stem = Path(path).stem # e.g. 'client_1_round_0090' parts = stem.split("_") try: idx = parts.index("round") return int(parts[idx + 1]) except (ValueError, IndexError): return 0 def _load_ckpt(path: str, device: torch.device): """Load a checkpoint and return (state_dict, train_loss, round_num).""" raw = torch.load(path, map_location=device, weights_only=True) if isinstance(raw, dict) and "model_state_dict" in raw: return raw["model_state_dict"], float(raw.get("loss", float("nan"))), raw.get("round", 0) # Legacy bare state_dict return raw, float("nan"), _parse_round(path) def _find_session(session_id: str | None) -> Path: """Return the session directory to use.""" bw = project_root / "bestweights" if session_id: p = bw / session_id if not p.is_dir(): raise FileNotFoundError(f"Session directory not found: {p}") return p sessions = sorted( [d for d in bw.iterdir() if d.is_dir()], key=lambda d: (d.stat().st_mtime, d.name), ) if not sessions: raise FileNotFoundError(f"No session directories found in {bw}") latest = sessions[-1] print(f"[DEBUG] Picking latest session: {latest.name}") return latest # ───────────────────────────────────────────────────────────────────────────── # Evaluation helper (lightweight — mirrors evaluate_client in run_evaluation.py) # ───────────────────────────────────────────────────────────────────────────── def _eval_pair(client_model, server_model, test_data_cache, device, batch_size=512): """ Run evaluation using a pre-loaded cache of test samples. test_data_cache: list of (input_tensor, target_tensor) tuples Uses batched inference for speed. """ prob_threshold = rain_probability_threshold() total_se, total_n = 0.0, 0 xs = torch.cat([x for x, _ in test_data_cache], dim=0) ys = torch.cat([y for _, y in test_data_cache], dim=0).squeeze(1) # (N,) with torch.no_grad(): for start in range(0, len(xs), batch_size): xb = xs[start:start + batch_size].to(device) yb = ys[start:start + batch_size].to(device) smashed = client_model(xb) rain_logit, rain_amount = server_model(smashed) rain_probs = torch.sigmoid(rain_logit).squeeze(1) # (B,) amounts = rain_amount.squeeze(1) # (B,) preds = torch.where( rain_probs >= prob_threshold, torch.tensor([inverse_target_scalar(v.item()) for v in amounts], dtype=torch.float32, device=device), torch.zeros(len(amounts), device=device), ) total_se += ((preds - yb) ** 2).sum().item() total_n += len(yb) return total_se / total_n if total_n > 0 else float("nan") # ───────────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--session", type=str, default=None, help="Session ID (folder name). Defaults to latest.") parser.add_argument("--device", type=str, default="cpu", help="'cpu' or 'mps'") args = parser.parse_args() device = torch.device(args.device) # Config seq_len = cfg.get("model", {}).get("seq_len", 48) horizon = max(1, int(cfg.get("model", {}).get("horizon", 3))) input_size = cfg.get("model", {}).get("input_size", 5) lstm_dropout = float(cfg.get("model", {}).get("lstm_dropout", cfg.get("model", {}).get("dropout", 0.3))) hidden_size = cfg.get("model", {}).get("hidden_size", 64) session_dir = _find_session(args.session) periodic_dir = session_dir / "periodic" # Preserve full relative path like "2026-04-09_08-11-48/01_seed42" bw = project_root / "bestweights" try: session_rel_path = session_dir.relative_to(bw) session_name = str(session_rel_path) except ValueError: session_name = session_dir.name session_id = session_name if not periodic_dir.is_dir(): print(f"[ERROR] No periodic/ dir found in {session_dir}") sys.exit(1) print(f"📅 Session : {session_id}") print(f"⚙️ Device : {device}") # 🕵️‍♂️ UI Update: Explain that this is the Cyclic test window per month print(f"📅 Test window: 2025-07-01 → end | seq_len={seq_len} horizon={horizon}\n") # ── Collect server periodic checkpoints ────────────────────────────────── server_ckpts = { _parse_round(p): p for p in sorted(glob.glob(str(periodic_dir / "server_round_*.pth"))) } if not server_ckpts: print("[ERROR] No server periodic checkpoints found.") sys.exit(1) server_rounds = sorted(server_ckpts.keys()) # ── Collect client IDs ─────────────────────────────────────────────────── client_files = sorted(glob.glob(str(periodic_dir / "client_*_round_*.pth"))) client_ids: set[int] = set() for f in client_files: parts = Path(f).stem.split("_") try: client_ids.add(int(parts[1])) except (IndexError, ValueError): pass client_ids_sorted = sorted(client_ids) if not client_ids_sorted: print("[ERROR] No client periodic checkpoints found.") sys.exit(1) print(f"Clients found: {client_ids_sorted}") # 📉 Optimization: In Federated Learning, clients share the same global model. # Evaluating CLIENT 1 is sufficient to see the trend. eval_client_ids = [1] if 1 in client_ids_sorted else [client_ids_sorted[0]] print(f"\u26a1\ufe0f Plotting representative Client {eval_client_ids[0]} for speed...") print(f"Server rounds: {server_rounds}\n") # ── Pre-load Test Data Cache (Speed Hack!) ────────────────────────────── print(f"⏳ Pre-loading test samples from 3-year dataset (Monthly Cyclic)...") from src.client.data_pipeline import collect_test_indices, load_sensor_data data_dir = None for pd_name in [cfg.get("data", {}).get("processed_dir", "dataset/processed"), "data/processed", "dataset/processed"]: candidate = project_root / pd_name if candidate.is_dir() and any(candidate.glob("*.parquet")): data_dir = candidate break test_data_cache = [] if data_dir: features_cfg = cfg.get("data", {}).get("feature_cols", ["Temperature", "Humidity", "Pressure", "Wind Speed", "Rain"]) eval_cid = eval_client_ids[0] all_files = sorted(data_dir.glob("*.parquet")) # Each client is assigned one file by sorted order; use that client's file only client_file = all_files[eval_cid - 1] if eval_cid - 1 < len(all_files) else all_files[0] print(f"Loading test data from: {client_file.name}") df = load_sensor_data(str(client_file), horizon=horizon) test_indices = collect_test_indices(df, min_history=seq_len, horizon=horizon) for idx in test_indices: target_val = float(df.iloc[idx][FUTURE_RAIN_COL]) window = df.iloc[idx - seq_len : idx] feat = window[features_cfg].apply(pd.to_numeric, errors="coerce").fillna(0).values test_data_cache.append(( torch.tensor(feat, dtype=torch.float32).unsqueeze(0), torch.tensor([[target_val]], dtype=torch.float32) )) print(f"✅ Cached {len(test_data_cache)} test samples.\n") # ── Evaluate every (client, round) pair ────────────────────────────────── # Structure: results[client_id] = [(round, train_loss, test_mse), ...] results: dict[int, list[tuple[int, float, float]]] = {} for cid in eval_client_ids: ckpts = sorted( glob.glob(str(periodic_dir / f"client_{cid}_round_*.pth")), key=_parse_round ) if not ckpts: print(f" [SKIP] Client {cid}: no periodic checkpoints") continue curve = [] for ckpt_path in ckpts: r = _parse_round(ckpt_path) # Find closest server round (prefer same, else nearest earlier) srv_r = max((sr for sr in server_rounds if sr <= r), default=server_rounds[0]) srv_path = server_ckpts[srv_r] client_state, train_loss, _ = _load_ckpt(ckpt_path, device) server_state, _, _ = _load_ckpt(srv_path, device) # Initialize models num_layers = sum(1 for k in client_state if k.startswith("lstm.weight_ih_l")) c_model = ClientLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, lstm_dropout=lstm_dropout).to(device) c_model.load_state_dict(client_state) c_model.eval() head_width = cfg.get("model", {}).get("server_head_width", 64) head_dropout = cfg.get("model", {}).get("server_head_dropout", 0.1) s_model = ServerHead(hidden_size=hidden_size, output_size=1, head_width=head_width, dropout=head_dropout).to(device) s_model.load_state_dict(server_state) s_model.eval() print(f" Client {cid} | round {r:04d} | train_loss={train_loss:.4f} | evaluating...", end="\r", flush=True) # 🔥 Real-time calculation on TEST SET test_mse = _eval_pair(c_model, s_model, test_data_cache, device) print(f" Client {cid} | round {r:04d} | train_loss={train_loss:.4f} | test_mse={test_mse:.4f}") curve.append((r, train_loss, test_mse)) results[cid] = curve if not results: print("[ERROR] No results to plot.") sys.exit(1) # ── Plot ───────────────────────────────────────────────────────────────── n_clients = len(results) fig_w = max(10, n_clients * 5) fig, axes = plt.subplots( 1, n_clients + 1, figsize=(fig_w + 5, 5), squeeze=False ) axes = axes[0] # flatten # Colour palette palette = plt.cm.tab10.colors # Per-client subplots for ax_idx, (cid, curve) in enumerate(sorted(results.items())): ax = axes[ax_idx] color = palette[ax_idx % len(palette)] rounds = [c[0] for c in curve] train_loss = [c[1] for c in curve] test_mse = [c[2] for c in curve] # Train loss line ax.plot(rounds, train_loss, "o--", color=color, alpha=0.6, linewidth=1.5, markersize=4, label="Train Loss") # Test MSE line ax.plot(rounds, test_mse, "s-", color=color, linewidth=2, markersize=5, label="Test MSE") # Mark best round (lowest test MSE) valid = [(r, v) for r, v in zip(rounds, test_mse) if not np.isnan(v)] if valid: best_r, best_v = min(valid, key=lambda x: x[1]) ax.axvline(best_r, color="red", linestyle=":", linewidth=1.2, alpha=0.7) ax.annotate(f"Best\nR={best_r}", xy=(best_r, best_v), xytext=(best_r + 2, best_v * 1.05), fontsize=7, color="red", arrowprops=dict(arrowstyle="->", color="red", lw=0.8)) ax.set_title(f"Client {cid}", fontsize=12, fontweight="bold") ax.set_xlabel("Round") ax.set_ylabel("Loss / MSE") ax.legend(fontsize=8) ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) ax.grid(True, alpha=0.3, linestyle="--") ax.set_facecolor("#f9f9f9") # Combined overlay (last subplot) ax_combined = axes[-1] for idx, (cid, curve) in enumerate(sorted(results.items())): color = palette[idx % len(palette)] rounds = [c[0] for c in curve] test_mse = [c[2] for c in curve] ax_combined.plot(rounds, test_mse, "s-", color=color, linewidth=2, markersize=4, label=f"Client {cid}") ax_combined.set_title("All Clients — Test MSE", fontsize=12, fontweight="bold") ax_combined.set_xlabel("Round") ax_combined.set_ylabel("Test MSE") ax_combined.legend(fontsize=8) ax_combined.xaxis.set_major_locator(mticker.MaxNLocator(integer=True)) ax_combined.grid(True, alpha=0.3, linestyle="--") ax_combined.set_facecolor("#f9f9f9") plt.suptitle( f"Training Curve ─ Session {session_name}\n" f"Dashed = Train Loss | Solid = Test MSE | Red line = Best Round", fontsize=11, y=1.02 ) plt.tight_layout() out_dir = project_root / "results" / session_name out_dir.mkdir(parents=True, exist_ok=True) safe_session_name = session_name.replace("/", "_").replace("\\", "_") out_path = out_dir / f"training_curve_{safe_session_name}.png" fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f"\nPlot saved: {out_path}") print("\n── Overfitting Summary ──────────────────────────────────────") for cid, curve in sorted(results.items()): valid = [(r, tl, tm) for r, tl, tm in curve if not np.isnan(tl) and not np.isnan(tm)] if not valid: continue best_r, _, best_tm = min(valid, key=lambda x: x[2]) last_r, _, last_tm = valid[-1] gap = last_tm - best_tm flag = "⚠️ Overfitting likely" if gap > 0.005 else "✅ Stable" print(f" Client {cid}: best_round={best_r} best_testMSE={best_tm:.4f}" f" final_testMSE={last_tm:.4f} gap={gap:+.4f} {flag}") print() if __name__ == "__main__": main()