"""
plot_server_metrics.py
======================
Reads the server_log_<session>.csv and produces a multi-panel
dashboard covering:
Panel 1 — Loss over rounds (train vs test, per client)
Panel 2 — Rain classification acc (rain_correct per round)
Panel 3 — Latency breakdown (decomp + comp time per round)
Panel 4 — Gradient magnitude (per round, train only)
Usage:
uv run python src/data/plot_server_metrics.py
uv run python src/data/plot_server_metrics.py --log results/2026-03-12_15-05-55/server_log_2026-03-12_15-05-55.csv
"""
import sys
import glob
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
# ── paths ────────────────────────────────────────────────────────────────────
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.append(str(project_root))
from src.shared.targets import is_rain
def _find_latest_log() -> Path:
all_csvs = list((project_root / "results").rglob("server_log_*.csv"))
if not all_csvs:
raise FileNotFoundError(
f"No server_log_*.csv found in {project_root / 'results'}"
)
logs = sorted(all_csvs, key=lambda p: (p.stat().st_mtime, p.name))
latest = logs[-1]
print(f"[DEBUG] Picking latest log: {latest.relative_to(project_root)}")
return latest
# ─────────────────────────────────────────────────────────────────────────────
# Constants used for dashboard labels/heuristics
TRAIN_DAYS = 20
VAL_DAYS = 5
TEST_DAYS = 5
SEQ_LEN = 48
HORIZON = 1
DEVICE = "cpu"
def plot_server_metrics(log_path: Path):
df = pd.read_csv(log_path)
print(f"[INFO] Loaded {len(df):,} rows from {log_path.name}")
# ── Compatibility: older logs without 'round' column ─────────────────────
if "round" not in df.columns:
# Bin rows into N equal segments so the chart stays readable
N_BINS = 20
df = df.copy()
df["round"] = pd.cut(
np.arange(len(df)), bins=N_BINS, labels=False
)
print(f"[WARN] 'round' column missing — binned {len(df):,} rows into {N_BINS} segments")
if "rain_correct" not in df.columns:
# Derive on-the-fly from target & prediction
df["rain_correct"] = [int(is_rain(t) == is_rain(p)) for t, p in zip(df["target"], df["prediction"])]
session_id = log_path.stem.replace("server_log_", "")
# ── Split train / test ────────────────────────────────────────────────────
df_train = df[df["is_training"] == 1] if "is_training" in df.columns else df
df_test = df[df["is_training"] == 0] if "is_training" in df.columns else pd.DataFrame()
# ── Per-round aggregation helper ─────────────────────────────────────────
def per_round(sub: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
return sub.groupby("round")[cols].mean().reset_index()
palette = plt.cm.tab10.colors
client_ids = sorted(df["client_id"].unique()) if "client_id" in df.columns else [0]
# ─────────────────────────────────────────────────────────────────────────
# Figure: 2 rows × 2 cols
# ─────────────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
fig.suptitle(
f"Server Metrics Dashboard — Session {session_id}",
fontsize=14, fontweight="bold", y=1.01
)
# ── Panel 1: Loss over rounds (per client) ────────────────────────────────
ax1 = axes[0, 0]
for i, cid in enumerate(client_ids):
color = palette[i % len(palette)]
sub = df_train[df_train["client_id"] == cid]
if sub.empty:
continue
agg = per_round(sub, ["loss"])
ax1.plot(agg["round"], agg["loss"], "o-",
color=color, linewidth=1.8, markersize=4,
label=f"Client {int(cid)} (train)")
if not df_test.empty:
sub_t = df_test[df_test["client_id"] == cid]
if not sub_t.empty:
agg_t = per_round(sub_t, ["loss"])
ax1.plot(agg_t["round"], agg_t["loss"], "s--",
color=color, linewidth=1.2, markersize=3, alpha=0.6,
label=f"Client {int(cid)} (test)")
ax1.set_title("📉 Loss per Round", fontsize=11, fontweight="bold")
ax1.set_xlabel("Round"); ax1.set_ylabel("MSE Loss")
ax1.legend(fontsize=7); ax1.grid(True, alpha=0.3, linestyle="--")
ax1.set_facecolor("#f8f8f8")
ax1.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
# ── Panel 2: Accuracy (rain classification) ───────────────────────────────
ax2 = axes[0, 1]
for i, cid in enumerate(client_ids):
color = palette[i % len(palette)]
sub = df_train[df_train["client_id"] == cid]
if sub.empty:
continue
agg = per_round(sub, ["rain_correct"])
ax2.plot(agg["round"], agg["rain_correct"] * 100, "o-",
color=color, linewidth=1.8, markersize=4,
label=f"Client {int(cid)}")
# Rain-only accuracy
if "rain_flag" in df.columns:
rain_df = df_train[df_train["rain_flag"] == 1]
else:
rain_df = df_train[df_train["target"].apply(is_rain)]
if not rain_df.empty:
agg_rain = per_round(rain_df, ["rain_correct"])
ax2.plot(agg_rain["round"], agg_rain["rain_correct"] * 100, "k^--",
linewidth=1.2, markersize=4, alpha=0.5, label="🌧 Rain only")
ax2.axhline(50, color="red", linestyle=":", linewidth=1, alpha=0.5, label="Random baseline")
ax2.set_title("☔ Rain Classification Accuracy", fontsize=11, fontweight="bold")
ax2.set_xlabel("Round"); ax2.set_ylabel("Accuracy (%)")
ax2.set_ylim(0, 105)
ax2.legend(fontsize=7); ax2.grid(True, alpha=0.3, linestyle="--")
ax2.set_facecolor("#f8f8f8")
ax2.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
# ── Panel 3: Latency (decomp + compute) per round ─────────────────────────
ax3 = axes[1, 0]
agg_latency = per_round(df_train, ["decompression_time_ms", "computation_time_ms"])
x = agg_latency["round"].values
decomp = agg_latency["decompression_time_ms"].values
compute = agg_latency["computation_time_ms"].values
ax3.bar(x, decomp, label="Decomp (ms)", color="#5b8dee", alpha=0.8, width=0.4)
ax3.bar(x, compute, bottom=decomp,
label="Compute (ms)", color="#f9a825", alpha=0.8, width=0.4)
total_avg = (decomp + compute).mean()
ax3.axhline(total_avg, color="red", linestyle="--", linewidth=1.2,
label=f"Avg total = {total_avg:.2f} ms")
ax3.set_title("⚡ Latency per Round (Train)", fontsize=11, fontweight="bold")
ax3.set_xlabel("Round"); ax3.set_ylabel("Time (ms)")
ax3.legend(fontsize=7); ax3.grid(True, alpha=0.2, axis="y", linestyle="--")
ax3.set_facecolor("#f8f8f8")
ax3.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
# ── Panel 4: Gradient magnitude per round (train only) ────────────────────
ax4 = axes[1, 1]
agg_grad = per_round(df_train, ["gradient_magnitude"])
ax4.plot(agg_grad["round"], agg_grad["gradient_magnitude"], "D-",
color="#7b2d8b", linewidth=1.8, markersize=4)
ax4.fill_between(agg_grad["round"], 0, agg_grad["gradient_magnitude"],
color="#7b2d8b", alpha=0.1)
grad_mean = agg_grad["gradient_magnitude"].mean()
ax4.axhline(grad_mean, color="grey", linestyle="--", linewidth=1,
label=f"Mean = {grad_mean:.4f}")
# Warn if many zeros
zero_pct = (df_train["gradient_magnitude"] == 0).mean() * 100
if zero_pct > 20:
ax4.text(0.5, 0.85, f"⚠️ {zero_pct:.0f}% zeros (test passes included?)",
transform=ax4.transAxes, ha="center", fontsize=8, color="red")
ax4.set_title("∇ Gradient Magnitude per Round (Train)", fontsize=11, fontweight="bold")
ax4.set_xlabel("Round"); ax4.set_ylabel("||grad||")
ax4.legend(fontsize=7); ax4.grid(True, alpha=0.3, linestyle="--")
ax4.set_facecolor("#f8f8f8")
ax4.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
plt.tight_layout()
out_dir = project_root / "results" / session_id
out_dir.mkdir(parents=True, exist_ok=True)
out_path = out_dir / f"server_metrics_{session_id}.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"\n✅ Dashboard saved: {out_path}")
# ── Text summary ─────────────────────────────────────────────────────────
print("\n── Server Metrics Summary ────────────────────────────────────────")
print(f" Total log entries : {len(df):,} "
f"(train={len(df_train):,}, test={len(df_test):,})")
if not df_train.empty:
avg_loss = df_train["loss"].mean()
avg_acc = df_train["rain_correct"].mean() * 100
avg_decomp = df_train["decompression_time_ms"].mean()
avg_compute = df_train["computation_time_ms"].mean()
avg_grad = df_train["gradient_magnitude"].mean()
print(f" Avg Train Loss : {avg_loss:.4f}")
print(f" Avg Accuracy : {avg_acc:.2f}%")
print(f" Avg Latency : decomp={avg_decomp:.2f} ms "
f"compute={avg_compute:.2f} ms "
f"total={avg_decomp+avg_compute:.2f} ms")
print(f" Avg Gradient Mag : {avg_grad:.4f}")
print()
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--log", type=str, default=None,
help="Path to server_log_*.csv. Defaults to the latest one in results/"
)
args = parser.parse_args()
log_path = Path(args.log) if args.log else _find_latest_log()
if not log_path.exists():
print(f"[ERROR] Log file not found: {log_path}")
sys.exit(1)
plot_server_metrics(log_path)
if __name__ == "__main__":
main()