""" plot_confusion_matrix.py ======================== Build confusion matrix plots from client training logs. Usage: python -m src.data.plot_confusion_matrix python -m src.data.plot_confusion_matrix --session 2026-03-13_01-53-07 python -m src.data.plot_confusion_matrix --session 2026-03-13_01-53-07 --phase TEST """ from __future__ import annotations import argparse import re import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import pandas as pd project_root = Path(__file__).resolve().parents[2] if str(project_root) not in sys.path: sys.path.append(str(project_root)) from src.shared.targets import rain_probability_threshold, rain_threshold_mm def _find_session(session_id: str | None) -> Path: results_dir = project_root / "results" sessions = sorted([d for d in results_dir.iterdir() if d.is_dir() and d.name.startswith("20")]) if not sessions: raise FileNotFoundError(f"No session folders found under {results_dir}") if session_id: target = results_dir / session_id if not target.is_dir(): raise FileNotFoundError(f"Session not found: {target}") return target return sessions[-1] def _latest_client_logs(session_dir: Path) -> dict[int, Path]: """Scans for client logs in the session directory and recursive subdirectories.""" # Recursively find training logs, which might be in pi01/, pi02/ subdirs # created by 'ansible fetch' or in the root of the session dir. all_logs = list(session_dir.rglob("training_log_client*.csv")) by_client: dict[int, Path] = {} for path in sorted(all_logs): if "progress" in path.name: continue m = re.search(r"training_log_client(\d+)_\d{8}_\d{6}\.csv$", path.name) if not m: continue cid = int(m.group(1)) by_client[cid] = path return by_client def _confusion_counts( df: pd.DataFrame, *, threshold_mm: float, decision: str, prob_threshold: float, ) -> tuple[int, int, int, int]: y_true = (pd.to_numeric(df["Target"], errors="coerce") > threshold_mm).astype(int) if decision == "probability" and "RainProbability" in df.columns: probs = pd.to_numeric(df["RainProbability"], errors="coerce").fillna(0.0) y_pred = (probs >= prob_threshold).astype(int) else: preds = pd.to_numeric(df["Prediction"], errors="coerce").fillna(0.0) y_pred = (preds > threshold_mm).astype(int) tp = int(((y_true == 1) & (y_pred == 1)).sum()) fn = int(((y_true == 1) & (y_pred == 0)).sum()) fp = int(((y_true == 0) & (y_pred == 1)).sum()) tn = int(((y_true == 0) & (y_pred == 0)).sum()) return tp, fn, fp, tn def _select_scope(df: pd.DataFrame, *, scope: str) -> tuple[pd.DataFrame, int | None]: if scope != "latest": return df, None if "Epoch" not in df.columns or df.empty: return df, None epoch_values = pd.to_numeric(df["Epoch"], errors="coerce").dropna() if epoch_values.empty: return df, None latest_epoch = int(epoch_values.max()) return df[pd.to_numeric(df["Epoch"], errors="coerce") == latest_epoch], latest_epoch def _draw_cm(ax, tp: int, fn: int, fp: int, tn: int, *, title: str) -> None: mat = np.array([[tn, fp], [fn, tp]], dtype=float) im = ax.imshow(mat, cmap="Blues") ax.set_xticks([0, 1], labels=["Pred: Dry", "Pred: Rain"]) ax.set_yticks([0, 1], labels=["True: Dry", "True: Rain"]) ax.set_title(title, fontsize=10, fontweight="bold") ax.set_xlabel("Prediction") ax.set_ylabel("Target") for i in range(2): for j in range(2): ax.text(j, i, f"{int(mat[i, j])}", ha="center", va="center", color="black", fontsize=11) plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) def _phase_metrics(tp: int, fn: int, fp: int, tn: int) -> str: total = tp + fn + fp + tn acc = (tp + tn) / total if total > 0 else float("nan") recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan") precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan") if np.isnan(recall) or np.isnan(precision) or (recall + precision) == 0: f1 = float("nan") else: f1 = 2 * recall * precision / (recall + precision) return ( f"acc={acc:.3f} | recall={recall:.3f} | " f"precision={precision:.3f} | f1={f1:.3f}" ) def _metric_values(tp: int, fn: int, fp: int, tn: int) -> dict[str, float | int]: total = tp + fn + fp + tn acc = (tp + tn) / total if total > 0 else float("nan") recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan") precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan") f1 = ( 2 * recall * precision / (recall + precision) if not np.isnan(recall) and not np.isnan(precision) and (recall + precision) > 0 else float("nan") ) return { "tp": tp, "fn": fn, "fp": fp, "tn": tn, "accuracy": acc, "recall": recall, "precision": precision, "f1": f1, "positive_count": tp + fn, "predicted_positive_count": tp + fp, "total_samples": total, } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--session", type=str, default=None, help="Session folder under results/") parser.add_argument("--phase", type=str, default="both", choices=["TRAIN", "TEST", "both"]) parser.add_argument("--scope", type=str, default="latest", choices=["latest", "all"], help="Use latest epoch only or all epochs") parser.add_argument( "--decision", type=str, default="probability", choices=["probability", "prediction"], help="Use RainProbability or final Prediction for rain/no-rain decision.", ) parser.add_argument("--threshold-mm", type=float, default=None, help="Rain classification threshold in mm") parser.add_argument("--prob-threshold", type=float, default=None, help="Probability threshold when decision=probability") args = parser.parse_args() threshold_mm = rain_threshold_mm() if args.threshold_mm is None else float(args.threshold_mm) prob_threshold = rain_probability_threshold() if args.prob_threshold is None else float(args.prob_threshold) session_dir = _find_session(args.session) logs = _latest_client_logs(session_dir) if not logs: raise FileNotFoundError(f"No final client logs found in {session_dir}") phases = ["TRAIN", "TEST"] if args.phase == "both" else [args.phase] n_rows, n_cols = len(phases), len(logs) + 1 # +1 for aggregated fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.2 * n_cols, 3.8 * n_rows), squeeze=False) rows: list[dict[str, str | int | float]] = [] for row_idx, phase in enumerate(phases): total_tp = total_fn = total_fp = total_tn = 0 for col_idx, cid in enumerate(sorted(logs)): df = pd.read_csv(logs[cid]) phase_df = df[df["Status"] == phase] if "Status" in df.columns else pd.DataFrame() phase_df, selected_epoch = _select_scope(phase_df, scope=args.scope) if phase_df.empty: tp = fn = fp = tn = 0 else: tp, fn, fp, tn = _confusion_counts( phase_df, threshold_mm=threshold_mm, decision=args.decision, prob_threshold=prob_threshold, ) total_tp += tp total_fn += fn total_fp += fp total_tn += tn epoch_tag = f" | E{selected_epoch}" if selected_epoch is not None else "" title = f"Client {cid} - {phase}{epoch_tag}\n{_phase_metrics(tp, fn, fp, tn)}" _draw_cm(axes[row_idx, col_idx], tp, fn, fp, tn, title=title) row = { "session": session_dir.name, "phase": phase, "client_id": cid, "scope": args.scope, "decision": args.decision, "selected_epoch": selected_epoch, } row.update(_metric_values(tp, fn, fp, tn)) rows.append(row) agg_title = f"ALL Clients - {phase}\n{_phase_metrics(total_tp, total_fn, total_fp, total_tn)}" _draw_cm(axes[row_idx, n_cols - 1], total_tp, total_fn, total_fp, total_tn, title=agg_title) agg_row = { "session": session_dir.name, "phase": phase, "client_id": "ALL", "scope": args.scope, "decision": args.decision, "selected_epoch": "ALL", } agg_row.update(_metric_values(total_tp, total_fn, total_fp, total_tn)) rows.append(agg_row) plt.suptitle( f"Confusion Matrices — {session_dir.name} | scope={args.scope} | decision={args.decision} | " f"rain>{threshold_mm:.2f}mm | p>={prob_threshold:.2f}", fontsize=13, fontweight="bold", y=1.02, ) plt.tight_layout() out_path = session_dir / f"confusion_matrix_{session_dir.name}.png" fig.savefig(out_path, dpi=150, bbox_inches="tight") plt.close(fig) print(f"Saved confusion matrix figure: {out_path}") metrics_df = pd.DataFrame(rows) metrics_path = session_dir / f"confusion_matrix_metrics_{session_dir.name}.csv" metrics_df.to_csv(metrics_path, index=False) print(f"Saved confusion matrix metrics: {metrics_path}") # Human-readable text summary for quick inspection without opening images. print("\nConfusion Matrix Metrics") for phase in phases: phase_df = metrics_df[metrics_df["phase"] == phase] print(f"\n[{phase}]") for _, r in phase_df.iterrows(): cid = r["client_id"] recall_txt = "N/A" if np.isnan(r["recall"]) else f"{r['recall']:.3f}" precision_txt = "N/A" if np.isnan(r["precision"]) else f"{r['precision']:.3f}" f1_txt = "N/A" if np.isnan(r["f1"]) else f"{r['f1']:.3f}" print( f" client={cid:<3} TP={int(r['tp']):>6} FN={int(r['fn']):>6} " f"FP={int(r['fp']):>6} TN={int(r['tn']):>6} " f"acc={r['accuracy']:.3f} recall={recall_txt} " f"precision={precision_txt} f1={f1_txt} " f"pos={int(r['positive_count'])}" ) if __name__ == "__main__": main()