"""
plot_training_curve.py
======================
Loads every periodic checkpoint (bestweights/<session>/periodic/)
and evaluates it on the 14-day test set.
Produces one figure with:
- One subplot per client : train_loss (from ckpt dict) vs test_mse (computed)
- A combined overlay plot : all clients test_mse on the same axes
Usage:
uv run python src/data/plot_training_curve.py
uv run python src/data/plot_training_curve.py --session 2026-03-12_15-05-55
uv run python src/data/plot_training_curve.py --device mps
"""
import os
import sys
import glob
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg") # headless (no display required)
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
# ── project root & shared imports ────────────────────────────────────────────
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.append(str(project_root))
from src.shared.common import cfg
from src.client.data_pipeline import FUTURE_RAIN_COL
from src.models.split_lstm import ClientLSTM, ServerHead
from src.shared.targets import inverse_target_scalar, rain_probability_threshold
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def _parse_round(path: str) -> int:
"""Extract zero-padded round from periodic filename: client_1_round_0090.pth → 90"""
stem = Path(path).stem # e.g. 'client_1_round_0090'
parts = stem.split("_")
try:
idx = parts.index("round")
return int(parts[idx + 1])
except (ValueError, IndexError):
return 0
def _load_ckpt(path: str, device: torch.device):
"""Load a checkpoint and return (state_dict, train_loss, round_num)."""
raw = torch.load(path, map_location=device, weights_only=True)
if isinstance(raw, dict) and "model_state_dict" in raw:
return raw["model_state_dict"], float(raw.get("loss", float("nan"))), raw.get("round", 0)
# Legacy bare state_dict
return raw, float("nan"), _parse_round(path)
def _find_session(session_id: str | None) -> Path:
"""Return the session directory to use."""
bw = project_root / "bestweights"
if session_id:
p = bw / session_id
if not p.is_dir():
raise FileNotFoundError(f"Session directory not found: {p}")
return p
sessions = sorted(
[d for d in bw.iterdir() if d.is_dir()],
key=lambda d: (d.stat().st_mtime, d.name),
)
if not sessions:
raise FileNotFoundError(f"No session directories found in {bw}")
latest = sessions[-1]
print(f"[DEBUG] Picking latest session: {latest.name}")
return latest
# ─────────────────────────────────────────────────────────────────────────────
# Evaluation helper (lightweight — mirrors evaluate_client in run_evaluation.py)
# ─────────────────────────────────────────────────────────────────────────────
def _eval_pair(client_model, server_model, test_data_cache, device, batch_size=512):
"""
Run evaluation using a pre-loaded cache of test samples.
test_data_cache: list of (input_tensor, target_tensor) tuples
Uses batched inference for speed.
"""
prob_threshold = rain_probability_threshold()
total_se, total_n = 0.0, 0
xs = torch.cat([x for x, _ in test_data_cache], dim=0)
ys = torch.cat([y for _, y in test_data_cache], dim=0).squeeze(1) # (N,)
with torch.no_grad():
for start in range(0, len(xs), batch_size):
xb = xs[start:start + batch_size].to(device)
yb = ys[start:start + batch_size].to(device)
smashed = client_model(xb)
rain_logit, rain_amount = server_model(smashed)
rain_probs = torch.sigmoid(rain_logit).squeeze(1) # (B,)
amounts = rain_amount.squeeze(1) # (B,)
preds = torch.where(
rain_probs >= prob_threshold,
torch.tensor([inverse_target_scalar(v.item()) for v in amounts],
dtype=torch.float32, device=device),
torch.zeros(len(amounts), device=device),
)
total_se += ((preds - yb) ** 2).sum().item()
total_n += len(yb)
return total_se / total_n if total_n > 0 else float("nan")
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--session", type=str, default=None,
help="Session ID (folder name). Defaults to latest.")
parser.add_argument("--device", type=str, default="cpu",
help="'cpu' or 'mps'")
args = parser.parse_args()
device = torch.device(args.device)
# Config
seq_len = cfg.get("model", {}).get("seq_len", 48)
horizon = max(1, int(cfg.get("model", {}).get("horizon", 3)))
input_size = cfg.get("model", {}).get("input_size", 5)
lstm_dropout = float(cfg.get("model", {}).get("lstm_dropout", cfg.get("model", {}).get("dropout", 0.3)))
hidden_size = cfg.get("model", {}).get("hidden_size", 64)
session_dir = _find_session(args.session)
periodic_dir = session_dir / "periodic"
# Preserve full relative path like "2026-04-09_08-11-48/01_seed42"
bw = project_root / "bestweights"
try:
session_rel_path = session_dir.relative_to(bw)
session_name = str(session_rel_path)
except ValueError:
session_name = session_dir.name
session_id = session_name
if not periodic_dir.is_dir():
print(f"[ERROR] No periodic/ dir found in {session_dir}")
sys.exit(1)
print(f"📅 Session : {session_id}")
print(f"⚙️ Device : {device}")
# 🕵️♂️ UI Update: Explain that this is the Cyclic test window per month
print(f"📅 Test window: 2025-07-01 → end | seq_len={seq_len} horizon={horizon}\n")
# ── Collect server periodic checkpoints ──────────────────────────────────
server_ckpts = {
_parse_round(p): p
for p in sorted(glob.glob(str(periodic_dir / "server_round_*.pth")))
}
if not server_ckpts:
print("[ERROR] No server periodic checkpoints found.")
sys.exit(1)
server_rounds = sorted(server_ckpts.keys())
# ── Collect client IDs ───────────────────────────────────────────────────
client_files = sorted(glob.glob(str(periodic_dir / "client_*_round_*.pth")))
client_ids: set[int] = set()
for f in client_files:
parts = Path(f).stem.split("_")
try:
client_ids.add(int(parts[1]))
except (IndexError, ValueError):
pass
client_ids_sorted = sorted(client_ids)
if not client_ids_sorted:
print("[ERROR] No client periodic checkpoints found.")
sys.exit(1)
print(f"Clients found: {client_ids_sorted}")
# 📉 Optimization: In Federated Learning, clients share the same global model.
# Evaluating CLIENT 1 is sufficient to see the trend.
eval_client_ids = [1] if 1 in client_ids_sorted else [client_ids_sorted[0]]
print(f"\u26a1\ufe0f Plotting representative Client {eval_client_ids[0]} for speed...")
print(f"Server rounds: {server_rounds}\n")
# ── Pre-load Test Data Cache (Speed Hack!) ──────────────────────────────
print(f"⏳ Pre-loading test samples from 3-year dataset (Monthly Cyclic)...")
from src.client.data_pipeline import collect_test_indices, load_sensor_data
data_dir = None
for pd_name in [cfg.get("data", {}).get("processed_dir", "dataset/processed"), "data/processed", "dataset/processed"]:
candidate = project_root / pd_name
if candidate.is_dir() and any(candidate.glob("*.parquet")):
data_dir = candidate
break
test_data_cache = []
if data_dir:
features_cfg = cfg.get("data", {}).get("feature_cols", ["Temperature", "Humidity", "Pressure", "Wind Speed", "Rain"])
eval_cid = eval_client_ids[0]
all_files = sorted(data_dir.glob("*.parquet"))
# Each client is assigned one file by sorted order; use that client's file only
client_file = all_files[eval_cid - 1] if eval_cid - 1 < len(all_files) else all_files[0]
print(f"Loading test data from: {client_file.name}")
df = load_sensor_data(str(client_file), horizon=horizon)
test_indices = collect_test_indices(df, min_history=seq_len, horizon=horizon)
for idx in test_indices:
target_val = float(df.iloc[idx][FUTURE_RAIN_COL])
window = df.iloc[idx - seq_len : idx]
feat = window[features_cfg].apply(pd.to_numeric, errors="coerce").fillna(0).values
test_data_cache.append((
torch.tensor(feat, dtype=torch.float32).unsqueeze(0),
torch.tensor([[target_val]], dtype=torch.float32)
))
print(f"✅ Cached {len(test_data_cache)} test samples.\n")
# ── Evaluate every (client, round) pair ──────────────────────────────────
# Structure: results[client_id] = [(round, train_loss, test_mse), ...]
results: dict[int, list[tuple[int, float, float]]] = {}
for cid in eval_client_ids:
ckpts = sorted(
glob.glob(str(periodic_dir / f"client_{cid}_round_*.pth")),
key=_parse_round
)
if not ckpts:
print(f" [SKIP] Client {cid}: no periodic checkpoints")
continue
curve = []
for ckpt_path in ckpts:
r = _parse_round(ckpt_path)
# Find closest server round (prefer same, else nearest earlier)
srv_r = max((sr for sr in server_rounds if sr <= r), default=server_rounds[0])
srv_path = server_ckpts[srv_r]
client_state, train_loss, _ = _load_ckpt(ckpt_path, device)
server_state, _, _ = _load_ckpt(srv_path, device)
# Initialize models
num_layers = sum(1 for k in client_state if k.startswith("lstm.weight_ih_l"))
c_model = ClientLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, lstm_dropout=lstm_dropout).to(device)
c_model.load_state_dict(client_state)
c_model.eval()
head_width = cfg.get("model", {}).get("server_head_width", 64)
head_dropout = cfg.get("model", {}).get("server_head_dropout", 0.1)
s_model = ServerHead(hidden_size=hidden_size, output_size=1, head_width=head_width, dropout=head_dropout).to(device)
s_model.load_state_dict(server_state)
s_model.eval()
print(f" Client {cid} | round {r:04d} | train_loss={train_loss:.4f} | evaluating...", end="\r", flush=True)
# 🔥 Real-time calculation on TEST SET
test_mse = _eval_pair(c_model, s_model, test_data_cache, device)
print(f" Client {cid} | round {r:04d} | train_loss={train_loss:.4f} | test_mse={test_mse:.4f}")
curve.append((r, train_loss, test_mse))
results[cid] = curve
if not results:
print("[ERROR] No results to plot.")
sys.exit(1)
# ── Plot ─────────────────────────────────────────────────────────────────
n_clients = len(results)
fig_w = max(10, n_clients * 5)
fig, axes = plt.subplots(
1, n_clients + 1,
figsize=(fig_w + 5, 5),
squeeze=False
)
axes = axes[0] # flatten
# Colour palette
palette = plt.cm.tab10.colors
# Per-client subplots
for ax_idx, (cid, curve) in enumerate(sorted(results.items())):
ax = axes[ax_idx]
color = palette[ax_idx % len(palette)]
rounds = [c[0] for c in curve]
train_loss = [c[1] for c in curve]
test_mse = [c[2] for c in curve]
# Train loss line
ax.plot(rounds, train_loss, "o--", color=color, alpha=0.6,
linewidth=1.5, markersize=4, label="Train Loss")
# Test MSE line
ax.plot(rounds, test_mse, "s-", color=color,
linewidth=2, markersize=5, label="Test MSE")
# Mark best round (lowest test MSE)
valid = [(r, v) for r, v in zip(rounds, test_mse) if not np.isnan(v)]
if valid:
best_r, best_v = min(valid, key=lambda x: x[1])
ax.axvline(best_r, color="red", linestyle=":", linewidth=1.2, alpha=0.7)
ax.annotate(f"Best\nR={best_r}", xy=(best_r, best_v),
xytext=(best_r + 2, best_v * 1.05),
fontsize=7, color="red",
arrowprops=dict(arrowstyle="->", color="red", lw=0.8))
ax.set_title(f"Client {cid}", fontsize=12, fontweight="bold")
ax.set_xlabel("Round")
ax.set_ylabel("Loss / MSE")
ax.legend(fontsize=8)
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax.grid(True, alpha=0.3, linestyle="--")
ax.set_facecolor("#f9f9f9")
# Combined overlay (last subplot)
ax_combined = axes[-1]
for idx, (cid, curve) in enumerate(sorted(results.items())):
color = palette[idx % len(palette)]
rounds = [c[0] for c in curve]
test_mse = [c[2] for c in curve]
ax_combined.plot(rounds, test_mse, "s-", color=color,
linewidth=2, markersize=4, label=f"Client {cid}")
ax_combined.set_title("All Clients — Test MSE", fontsize=12, fontweight="bold")
ax_combined.set_xlabel("Round")
ax_combined.set_ylabel("Test MSE")
ax_combined.legend(fontsize=8)
ax_combined.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
ax_combined.grid(True, alpha=0.3, linestyle="--")
ax_combined.set_facecolor("#f9f9f9")
plt.suptitle(
f"Training Curve ─ Session {session_name}\n"
f"Dashed = Train Loss | Solid = Test MSE | Red line = Best Round",
fontsize=11, y=1.02
)
plt.tight_layout()
out_dir = project_root / "results" / session_name
out_dir.mkdir(parents=True, exist_ok=True)
safe_session_name = session_name.replace("/", "_").replace("\\", "_")
out_path = out_dir / f"training_curve_{safe_session_name}.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"\nPlot saved: {out_path}")
print("\n── Overfitting Summary ──────────────────────────────────────")
for cid, curve in sorted(results.items()):
valid = [(r, tl, tm) for r, tl, tm in curve
if not np.isnan(tl) and not np.isnan(tm)]
if not valid:
continue
best_r, _, best_tm = min(valid, key=lambda x: x[2])
last_r, _, last_tm = valid[-1]
gap = last_tm - best_tm
flag = "⚠️ Overfitting likely" if gap > 0.005 else "✅ Stable"
print(f" Client {cid}: best_round={best_r} best_testMSE={best_tm:.4f}"
f" final_testMSE={last_tm:.4f} gap={gap:+.4f} {flag}")
print()
if __name__ == "__main__":
main()