csc8114 / code / src / data / plot_confusion_matrix.py
plot_confusion_matrix.py
Raw
"""
plot_confusion_matrix.py
========================
Build confusion matrix plots from client training logs.

Usage:
    python -m src.data.plot_confusion_matrix
    python -m src.data.plot_confusion_matrix --session 2026-03-13_01-53-07
    python -m src.data.plot_confusion_matrix --session 2026-03-13_01-53-07 --phase TEST
"""
from __future__ import annotations

import argparse
import re
import sys
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from src.shared.targets import rain_probability_threshold, rain_threshold_mm


def _find_session(session_id: str | None) -> Path:
    results_dir = project_root / "results"
    sessions = sorted([d for d in results_dir.iterdir() if d.is_dir() and d.name.startswith("20")])
    if not sessions:
        raise FileNotFoundError(f"No session folders found under {results_dir}")
    if session_id:
        target = results_dir / session_id
        if not target.is_dir():
            raise FileNotFoundError(f"Session not found: {target}")
        return target
    return sessions[-1]


def _latest_client_logs(session_dir: Path) -> dict[int, Path]:
    """Scans for client logs in the session directory and recursive subdirectories."""
    # Recursively find training logs, which might be in pi01/, pi02/ subdirs
    # created by 'ansible fetch' or in the root of the session dir.
    all_logs = list(session_dir.rglob("training_log_client*.csv"))
    by_client: dict[int, Path] = {}
    for path in sorted(all_logs):
        if "progress" in path.name:
            continue
        m = re.search(r"training_log_client(\d+)_\d{8}_\d{6}\.csv$", path.name)
        if not m:
            continue
        cid = int(m.group(1))
        by_client[cid] = path
    return by_client


def _confusion_counts(
    df: pd.DataFrame,
    *,
    threshold_mm: float,
    decision: str,
    prob_threshold: float,
) -> tuple[int, int, int, int]:
    y_true = (pd.to_numeric(df["Target"], errors="coerce") > threshold_mm).astype(int)
    if decision == "probability" and "RainProbability" in df.columns:
        probs = pd.to_numeric(df["RainProbability"], errors="coerce").fillna(0.0)
        y_pred = (probs >= prob_threshold).astype(int)
    else:
        preds = pd.to_numeric(df["Prediction"], errors="coerce").fillna(0.0)
        y_pred = (preds > threshold_mm).astype(int)
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    return tp, fn, fp, tn


def _select_scope(df: pd.DataFrame, *, scope: str) -> tuple[pd.DataFrame, int | None]:
    if scope != "latest":
        return df, None
    if "Epoch" not in df.columns or df.empty:
        return df, None
    epoch_values = pd.to_numeric(df["Epoch"], errors="coerce").dropna()
    if epoch_values.empty:
        return df, None
    latest_epoch = int(epoch_values.max())
    return df[pd.to_numeric(df["Epoch"], errors="coerce") == latest_epoch], latest_epoch


def _draw_cm(ax, tp: int, fn: int, fp: int, tn: int, *, title: str) -> None:
    mat = np.array([[tn, fp], [fn, tp]], dtype=float)
    im = ax.imshow(mat, cmap="Blues")
    ax.set_xticks([0, 1], labels=["Pred: Dry", "Pred: Rain"])
    ax.set_yticks([0, 1], labels=["True: Dry", "True: Rain"])
    ax.set_title(title, fontsize=10, fontweight="bold")
    ax.set_xlabel("Prediction")
    ax.set_ylabel("Target")
    for i in range(2):
        for j in range(2):
            ax.text(j, i, f"{int(mat[i, j])}", ha="center", va="center", color="black", fontsize=11)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)


def _phase_metrics(tp: int, fn: int, fp: int, tn: int) -> str:
    total = tp + fn + fp + tn
    acc = (tp + tn) / total if total > 0 else float("nan")
    recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
    if np.isnan(recall) or np.isnan(precision) or (recall + precision) == 0:
        f1 = float("nan")
    else:
        f1 = 2 * recall * precision / (recall + precision)
    return (
        f"acc={acc:.3f} | recall={recall:.3f} | "
        f"precision={precision:.3f} | f1={f1:.3f}"
    )


def _metric_values(tp: int, fn: int, fp: int, tn: int) -> dict[str, float | int]:
    total = tp + fn + fp + tn
    acc = (tp + tn) / total if total > 0 else float("nan")
    recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan")
    precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan")
    f1 = (
        2 * recall * precision / (recall + precision)
        if not np.isnan(recall) and not np.isnan(precision) and (recall + precision) > 0
        else float("nan")
    )
    return {
        "tp": tp,
        "fn": fn,
        "fp": fp,
        "tn": tn,
        "accuracy": acc,
        "recall": recall,
        "precision": precision,
        "f1": f1,
        "positive_count": tp + fn,
        "predicted_positive_count": tp + fp,
        "total_samples": total,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--session", type=str, default=None, help="Session folder under results/")
    parser.add_argument("--phase", type=str, default="both", choices=["TRAIN", "TEST", "both"])
    parser.add_argument("--scope", type=str, default="latest", choices=["latest", "all"], help="Use latest epoch only or all epochs")
    parser.add_argument(
        "--decision",
        type=str,
        default="probability",
        choices=["probability", "prediction"],
        help="Use RainProbability or final Prediction for rain/no-rain decision.",
    )
    parser.add_argument("--threshold-mm", type=float, default=None, help="Rain classification threshold in mm")
    parser.add_argument("--prob-threshold", type=float, default=None, help="Probability threshold when decision=probability")
    args = parser.parse_args()
    threshold_mm = rain_threshold_mm() if args.threshold_mm is None else float(args.threshold_mm)
    prob_threshold = rain_probability_threshold() if args.prob_threshold is None else float(args.prob_threshold)

    session_dir = _find_session(args.session)
    logs = _latest_client_logs(session_dir)
    if not logs:
        raise FileNotFoundError(f"No final client logs found in {session_dir}")

    phases = ["TRAIN", "TEST"] if args.phase == "both" else [args.phase]
    n_rows, n_cols = len(phases), len(logs) + 1  # +1 for aggregated
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.2 * n_cols, 3.8 * n_rows), squeeze=False)
    rows: list[dict[str, str | int | float]] = []

    for row_idx, phase in enumerate(phases):
        total_tp = total_fn = total_fp = total_tn = 0
        for col_idx, cid in enumerate(sorted(logs)):
            df = pd.read_csv(logs[cid])
            phase_df = df[df["Status"] == phase] if "Status" in df.columns else pd.DataFrame()
            phase_df, selected_epoch = _select_scope(phase_df, scope=args.scope)
            if phase_df.empty:
                tp = fn = fp = tn = 0
            else:
                tp, fn, fp, tn = _confusion_counts(
                    phase_df,
                    threshold_mm=threshold_mm,
                    decision=args.decision,
                    prob_threshold=prob_threshold,
                )
            total_tp += tp
            total_fn += fn
            total_fp += fp
            total_tn += tn
            epoch_tag = f" | E{selected_epoch}" if selected_epoch is not None else ""
            title = f"Client {cid} - {phase}{epoch_tag}\n{_phase_metrics(tp, fn, fp, tn)}"
            _draw_cm(axes[row_idx, col_idx], tp, fn, fp, tn, title=title)
            row = {
                "session": session_dir.name,
                "phase": phase,
                "client_id": cid,
                "scope": args.scope,
                "decision": args.decision,
                "selected_epoch": selected_epoch,
            }
            row.update(_metric_values(tp, fn, fp, tn))
            rows.append(row)

        agg_title = f"ALL Clients - {phase}\n{_phase_metrics(total_tp, total_fn, total_fp, total_tn)}"
        _draw_cm(axes[row_idx, n_cols - 1], total_tp, total_fn, total_fp, total_tn, title=agg_title)
        agg_row = {
            "session": session_dir.name,
            "phase": phase,
            "client_id": "ALL",
            "scope": args.scope,
            "decision": args.decision,
            "selected_epoch": "ALL",
        }
        agg_row.update(_metric_values(total_tp, total_fn, total_fp, total_tn))
        rows.append(agg_row)

    plt.suptitle(
        f"Confusion Matrices — {session_dir.name} | scope={args.scope} | decision={args.decision} | "
        f"rain>{threshold_mm:.2f}mm | p>={prob_threshold:.2f}",
        fontsize=13,
        fontweight="bold",
        y=1.02,
    )
    plt.tight_layout()

    out_path = session_dir / f"confusion_matrix_{session_dir.name}.png"
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved confusion matrix figure: {out_path}")

    metrics_df = pd.DataFrame(rows)
    metrics_path = session_dir / f"confusion_matrix_metrics_{session_dir.name}.csv"
    metrics_df.to_csv(metrics_path, index=False)
    print(f"Saved confusion matrix metrics: {metrics_path}")

    # Human-readable text summary for quick inspection without opening images.
    print("\nConfusion Matrix Metrics")
    for phase in phases:
        phase_df = metrics_df[metrics_df["phase"] == phase]
        print(f"\n[{phase}]")
        for _, r in phase_df.iterrows():
            cid = r["client_id"]
            recall_txt = "N/A" if np.isnan(r["recall"]) else f"{r['recall']:.3f}"
            precision_txt = "N/A" if np.isnan(r["precision"]) else f"{r['precision']:.3f}"
            f1_txt = "N/A" if np.isnan(r["f1"]) else f"{r['f1']:.3f}"
            print(
                f"  client={cid:<3} TP={int(r['tp']):>6} FN={int(r['fn']):>6} "
                f"FP={int(r['fp']):>6} TN={int(r['tn']):>6} "
                f"acc={r['accuracy']:.3f} recall={recall_txt} "
                f"precision={precision_txt} f1={f1_txt} "
                f"pos={int(r['positive_count'])}"
            )


if __name__ == "__main__":
    main()