import sys
import glob
import json
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, average_precision_score, f1_score, roc_auc_score
# --- Configuration & Paths ---
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.append(str(project_root))
from src.shared.common import cfg, feature_cols_from_cfg, get_nested
from src.client.data_pipeline import (
FUTURE_RAIN_COL,
collect_eval_indices_capped,
collect_test_indices_capped,
get_dataset_split,
load_sensor_data,
partition_client_files,
)
from src.models.split_lstm import ClientLSTM, ServerHead
from src.shared.targets import (
inverse_target_scalar,
is_rain,
rain_probability_threshold,
rain_threshold_mm,
target_transform_mode,
)
def _normalize_report_tag(tag: str) -> str:
raw = str(tag).strip()
if not raw:
return ""
safe = "".join(ch if (ch.isalnum() or ch in {"_", "-"}) else "_" for ch in raw)
return safe.strip("_")
def _parse_threshold_list(spec: str) -> list[float]:
raw = str(spec).strip()
if not raw:
return []
if ":" in raw:
parts = [p.strip() for p in raw.split(":")]
if len(parts) != 3:
raise ValueError("--scan-thresholds with ':' must be start:end:step")
start, end, step = (float(parts[0]), float(parts[1]), float(parts[2]))
if step <= 0:
raise ValueError("--scan-thresholds step must be > 0")
values: list[float] = []
x = start
# epsilon for float endpoint inclusion
while x <= end + 1e-12:
values.append(float(round(x, 6)))
x += step
return [v for v in values if 0.0 <= v <= 1.0]
values = []
for token in raw.split(","):
t = token.strip()
if not t:
continue
v = float(t)
if not (0.0 <= v <= 1.0):
raise ValueError(f"threshold out of range [0,1]: {v}")
values.append(float(round(v, 6)))
return sorted(set(values))
def _parse_timestamp(path: str) -> str:
"""
Extract the trailing timestamp token from any model filename.
server_head_round_9_20260312131020.pth -> '20260312131020'
best_client_1_round_9_model_20260312131018.pth -> '20260312131018'
"""
return Path(path).stem.split("_")[-1]
def _parse_round(path: str) -> int:
"""
Extract the round number from a client or server filename.
best_client_1_round_9_model_20260312131018.pth -> 9
server_head_round_9_20260312131020.pth -> 9
Returns 0 if not parseable.
"""
parts = Path(path).stem.split("_")
try:
# 'round' token is always followed by the numeric round number
idx = parts.index("round")
return int(parts[idx + 1])
except (ValueError, IndexError):
return 0
def _find_latest_session_id() -> str:
bw_dir = project_root / "bestweights"
sessions = sorted([d for d in bw_dir.glob("20*") if d.is_dir()], key=lambda p: p.name)
if not sessions:
raise FileNotFoundError(f"No session folders found under {bw_dir}")
return sessions[-1].name
def find_periodic_pair(
*,
session_id: str,
num_clients: int | None = None,
target_round: int | None = None,
scenario_id: str | None = None,
) -> tuple[int, str, dict[int, str]] | None:
"""
Find a strictly paired periodic checkpoint set:
- server_round_<R>.pth
- client_<cid>_round_<R>.pth for all available clients at round R,
or all cid in [1..num_clients] if num_clients is provided.
If scenario_id is given, only the matching scenario subdirectory is searched.
"""
session_dir = project_root / "bestweights" / session_id
if not session_dir.is_dir():
return None
periodic_roots: list[Path] = []
if scenario_id:
# Narrow search to the specific scenario subdirectory only.
scenario_periodic = session_dir / scenario_id / "periodic"
if scenario_periodic.is_dir():
periodic_roots.append(scenario_periodic)
else:
direct = session_dir / "periodic"
if direct.is_dir():
periodic_roots.append(direct)
for sub in sorted(session_dir.glob("*/periodic")):
if sub.is_dir():
periodic_roots.append(sub)
if not periodic_roots:
return None
best_candidate: tuple[int, float, str, dict[int, str]] | None = None
# Sort newest-first so ties prefer most recent scenario folder.
periodic_roots.sort(key=lambda p: p.stat().st_mtime, reverse=True)
for periodic_dir in periodic_roots:
server_by_round: dict[int, str] = {}
for p in periodic_dir.glob("server_round_*.pth"):
try:
r = int(p.stem.split("_")[-1])
except ValueError:
continue
server_by_round[r] = str(p)
if not server_by_round:
continue
client_by_round: dict[int, dict[int, str]] = {}
for p in periodic_dir.glob("client_*_round_*.pth"):
parts = p.stem.split("_")
try:
cid = int(parts[1])
rnd = int(parts[3])
except (IndexError, ValueError):
continue
client_by_round.setdefault(cid, {})[rnd] = str(p)
if target_round is not None:
candidate_rounds = [target_round]
else:
candidate_rounds = sorted(server_by_round.keys(), reverse=True)
for rnd in candidate_rounds:
server_path = server_by_round.get(rnd)
if not server_path:
continue
cmap: dict[int, str] = {}
if num_clients is None:
expected_cids = sorted(cid for cid, by_round in client_by_round.items() if rnd in by_round)
else:
expected_cids = list(range(1, num_clients + 1))
if not expected_cids:
continue
ok = True
for cid in expected_cids:
cpath = client_by_round.get(cid, {}).get(rnd)
if not cpath:
ok = False
break
cmap[cid] = cpath
if not ok:
continue
candidate = (rnd, periodic_dir.stat().st_mtime, server_path, cmap)
if best_candidate is None or candidate[:2] > best_candidate[:2]:
best_candidate = candidate
break
if best_candidate is None:
return None
return best_candidate[0], best_candidate[2], best_candidate[3]
def find_best_server(session_id: str | None = None) -> str:
"""
Find the most recent server model in bestweights/.
Supports both layouts:
Flat : bestweights/server_head_round_<N>_<stamp>.pth
Session: bestweights/<session>/server_head_round_<N>_<stamp>.pth
Sorting priority: session-subdir mtime (if present) > embedded timestamp.
Returns the path of the latest server model.
"""
bw_dir = project_root / "bestweights"
if not bw_dir.exists():
raise FileNotFoundError(f"Cannot find bestweights directory at {bw_dir}")
if session_id:
session_dir = bw_dir / session_id
if not session_dir.is_dir():
raise FileNotFoundError(f"Session not found under bestweights/: {session_id}")
# Support both flat layout and nested scenario layout (e.g. session/01_seed42/server_head...)
all_paths = glob.glob(str(session_dir / "server_head_round_*.pth"))
if not all_paths:
all_paths = glob.glob(str(session_dir / "**/server_head_round_*.pth"), recursive=True)
else:
# Collect from both flat root and one level of subdirectories
flat_paths = glob.glob(str(bw_dir / "server_head_round_*.pth"))
subdir_paths = glob.glob(str(bw_dir / "*" / "server_head_round_*.pth"))
all_paths = flat_paths + subdir_paths
if not all_paths:
raise FileNotFoundError(
f"No server weights found in {bw_dir}.\n"
"Expected: server_head_round_<N>_<stamp>.pth (flat or inside a session subdir)"
)
def _sort_key(p: str):
parent = Path(p).parent
folder_mtime = parent.stat().st_mtime if parent != bw_dir else 0.0
return (folder_mtime, _parse_timestamp(p), _parse_round(p))
all_paths.sort(key=_sort_key)
latest_server = all_paths[-1]
parent = Path(latest_server).parent
rel = f"{parent.name}/{Path(latest_server).name}" if parent != bw_dir else Path(latest_server).name
print(f"\u2705 Latest Server Model : {rel}")
return latest_server
def find_matching_clients(server_path: str, *, allow_cross_session_fallback: bool = False) -> dict[int, str]:
"""
Find the best client model for every client ID.
Searches in the same directory as the server model (flat or session subdir).
"Best" = highest round number saved (last checkpoint to beat previous best loss).
If multiple files share the same round, the one with the latest timestamp wins.
Falls back to searching ALL of bestweights/ if the server's dir has no client files.
"""
bw_dir = project_root / "bestweights"
server_dir = Path(server_path).parent # flat root OR session subdir
is_session = server_dir != bw_dir
location = server_dir.name if is_session else "(flat bestweights/)"
print(f"\U0001f511 Looking for clients in : {location}")
def _collect_ids(file_list):
ids = set()
for f in file_list:
parts = Path(f).stem.split("_")
try:
ids.add(int(parts[2])) # index 2 is always the numeric client ID
except (IndexError, ValueError):
pass
return ids
# Primary search: same directory as server or subdirectories
primary_files = glob.glob(str(server_dir / "best_client_*_round_*_model_*.pth"))
if not primary_files:
primary_files = glob.glob(str(server_dir / "**/best_client_*_round_*_model_*.pth"), recursive=True)
client_ids = _collect_ids(primary_files)
if not client_ids and allow_cross_session_fallback:
# Fallback: the entire bestweights tree
print("\u26a0\ufe0f No client files in server dir. Searching all of bestweights/...")
fallback_files = (
glob.glob(str(bw_dir / "best_client_*_round_*_model_*.pth")) +
glob.glob(str(bw_dir / "*" / "best_client_*_round_*_model_*.pth"))
)
client_ids = _collect_ids(fallback_files)
search_root = None # signal: use global search per client
elif not client_ids:
raise FileNotFoundError(
f"No client weights found in server session directory: {server_dir}\n"
"Use --allow-cross-session-fallback to scan all bestweights/ (less strict)."
)
else:
search_root = server_dir
if not client_ids:
raise FileNotFoundError(
"No client weight files found anywhere in bestweights/.\n"
"Expected: best_client_<ID>_round_<N>_model_<stamp>.pth"
)
matched: dict[int, str] = {}
for cid in sorted(client_ids):
if search_root is not None:
candidates = glob.glob(str(search_root / f"best_client_{cid}_round_*_model_*.pth"))
else:
candidates = (
glob.glob(str(bw_dir / f"best_client_{cid}_round_*_model_*.pth")) +
glob.glob(str(bw_dir / "*" / f"best_client_{cid}_round_*_model_*.pth"))
)
if not candidates:
print(f"\u26a0\ufe0f Client {cid}: No model files found \u2014 skipping")
continue
# Pick the file with the highest round number;
# break ties by latest timestamp (lexicographic on the stamp token)
candidates.sort(key=lambda p: (_parse_round(p), _parse_timestamp(p)))
chosen = candidates[-1]
tag = f"round {_parse_round(chosen)}, ts={_parse_timestamp(chosen)}"
print(f"\U0001f464 Client {cid} Best Model : {Path(chosen).name} ({tag})")
matched[cid] = chosen
if not matched:
raise FileNotFoundError("Could not find any client weight files.")
return matched
def _metrics_from_cm(tp: int, fn: int, fp: int, tn: int) -> dict[str, float | int]:
total = max(1, tp + fn + fp + tn)
accuracy = float((tp + tn) / total * 100.0)
recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0
f1 = (
float(2.0 * recall * precision / (recall + precision))
if (recall + precision) > 0
else 0.0
)
return {
"accuracy": accuracy,
"recall": recall,
"precision": precision,
"f1": f1,
"tp": int(tp),
"fn": int(fn),
"fp": int(fp),
"tn": int(tn),
}
def _class_metrics_at_threshold(
*,
y_true: np.ndarray,
probs: np.ndarray,
threshold: float,
) -> dict[str, float | int]:
y_pred = (probs >= float(threshold)).astype(np.int32)
tp = int(((y_true == 1) & (y_pred == 1)).sum())
fn = int(((y_true == 1) & (y_pred == 0)).sum())
fp = int(((y_true == 0) & (y_pred == 1)).sum())
tn = int(((y_true == 0) & (y_pred == 0)).sum())
metrics = _metrics_from_cm(tp, fn, fp, tn)
metrics["pred_positive_rate"] = float(y_pred.mean())
metrics["threshold"] = float(threshold)
return metrics
def _resolve_eval_settings(server_raw: dict | object) -> tuple[dict, str]:
snapshot = (
server_raw.get("config_snapshot")
if isinstance(server_raw, dict) and isinstance(server_raw.get("config_snapshot"), dict)
else None
)
source = "checkpoint_snapshot" if snapshot is not None else "runtime_config"
def _setting(path: tuple[str, ...], default):
if snapshot is not None:
snap_value = get_nested(snapshot, path, None)
if snap_value is not None:
return snap_value
return get_nested(cfg, path, default)
settings = {
"end_date_str": str(_setting(("data_download", "end_date"), "2026-03-10T00:00:00")),
"eval_max_samples": max(0, int(_setting(("training", "eval_max_samples_per_sensor"), 0))),
"seq_len": int(_setting(("model", "seq_len"), 24)),
"horizon": max(1, int(_setting(("model", "horizon"), 3))),
"input_size": int(_setting(("model", "input_size"), 5)),
"lstm_dropout": float(_setting(("model", "lstm_dropout"), _setting(("model", "dropout"), 0.3))),
"hidden_size": int(_setting(("model", "hidden_size"), 64)),
"head_width": int(_setting(("model", "server_head_width"), 64)),
"head_dropout": float(_setting(("model", "server_head_dropout"), 0.1)),
"num_clients": max(1, int(_setting(("federated", "num_clients"), 1))),
"processed_dir": str(_setting(("data", "processed_dir"), "dataset/processed")),
"active_features": feature_cols_from_cfg(snapshot),
"prob_threshold": float(
_setting(("training", "rain_probability_threshold"), rain_probability_threshold())
),
"rain_threshold": float(
_setting(("training", "rain_threshold_mm"), rain_threshold_mm())
),
"target_mode": str(
_setting(("training", "target_transform"), target_transform_mode())
).strip().lower(),
"config_snapshot_used": snapshot is not None,
}
server_cfg = server_raw.get("config", {}) if isinstance(server_raw, dict) else {}
if isinstance(server_cfg, dict) and server_cfg.get("hidden_size") is not None:
settings["hidden_size"] = int(server_cfg.get("hidden_size"))
val_end_str = str(_setting(("data", "val_end"), "2025-07-01"))
split_date = pd.Timestamp(val_end_str)
settings["split_date"] = split_date
return settings, source
def evaluate_client(
client_id: int,
client_path: str,
server_model: nn.Module,
device: torch.device,
split_date: pd.Timestamp,
eval_max_samples: int,
seq_len: int,
horizon: int,
input_size: int,
lstm_dropout: float,
hidden_size: int,
num_clients: int,
processed_dir: str,
active_features: list[str],
prob_threshold: float,
force_prob_threshold: float | None,
prefer_checkpoint_threshold: bool,
eval_phase: str,
scan_thresholds: list[float] | None,
rain_threshold: float,
target_mode: str,
) -> dict:
"""
Load one client model and evaluate it against all sensor files.
Returns a dict with mse, mae, accuracy, and sample count.
"""
criterion = nn.MSELoss()
# ── Load checkpoint (supports both new dict format and old bare state_dict) ──
raw = torch.load(client_path, map_location="cpu", weights_only=True)
client_prob_threshold: float | None = None
if isinstance(raw, dict) and "model_state_dict" in raw:
# New format: full checkpoint dict
state_dict = raw["model_state_dict"]
saved_round = raw.get("round", "?") # best round recorded
saved_loss = raw.get("loss", float("nan")) # best train loss recorded
ckpt_cfg = raw.get("config", {})
hidden_size = ckpt_cfg.get("hidden_size", hidden_size) # override if saved
cls_metrics = raw.get("classification_metrics", {})
if isinstance(cls_metrics, dict) and cls_metrics.get("threshold") is not None:
try:
client_prob_threshold = float(cls_metrics.get("threshold"))
except (TypeError, ValueError):
client_prob_threshold = None
print(f"\n[Client {client_id}] \U0001f4c4 Checkpoint dict \u2014 best_round={saved_round}, train_loss={saved_loss:.4f}")
else:
# Old format: bare state_dict
state_dict = raw
saved_round = "N/A"
saved_loss = float("nan")
print(f"\n[Client {client_id}] \U0001f4c4 Legacy checkpoint (bare state_dict)")
if force_prob_threshold is not None:
effective_prob_threshold = float(force_prob_threshold)
prob_threshold_source = "forced_cli"
elif prefer_checkpoint_threshold and client_prob_threshold is not None:
effective_prob_threshold = client_prob_threshold
prob_threshold_source = "client_checkpoint"
else:
effective_prob_threshold = prob_threshold
prob_threshold_source = "config"
num_layers = sum(1 for k in state_dict if k.startswith("lstm.weight_ih_l"))
client_model = ClientLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
lstm_dropout=lstm_dropout,
).to(device)
client_model.load_state_dict(state_dict)
client_model.eval()
print(f"[Client {client_id}] Architecture: hidden={hidden_size}, layers={num_layers}")
data_dir = project_root / processed_dir
all_files = sorted(data_dir.glob("*.parquet"))
if not all_files:
print(f" [ERROR] No parquet files found in {data_dir}")
return {}
if 1 <= client_id <= num_clients:
client_files = partition_client_files(
[str(p) for p in all_files],
client_id=client_id,
num_clients=num_clients,
)
client_files = [Path(p) for p in client_files]
else:
print(
f" [WARNING] Client ID {client_id} is outside configured range 1..{num_clients}. "
"Falling back to full dataset for evaluation."
)
client_files = all_files
print(f"[Client {client_id}] Evaluating {len(client_files)}/{len(all_files)} sensor files")
if not client_files:
print(f" [ERROR] Client {client_id}: assigned 0 sensor files for evaluation.")
return {}
total_loss, total_batches = 0.0, 0
all_targets, all_preds, all_probs = [], [], []
# Initialize buckets for 12 months
month_stats = {m: {"loss": 0.0, "batches": 0, "targets": [], "probs": []} for m in range(1, 13)}
sensor_data_cache: dict[Path, pd.DataFrame] = {}
for file in client_files:
sensor_data_cache[file] = load_sensor_data(str(file), horizon=horizon)
# Use TRAIN-period rows only — matches the normalization used during training
train_frames = [
df[np.array([get_dataset_split(ts) == "TRAIN" for ts in df.index])]
for df in sensor_data_cache.values()
]
train_combined = pd.concat([f for f in train_frames if not f.empty])
feat_mean = train_combined[active_features].mean().values
feat_std = train_combined[active_features].std().values + 1e-9
with torch.no_grad():
for file in client_files:
sensor_id = file.stem
df = sensor_data_cache[file]
if eval_phase == "VAL":
eval_indices = collect_eval_indices_capped(
df,
target_phase="VAL",
eval_max_samples=eval_max_samples,
min_history=seq_len,
horizon=horizon,
)
else:
eval_indices = collect_test_indices_capped(
df,
eval_max_samples=eval_max_samples,
min_history=seq_len,
horizon=horizon,
)
if len(eval_indices) == 0:
print(f" [WARNING] No valid {eval_phase} indices in {sensor_id} — skipped")
continue
for idx in eval_indices:
# Resolve month from current timestamp
# Note: idx is the index label in indices, we need the timestamp at that index
try:
ts = df.index[int(idx)]
m_idx = ts.month
except:
m_idx = 1 # Fallback
target_val = float(df[FUTURE_RAIN_COL].iloc[int(idx)])
window_data = df.iloc[int(idx) - seq_len : int(idx)]
features = (
window_data[active_features]
.apply(pd.to_numeric, errors="coerce")
.fillna(0)
.values
)
features = (features - feat_mean) / feat_std
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
y = torch.tensor([[target_val]], dtype=torch.float32).to(device)
smashed = client_model(x)
rain_logit, rain_amount = server_model(smashed)
rain_prob = torch.sigmoid(rain_logit).item()
raw_pred_val = inverse_target_scalar(rain_amount.item(), mode=target_mode)
pred_val = raw_pred_val if rain_prob >= effective_prob_threshold else 0.0
loss = criterion(torch.tensor([[pred_val]], device=device), y).item()
# Monthly collectors
month_stats[m_idx]["loss"] += loss
month_stats[m_idx]["batches"] += 1
month_stats[m_idx]["targets"].append(target_val)
month_stats[m_idx]["probs"].append(rain_prob)
# Global collectors
all_targets.append(target_val)
all_probs.append(rain_prob)
all_preds.append(pred_val)
total_loss += loss
total_batches += 1
# Calculate Final Monthly Details
monthly_details = []
for m in range(1, 13):
m_data = month_stats[m]
if m_data["batches"] > 0:
m_mse = m_data["loss"] / m_data["batches"]
m_y_true = [1 if is_rain(t, threshold=rain_threshold) else 0 for t in m_data["targets"]]
m_y_pred = [1 if p >= effective_prob_threshold else 0 for p in m_data["probs"]]
m_acc = accuracy_score(m_y_true, m_y_pred)
m_f1 = f1_score(m_y_true, m_y_pred, zero_division=0)
monthly_details.append({
"Month": f"{m:02d}",
"MSE": round(m_mse, 6),
"Acc": round(m_acc, 4),
"F1": round(m_f1, 4),
"Samples": m_data["batches"]
})
if total_batches == 0:
print(f" [ERROR] Client {client_id}: No test samples evaluated.")
return {}
targets_arr = np.array(all_targets)
preds_arr = np.array(all_preds)
mse = total_loss / total_batches
mae = float(np.mean(np.abs(targets_arr - preds_arr)))
rain_mask = np.array([is_rain(t, threshold=rain_threshold) for t in targets_arr])
if rain_mask.any():
rain_mse = float(np.mean((targets_arr[rain_mask] - preds_arr[rain_mask]) ** 2))
rain_mae = float(np.mean(np.abs(targets_arr[rain_mask] - preds_arr[rain_mask])))
else:
rain_mse = float("nan")
rain_mae = float("nan")
probs_arr = np.array(all_probs)
y_true = np.array([1 if is_rain(t, threshold=rain_threshold) else 0 for t in targets_arr], dtype=np.int32)
positive_rate = float(y_true.mean())
prob_mean = float(probs_arr.mean())
prob_std = float(probs_arr.std())
brier = float(np.mean((probs_arr - y_true) ** 2))
if np.unique(y_true).size >= 2:
roc_auc = float(roc_auc_score(y_true, probs_arr))
else:
roc_auc = 0.5
if int(y_true.sum()) > 0:
auprc = float(average_precision_score(y_true, probs_arr))
else:
auprc = 0.0
cls_metrics = _class_metrics_at_threshold(
y_true=y_true,
probs=probs_arr,
threshold=effective_prob_threshold,
)
pred_positive_rate = float(cls_metrics["pred_positive_rate"])
cls_source = "offline_recomputed"
selected_cls = cls_metrics
y_pred_op = np.array([1 if is_rain(p, threshold=rain_threshold) else 0 for p in preds_arr], dtype=np.int32)
op_tp = int(((y_true == 1) & (y_pred_op == 1)).sum())
op_fn = int(((y_true == 1) & (y_pred_op == 0)).sum())
op_fp = int(((y_true == 0) & (y_pred_op == 1)).sum())
op_tn = int(((y_true == 0) & (y_pred_op == 0)).sum())
op_metrics = _metrics_from_cm(op_tp, op_fn, op_fp, op_tn)
threshold_scan: list[dict[str, float | int]] = []
if scan_thresholds:
for thr in scan_thresholds:
threshold_scan.append(
_class_metrics_at_threshold(
y_true=y_true,
probs=probs_arr,
threshold=float(thr),
)
)
return {
"client_id": client_id,
"best_round": saved_round,
"train_loss": saved_loss,
"samples": total_batches,
"mse": mse,
"mae": mae,
"rain_mse": rain_mse,
"rain_mae": rain_mae,
"auprc": auprc,
"roc_auc": roc_auc,
"brier": brier,
"positive_rate": positive_rate,
"pred_positive_rate": pred_positive_rate,
"prob_mean": prob_mean,
"prob_std": prob_std,
"accuracy": selected_cls["accuracy"],
"recall": selected_cls["recall"],
"precision": selected_cls["precision"],
"f1": selected_cls["f1"],
"tp": selected_cls["tp"],
"fn": selected_cls["fn"],
"fp": selected_cls["fp"],
"tn": selected_cls["tn"],
"cls_metric_source": cls_source,
"offline_accuracy": cls_metrics["accuracy"],
"offline_recall": cls_metrics["recall"],
"offline_precision": cls_metrics["precision"],
"offline_f1": cls_metrics["f1"],
"offline_tp": cls_metrics["tp"],
"offline_fn": cls_metrics["fn"],
"offline_fp": cls_metrics["fp"],
"offline_tn": cls_metrics["tn"],
"op_accuracy": op_metrics["accuracy"],
"op_recall": op_metrics["recall"],
"op_precision": op_metrics["precision"],
"op_f1": op_metrics["f1"],
"op_tp": op_metrics["tp"],
"op_fn": op_metrics["fn"],
"op_fp": op_metrics["fp"],
"op_tn": op_metrics["tn"],
"prob_threshold": effective_prob_threshold,
"prob_threshold_source": prob_threshold_source,
"rain_threshold_mm": rain_threshold,
"target_transform": target_mode,
"monthly_details": monthly_details,
"threshold_scan": threshold_scan,
}
def evaluate():
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", type=str, default="cpu",
help="Device to run on: 'cpu' or 'mps'"
)
parser.add_argument(
"--session",
type=str,
default=None,
help="Evaluate a specific session ID under bestweights/ and results/ (recommended).",
)
parser.add_argument(
"--scenario",
type=str,
default=None,
help="Restrict checkpoint search to a specific scenario subdirectory (e.g. '07'). "
"Required when multiple scenarios share the same session (matrix runs).",
)
parser.add_argument(
"--eval-phase",
type=str,
default="test",
choices=["test", "val"],
help="Evaluation split to use: test or val (threshold tuning should use val).",
)
parser.add_argument(
"--round",
type=int,
default=None,
help="Evaluate a specific round. Requires a strict paired periodic checkpoint set.",
)
parser.add_argument(
"--allow-cross-session-fallback",
action="store_true",
help="Allow searching client checkpoints across all sessions if current session lacks clients.",
)
parser.add_argument(
"--allow-latest-best-fallback",
action="store_true",
help="Allow non-strict latest/best server-client pairing (debug only; less reliable).",
)
parser.add_argument(
"--force-prob-threshold",
type=float,
default=None,
help="Force one shared probability threshold in [0,1] for all clients (ignores checkpoint threshold).",
)
parser.add_argument(
"--prefer-checkpoint-threshold",
action="store_true",
help="Prefer per-client checkpoint threshold when available (default uses config threshold).",
)
parser.add_argument(
"--scan-thresholds",
type=str,
default="",
help="Threshold grid for scan (e.g. '0.1,0.2,0.3' or '0.1:0.9:0.05').",
)
parser.add_argument(
"--report-tag",
type=str,
default="",
help="Optional suffix for output report files (e.g. fixedthr034).",
)
parser.add_argument(
"--eval-max-samples",
type=int,
default=None,
help="Override eval_max_samples_per_sensor from config (0 = full dataset).",
)
args = parser.parse_args()
forced_prob_threshold = None
if args.force_prob_threshold is not None:
forced_prob_threshold = float(args.force_prob_threshold)
if not (0.0 <= forced_prob_threshold <= 1.0):
raise ValueError("--force-prob-threshold must be in [0, 1].")
report_tag = _normalize_report_tag(args.report_tag)
if args.report_tag and not report_tag:
raise ValueError("Invalid --report-tag: must contain at least one alphanumeric character.")
eval_phase = str(args.eval_phase).strip().upper()
scan_thresholds = _parse_threshold_list(args.scan_thresholds)
print("\n" + "=" * 55)
print(f" 🚀 AUTO EVALUATION — ALL CLIENTS ({eval_phase} DATA) ")
print("=" * 55)
device = torch.device(args.device)
# ── Step 1: Select session and checkpoint pairing mode ────────────
selected_session = args.session or _find_latest_session_id()
pairing_mode = "strict_periodic"
num_clients_hint = max(1, int(get_nested(cfg, ("federated", "num_clients"), 1)))
paired = find_periodic_pair(
session_id=selected_session,
num_clients=num_clients_hint,
target_round=args.round,
scenario_id=args.scenario or None,
)
if paired is not None:
paired_round, server_path, client_map = paired
pairing_mode = f"periodic_round_{paired_round}"
print(f"[INFO] Using strict paired periodic checkpoints at round {paired_round}")
else:
round_hint = f", round={args.round}" if args.round is not None else ""
if not args.allow_latest_best_fallback:
raise FileNotFoundError(
f"No strict paired periodic checkpoints found for session={selected_session}{round_hint}. "
"Use --allow-latest-best-fallback for debug-only non-strict pairing."
)
if args.round is not None:
raise FileNotFoundError(
f"No strict paired periodic checkpoints found for session={selected_session}, round={args.round}."
)
server_path = find_best_server(session_id=selected_session)
client_map = find_matching_clients(
server_path,
allow_cross_session_fallback=args.allow_cross_session_fallback,
)
pairing_mode = "latest_best"
print("[WARN] Falling back to latest/best checkpoint selection (non-strict pairing).")
# ── Step 2: Summary of selected client checkpoints ───────────────
print(f"\n[INFO] Found {len(client_map)} client model(s) to evaluate: {sorted(client_map.keys())}")
# ── Step 3: Load server model once (shared across all clients) ───
server_raw = torch.load(server_path, map_location=device, weights_only=True)
if isinstance(server_raw, dict) and "model_state_dict" in server_raw:
server_state = server_raw["model_state_dict"]
server_round = server_raw.get("round", "?")
print(f"[INFO] Server checkpoint dict \u2014 round={server_round}")
else:
server_state = server_raw
server_round = "N/A"
print("[INFO] Server: legacy checkpoint (bare state_dict)")
eval_settings, eval_cfg_source = _resolve_eval_settings(server_raw)
split_date = eval_settings["split_date"]
seq_len = eval_settings["seq_len"]
horizon = eval_settings["horizon"]
input_size = eval_settings["input_size"]
lstm_dropout = eval_settings["lstm_dropout"]
hidden_size = eval_settings["hidden_size"]
head_width = eval_settings["head_width"]
head_dropout = eval_settings["head_dropout"]
eval_max_samples = eval_settings["eval_max_samples"]
if args.eval_max_samples is not None:
eval_max_samples = max(0, args.eval_max_samples)
num_clients = eval_settings["num_clients"]
processed_dir = eval_settings["processed_dir"]
active_features = eval_settings["active_features"]
prob_threshold = eval_settings["prob_threshold"]
rain_threshold = eval_settings["rain_threshold"]
target_mode = eval_settings["target_mode"]
if pairing_mode.startswith("periodic_round_"):
expected_clients = list(range(1, num_clients + 1))
if sorted(client_map.keys()) != expected_clients:
raise FileNotFoundError(
f"Strict periodic pairing requires all clients {expected_clients}, "
f"but found {sorted(client_map.keys())} in {pairing_mode}."
)
threshold_text = (
f"p(rain)>={forced_prob_threshold:.2f} (forced)"
if forced_prob_threshold is not None
else f"p(rain)>={prob_threshold:.2f}"
)
print(
f"[INFO] Eval config source: {eval_cfg_source} | split_date={split_date} | "
f"seq_len={seq_len} horizon={horizon} | per_sensor_cap={eval_max_samples if eval_max_samples > 0 else 'FULL'} | "
f"{threshold_text} | rain_mm>{rain_threshold:.2f} | target_transform={target_mode} | phase={eval_phase} | "
f"threshold_source={'checkpoint_preferred' if args.prefer_checkpoint_threshold else 'config'} | device={device}"
)
if forced_prob_threshold is not None:
print(
f"[INFO] Forced probability threshold active: p(rain)>={forced_prob_threshold:.2f} "
"(checkpoint thresholds ignored)"
)
server_model = ServerHead(
hidden_size=hidden_size,
output_size=1,
head_width=head_width,
dropout=head_dropout,
).to(device)
server_model.load_state_dict(server_state)
server_model.eval()
print(f"[INFO] Server model loaded and set to eval mode.")
# ── Step 4: Evaluate each client ─────────────────────────────────
all_results = []
for cid, cpath in sorted(client_map.items()):
print(f"\n{'─'*55}")
print(f" Evaluating Client {cid} ...")
print(f"{'─'*55}")
result = evaluate_client(
client_id=cid,
client_path=cpath,
server_model=server_model,
device=device,
split_date=split_date,
eval_max_samples=eval_max_samples,
seq_len=seq_len,
horizon=horizon,
input_size=input_size,
lstm_dropout=lstm_dropout,
hidden_size=hidden_size,
num_clients=num_clients,
processed_dir=processed_dir,
active_features=active_features,
prob_threshold=prob_threshold,
force_prob_threshold=forced_prob_threshold,
prefer_checkpoint_threshold=args.prefer_checkpoint_threshold,
eval_phase=eval_phase,
scan_thresholds=scan_thresholds,
rain_threshold=rain_threshold,
target_mode=target_mode,
)
if result:
all_results.append(result)
print(f" ✅ Client {cid} | Samples={result['samples']:,} | "
f"MSE={result['mse']:.4f} | MAE={result['mae']:.4f} mm | "
f"ClsF1={result['f1']:.3f} | PR-AUC={result['auprc']:.3f} | "
f"ROC-AUC={result['roc_auc']:.3f} | Brier={result['brier']:.4f}")
# ── Step 5: Print combined summary ─────────────────────────────────
if not all_results:
print("\n[ERROR] No clients were successfully evaluated.")
return
threshold_scan_summary: list[dict[str, float | int]] = []
recommended_threshold: float | None = None
if scan_thresholds:
for thr in scan_thresholds:
tp = fn = fp = tn = 0
total = 0
for r in all_results:
total += int(r["samples"])
scan_rows = r.get("threshold_scan", [])
found = next((m for m in scan_rows if abs(float(m["threshold"]) - float(thr)) <= 1e-12), None)
if found is None:
continue
tp += int(found["tp"])
fn += int(found["fn"])
fp += int(found["fp"])
tn += int(found["tn"])
metrics = _metrics_from_cm(tp, fn, fp, tn)
threshold_scan_summary.append(
{
"threshold": float(thr),
"samples": int(total),
"accuracy": float(metrics["accuracy"]),
"recall": float(metrics["recall"]),
"precision": float(metrics["precision"]),
"f1": float(metrics["f1"]),
"tp": int(tp),
"fn": int(fn),
"fp": int(fp),
"tn": int(tn),
"pred_positive_rate": float((tp + fp) / max(1, tp + fn + fp + tn)),
}
)
if threshold_scan_summary:
best = max(
threshold_scan_summary,
key=lambda x: (
float(x["f1"]),
float(x["precision"]),
-abs(float(x["threshold"]) - float(prob_threshold)),
),
)
recommended_threshold = float(best["threshold"])
W = 128 # table width
print("\n" + "=" * W)
print(f"\U0001f3c6 FINAL EVALUATION REPORT ({eval_phase} SET)")
print(f" Session : {selected_session}")
print(f" Pairing mode : {pairing_mode}")
print(f" Server checkpoint : {Path(server_path).name} (round={server_round})")
print(f" Eval cfg source : {eval_cfg_source}")
if forced_prob_threshold is not None:
print(f" Forced threshold : {forced_prob_threshold:.2f}")
if recommended_threshold is not None:
print(f" Scan best threshold (by F1): {recommended_threshold:.2f}")
print("=" * W)
hdr = (f" {'Client':<8} {'BestRound':>10} {'TrainLoss':>10}"
f" {'Samples':>8} {'MSE':>8} {'MAE':>8} {'ClsAcc':>8} {'ClsF1':>8}"
f" {'PR-AUC':>8} {'ROC-AUC':>8} {'Brier':>8} {'OpF1':>8}")
sep = (f" {'──────':<8} {'─────────':>10} {'─────────':>10}"
f" {'───────':>8} {'───────':>8} {'───────':>8} {'──────':>8} {'──────':>8}"
f" {'──────':>8} {'───────':>8} {'──────':>8} {'──────':>8}")
print(hdr)
print(sep)
for r in all_results:
br = str(r['best_round'])
tl = f"{r['train_loss']:.4f}" if not np.isnan(r['train_loss']) else "N/A"
print(f" {r['client_id']:<8} {br:>10} {tl:>10}"
f" {r['samples']:>8,} {r['mse']:>8.4f} {r['mae']:>8.4f} {r['accuracy']:>7.2f}% {r['f1']:>8.4f}"
f" {r['auprc']:>8.4f} {r['roc_auc']:>8.4f} {r['brier']:>8.4f} {r['op_f1']:>8.4f}")
if len(all_results) > 1:
avg_mse = np.mean([r["mse"] for r in all_results])
avg_mae = np.mean([r["mae"] for r in all_results])
avg_acc = np.mean([r["accuracy"] for r in all_results])
avg_f1 = np.mean([r["f1"] for r in all_results])
avg_auprc = np.mean([r["auprc"] for r in all_results])
avg_roc_auc = np.mean([r["roc_auc"] for r in all_results])
avg_brier = np.mean([r["brier"] for r in all_results])
avg_op_f1 = np.mean([r["op_f1"] for r in all_results])
tot = sum( r["samples"] for r in all_results)
print(sep)
print(f" {'AVERAGE':<8} {'':>10} {'':>10}"
f" {tot:>8,} {avg_mse:>8.4f} {avg_mae:>8.4f} {avg_acc:>7.2f}% {avg_f1:>8.4f}"
f" {avg_auprc:>8.4f} {avg_roc_auc:>8.4f} {avg_brier:>8.4f} {avg_op_f1:>8.4f}")
print("=" * W + "\n")
if threshold_scan_summary:
print("Threshold scan summary (weighted/global confusion):")
print(" threshold f1 precision recall pred_pos")
for row in threshold_scan_summary:
print(
f" {row['threshold']:>8.2f} {row['f1']:>6.4f} {row['precision']:>9.4f} "
f"{row['recall']:>7.4f} {row['pred_positive_rate']:>8.4f}"
)
bw_session_dir = project_root / "bestweights" / selected_session
try:
rel_path = Path(server_path).relative_to(bw_session_dir)
scenario_part = rel_path.parts[0] if len(rel_path.parts) > 1 and rel_path.parts[0] != "periodic" else ""
except ValueError:
scenario_part = ""
save_dir = project_root / "results" / selected_session
save_dir.mkdir(parents=True, exist_ok=True)
results_dir = save_dir / scenario_part if scenario_part else save_dir
if args.scenario:
report_stem = f"{args.scenario}_eval_report"
elif scenario_part:
report_stem = f"{scenario_part}_eval_report"
else:
safe_session_str = str(selected_session).replace("/", "_").replace("\\", "_")
report_stem = f"{safe_session_str}_eval_report"
if report_tag:
report_stem = f"{report_stem}_{report_tag}"
import re
training_meta_files = list(results_dir.glob("*_meta.json"))
client_telemetry = {}
for mf in training_meta_files:
if "progress" in mf.name:
continue
try:
m = re.search(r'client(\d+)', mf.name)
if m:
with open(mf, "r") as f:
client_telemetry[int(m.group(1))] = json.load(f)
except Exception:
pass
for r in all_results:
cid = r["client_id"]
tm = client_telemetry.get(cid, {})
r["cpu_percent"] = round(float(tm["avg_cpu_percent"]), 1) if tm.get("avg_cpu_percent") is not None else np.nan
r["mem_percent"] = round(float(tm["avg_mem_percent"]), 1) if tm.get("avg_mem_percent") is not None else np.nan
r["runtime_s"] = round(float(tm["total_runtime_s"]), 1) if tm.get("total_runtime_s") is not None else np.nan
msb = tm.get("model_size_bytes")
r["model_size_mb"] = round(msb / (1024 * 1024), 2) if msb is not None else np.nan
apb = tm.get("avg_payload_bytes")
r["payload_bytes"] = round(float(apb), 1) if apb is not None else np.nan
alm = tm.get("avg_latency_ms")
r["latency_ms"] = round(float(alm), 1) if alm is not None else np.nan
rt = tm.get("total_runtime_s")
r["throughput_sps"] = round(r.get("samples", 0) / rt, 1) if rt and float(rt) > 0 else np.nan
r["net_sent_mb"] = tm.get("net_sent_mb", np.nan)
r["net_recv_mb"] = tm.get("net_recv_mb", np.nan)
r["mem_peak_mb"] = tm.get("mem_peak_mb", np.nan)
r["sync_bytes_sent_mb"] = tm.get("sync_bytes_sent_mb", np.nan)
r["sync_bytes_recv_mb"] = tm.get("sync_bytes_recv_mb", np.nan)
# 數據量吞吐量: (平均每步傳輸量 × 總步數) / 訓練時間 (KB/s)
num_records = tm.get("num_records")
if apb and num_records and rt and float(rt) > 0:
r["data_throughput_kbps"] = round(
(float(apb) * int(num_records) / 1024) / float(rt), 2
)
else:
r["data_throughput_kbps"] = np.nan
# 總傳輸量 (MB):全訓練過程所有 Forward Pass 的 Payload 合計
# 對應 Results Table 的 Traff (payload) 欄位
if apb and num_records:
r["total_payload_mb"] = round(float(apb) * int(num_records) / (1024 * 1024), 4)
else:
r["total_payload_mb"] = np.nan
# 🚀 CALCULATE AGGREGATED TELEMETRY & WEIGHTED METRICS (Move up for CSV inclusion)
total_samples = float(sum(r["samples"] for r in all_results))
weighted = {
"mse": float(sum(r["mse"] * r["samples"] for r in all_results) / total_samples),
"mae": float(sum(r["mae"] * r["samples"] for r in all_results) / total_samples),
"accuracy": float(sum(r["accuracy"] * r["samples"] for r in all_results) / total_samples),
"f1": float(sum(r["f1"] * r["samples"] for r in all_results) / total_samples),
"auprc": float(sum(r["auprc"] * r["samples"] for r in all_results) / total_samples),
"roc_auc": float(sum(r["roc_auc"] * r["samples"] for r in all_results) / total_samples),
"brier": float(sum(r["brier"] * r["samples"] for r in all_results) / total_samples),
"positive_rate": float(sum(r["positive_rate"] * r["samples"] for r in all_results) / total_samples),
"pred_positive_rate": float(sum(r["pred_positive_rate"] * r["samples"] for r in all_results) / total_samples),
"tp": int(sum(r["tp"] for r in all_results)),
"fn": int(sum(r["fn"] for r in all_results)),
"fp": int(sum(r["fp"] for r in all_results)),
"tn": int(sum(r["tn"] for r in all_results)),
"rain_mse": float(np.nanmean([r.get("rain_mse", float("nan")) for r in all_results])),
"rain_mae": float(np.nanmean([r.get("rain_mae", float("nan")) for r in all_results])),
}
# Fetch and aggregate hardware telemetry
sys_telemetry = {}
if training_meta_files:
try:
runtimes, cpus, mems = [], [], []
net_sent, net_recv, mem_peaks = [], [], []
sync_sent, sync_recv = [], []
for mf in training_meta_files:
if "progress" in mf.name: continue
with open(mf, "r") as f:
mt = json.load(f)
if mt.get("total_runtime_s"): runtimes.append(float(mt["total_runtime_s"]))
if mt.get("avg_cpu_percent"): cpus.append(float(mt["avg_cpu_percent"]))
if mt.get("avg_mem_percent"): mems.append(float(mt["avg_mem_percent"]))
if mt.get("net_sent_mb"): net_sent.append(float(mt["net_sent_mb"]))
if mt.get("net_recv_mb"): net_recv.append(float(mt["net_recv_mb"]))
if mt.get("mem_peak_mb"): mem_peaks.append(float(mt["mem_peak_mb"]))
if mt.get("sync_bytes_sent_mb"): sync_sent.append(float(mt["sync_bytes_sent_mb"]))
if mt.get("sync_bytes_recv_mb"): sync_recv.append(float(mt["sync_bytes_recv_mb"]))
if runtimes: sys_telemetry["avg_runtime_s"] = round(sum(runtimes)/len(runtimes), 2)
if cpus: sys_telemetry["avg_cpu_percent"] = round(sum(cpus)/len(cpus), 1)
if mems: sys_telemetry["avg_mem_percent"] = round(sum(mems)/len(mems), 1)
if net_sent: sys_telemetry["total_net_sent_mb"] = round(sum(net_sent), 2)
if net_recv: sys_telemetry["total_net_recv_mb"] = round(sum(net_recv), 2)
if mem_peaks: sys_telemetry["avg_mem_peak_mb"] = round(sum(mem_peaks)/len(mem_peaks), 2)
if sync_sent: sys_telemetry["total_sync_sent_mb"] = round(sum(sync_sent), 4)
if sync_recv: sys_telemetry["total_sync_recv_mb"] = round(sum(sync_recv), 4)
if sys_telemetry.get("avg_runtime_s", 0) > 0:
sys_telemetry["throughput_sps"] = round(total_samples / sys_telemetry["avg_runtime_s"], 2)
# 系統級數據量吞吐量:所有 Client 傳輸 Bytes 總和 / 平均訓練時間
total_payload_bytes = sum(
float(r.get("payload_bytes", 0) or 0)
* float(client_telemetry.get(cid, {}).get("num_records", 0) or 0)
for cid, r in zip(
[res.get("client_id") for res in all_results],
all_results
)
)
if total_payload_bytes > 0:
sys_telemetry["data_throughput_kbps"] = round(
(total_payload_bytes / 1024) / sys_telemetry["avg_runtime_s"], 2
)
except Exception:
pass
report_csv = save_dir / f"{report_stem}.csv"
# 重新整理每一行的順序,確保 monthly_details 在最後
csv_rows = []
for r in all_results:
# 先抓取所有非 monthly_details 且非 threshold_scan 的欄位
row = {k: v for k, v in r.items() if k not in ["threshold_scan", "monthly_details"]}
# 最後補上 monthly_details
row["monthly_details"] = r.get("monthly_details")
csv_rows.append(row)
# 準備 Summary 行
summary_row = {
"client_id": "SUMMARY",
"samples": int(total_samples),
"mse": weighted["mse"],
"mae": weighted["mae"],
"rain_mse": weighted["rain_mse"],
"rain_mae": weighted["rain_mae"],
"accuracy": weighted["accuracy"],
"f1": weighted["f1"],
"auprc": weighted["auprc"],
"roc_auc": weighted["roc_auc"],
"brier": weighted["brier"],
"tp": weighted["tp"],
"fn": weighted["fn"],
"fp": weighted["fp"],
"tn": weighted["tn"],
"cpu_percent": sys_telemetry.get("avg_cpu_percent", np.nan),
"mem_percent": sys_telemetry.get("avg_mem_percent", np.nan),
"runtime_s": sys_telemetry.get("avg_runtime_s", np.nan),
"throughput_sps": sys_telemetry.get("throughput_sps", np.nan),
"data_throughput_kbps": sys_telemetry.get("data_throughput_kbps", np.nan),
"total_payload_mb": round(sum(r.get("total_payload_mb", 0) or 0 for r in all_results), 4),
"net_sent_mb": sys_telemetry.get("total_net_sent_mb", np.nan),
"net_recv_mb": sys_telemetry.get("total_net_recv_mb", np.nan),
"mem_peak_mb": sys_telemetry.get("avg_mem_peak_mb", np.nan),
"sync_bytes_sent_mb": sys_telemetry.get("total_sync_sent_mb", np.nan),
"sync_bytes_recv_mb": sys_telemetry.get("total_sync_recv_mb", np.nan),
"monthly_details": "" # Summary 行留空
}
csv_rows.append(summary_row)
# 直接儲存 (Pandas 會遵循第一個 dict 的 keys 順序)
pd.DataFrame(csv_rows).to_csv(report_csv, index=False)
report_json = save_dir / f"{report_stem}.json"
summary = {
"session": selected_session,
"pairing_mode": pairing_mode,
"server_checkpoint": Path(server_path).name,
"server_round": server_round,
"report_tag": report_tag,
"strict_pairing_default": True,
"allow_latest_best_fallback": bool(args.allow_latest_best_fallback),
"prefer_checkpoint_threshold": bool(args.prefer_checkpoint_threshold),
"forced_prob_threshold": forced_prob_threshold,
"device": str(device),
"eval_phase": eval_phase,
"split_date": str(split_date),
"eval_max_samples_per_sensor": eval_max_samples,
"eval_config_source": eval_cfg_source,
"eval_config": {
"num_clients": num_clients,
"processed_dir": processed_dir,
"feature_cols": active_features,
"seq_len": seq_len,
"input_size": input_size,
"hidden_size": hidden_size,
"server_head_width": head_width,
"server_head_dropout": head_dropout,
"rain_probability_threshold": (
forced_prob_threshold if forced_prob_threshold is not None else prob_threshold
),
"rain_probability_threshold_source": (
"forced_cli"
if forced_prob_threshold is not None
else ("client_checkpoint" if args.prefer_checkpoint_threshold else "config")
),
"rain_threshold_mm": rain_threshold,
"target_transform": target_mode,
},
"num_clients_evaluated": len(all_results),
"weighted_overall": weighted,
"threshold_scan_summary": threshold_scan_summary,
"recommended_threshold_by_f1": recommended_threshold,
"clients": all_results,
}
# 🚀 Inject Training Phase Metadata into Summary
if sys_telemetry:
summary.update({
"training_total_runtime_s": sys_telemetry.get("avg_runtime_s"),
"training_avg_cpu_percent": sys_telemetry.get("avg_cpu_percent"),
"training_avg_mem_percent": sys_telemetry.get("avg_mem_percent"),
"training_throughput_samples_s": sys_telemetry.get("throughput_sps"),
})
print("\n" + "=" * W)
print(f"🌡️ HARDWARE TELEMETRY & EFFICIENCY")
print("=" * W)
print(f" Total Runtime : {sys_telemetry.get('avg_runtime_s', 'N/A')} s")
print(f" System Throughput : {sys_telemetry.get('throughput_sps', 'N/A')} samples/s")
print(f" Average CPU Usage : {sys_telemetry.get('avg_cpu_percent', 'N/A')} %")
print(f" Average Memory : {sys_telemetry.get('avg_mem_percent', 'N/A')} %")
print("=" * W + "\n")
else:
print(f"[WARN] No training metadata found in {results_dir}")
with open(report_json, "w") as f:
json.dump(summary, f, indent=2)
print(f"[INFO] Saved evaluation CSV : {report_csv}")
print(f"[INFO] Saved evaluation JSON: {report_json}")
if __name__ == "__main__":
evaluate()