from pathlib import Path
import numpy as np
import pandas as pd
import torch
from src.client.data_pipeline import (
FUTURE_RAIN_COL,
collect_eval_indices_capped,
load_sensor_data,
resolve_split_pos,
sample_index,
)
from src.client.forward_step import run_forward_step
from src.client.scheduler_state import SchedulerState
from src.shared.targets import is_rain, rain_probability_threshold
def _binary_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float | int]:
tp = int(((y_true == 1) & (y_pred == 1)).sum())
fn = int(((y_true == 1) & (y_pred == 0)).sum())
fp = int(((y_true == 0) & (y_pred == 1)).sum())
tn = int(((y_true == 0) & (y_pred == 0)).sum())
recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0
f1 = float(2.0 * recall * precision / (recall + precision)) if (recall + precision) > 0 else 0.0
total = tp + fn + fp + tn
accuracy = float((tp + tn) / total) if total > 0 else 0.0
return {
"tp": tp,
"fn": fn,
"fp": fp,
"tn": tn,
"recall": recall,
"precision": precision,
"f1": f1,
"accuracy": accuracy,
}
def _select_best_threshold(
y_true: np.ndarray,
probs: np.ndarray,
*,
default_threshold: float,
) -> tuple[float, dict[str, float | int], dict[str, float | int]]:
default_pred = (probs >= default_threshold).astype(np.int32)
default_metrics = _binary_metrics(y_true, default_pred)
best_threshold = float(default_threshold)
best_metrics = default_metrics
best_score = (
float(default_metrics["f1"]),
float(default_metrics["precision"]),
-abs(float(default_threshold) - float(default_threshold)),
)
for threshold in np.linspace(0.0, 1.0, 201):
y_pred = (probs >= threshold).astype(np.int32)
metrics = _binary_metrics(y_true, y_pred)
score = (
float(metrics["f1"]),
float(metrics["precision"]),
-abs(float(threshold) - float(default_threshold)),
)
if score > best_score:
best_score = score
best_threshold = float(threshold)
best_metrics = metrics
return best_threshold, best_metrics, default_metrics
def preload_sensor_data(
client_id: int,
client_files: list[str],
*,
horizon: int | None = None,
) -> dict[str, pd.DataFrame]:
print(f"[CLIENT {client_id}] Pre-loading sensor data into memory...")
sensor_data_cache: dict[str, pd.DataFrame] = {}
for file_path in client_files:
sensor_id = Path(file_path).stem
try:
sensor_data_cache[file_path] = load_sensor_data(file_path, horizon=horizon)
except Exception as e:
print(f"[CLIENT {client_id} ERROR] Failed to load {sensor_id}: {e}")
if not sensor_data_cache:
raise RuntimeError(
f"Client {client_id} could not load any sensor data from {len(client_files)} assigned files. "
"Check dataset contents and preprocessing outputs."
)
return sensor_data_cache
def compute_feature_stats(
*,
client_id: int,
sensor_data_cache: dict[str, pd.DataFrame],
feature_cols: list[str],
) -> tuple[np.ndarray, np.ndarray]:
"""Calculate normalisation statistics using only TRAIN-phase rows (2023-01-01 to 2024-12-31)."""
from src.client.data_pipeline import get_dataset_split
print(f"[CLIENT {client_id}] Calculating feature statistics for normalisation (TRAIN phase only)...")
train_frames: list[pd.DataFrame] = []
train_rows = 0
for df in sensor_data_cache.values():
if not isinstance(df.index, pd.DatetimeIndex):
raise TypeError("compute_feature_stats requires a DatetimeIndex.")
timestamps = pd.to_datetime(df.index)
train_mask = np.array([get_dataset_split(ts) == "TRAIN" for ts in timestamps])
train_df = df[train_mask]
if not train_df.empty:
train_frames.append(train_df)
train_rows += int(len(train_df))
if not train_frames:
raise RuntimeError(
f"Client {client_id} has 0 TRAIN-phase rows for feature normalisation. "
"Check dataset timestamps and monthly_cycle configuration."
)
all_combined = pd.concat(train_frames)
feat_mean = all_combined[feature_cols].mean().values
feat_std = all_combined[feature_cols].std().values + 1e-9
print(f"[CLIENT {client_id}] Normalisation stats rows={train_rows}")
return feat_mean, feat_std
def build_eval_index_cache(
*,
client_id: int,
sensor_data_cache: dict[str, pd.DataFrame],
target_phase: str,
eval_max_samples: int,
seq_len: int,
label: str,
horizon: int = 3,
) -> tuple[dict[str, np.ndarray], int, int]:
eval_index_cache: dict[str, np.ndarray] = {}
total_eval_samples = 0
total_eval_positive = 0
for file_path, df in sensor_data_cache.items():
eval_indices = collect_eval_indices_capped(
df,
target_phase=target_phase,
eval_max_samples=eval_max_samples,
min_history=seq_len,
horizon=horizon,
)
eval_index_cache[file_path] = eval_indices
total_eval_samples += int(len(eval_indices))
if len(eval_indices) > 0:
total_eval_positive += int(df[FUTURE_RAIN_COL].iloc[eval_indices].apply(is_rain).sum())
print(
f"[CLIENT {client_id}] Fixed {label.lower()} set prepared: "
f"samples={total_eval_samples} positives={total_eval_positive} "
f"(per_sensor_cap={eval_max_samples if eval_max_samples > 0 else 'FULL'})"
)
return eval_index_cache, total_eval_samples, total_eval_positive
def run_train_epoch(
*,
stub,
client_id: int,
client_model,
optimizer,
client_files: list[str],
sensor_data_cache: dict[str, pd.DataFrame],
train_state: SchedulerState,
feature_cols: list[str],
feat_stats: tuple[np.ndarray, np.ndarray],
device: torch.device,
local_steps: int,
rain_sample_ratio: float,
seq_len: int,
epoch: int,
experimental_logs: list[dict],
epoch_logs: list[dict],
horizon: int = 3,
rain_threshold: float | None = None,
) -> int:
epoch_train_steps = 0
for file_path in client_files:
sensor_id = Path(file_path).stem
try:
df = sensor_data_cache.get(file_path)
if df is None:
continue
for _ in range(local_steps):
optimizer.zero_grad()
result = sample_index(
df,
None,
is_training=True,
rain_sample_ratio=rain_sample_ratio,
min_history=seq_len,
horizon=horizon,
rain_threshold=rain_threshold,
)
if result is None:
continue
target_idx, mode = result
target_value = float(df[FUTURE_RAIN_COL].iloc[target_idx])
log_entry = run_forward_step(
stub,
client_id,
client_model,
optimizer,
df,
target_idx,
target_value,
mode,
sensor_id,
train_state.compression_mode,
feature_cols,
feat_stats,
device,
is_training=True,
last_latency_ms=train_state.last_latency_ms,
seq_len=seq_len,
)
train_state.update(log_entry)
epoch_train_steps += 1
epoch_record = {"Epoch": epoch + 1, "Status": "TRAIN", "Sensor": sensor_id, **log_entry}
experimental_logs.append(epoch_record)
epoch_logs.append(epoch_record)
except Exception as e:
print(f"[CLIENT {client_id} ERROR] {sensor_id}: {e}")
return epoch_train_steps
def run_eval_epoch(
*,
stub,
client_id: int,
client_model,
optimizer,
client_files: list[str],
sensor_data_cache: dict[str, pd.DataFrame],
eval_index_cache: dict[str, np.ndarray],
eval_state: SchedulerState,
feature_cols: list[str],
feat_stats: tuple[np.ndarray, np.ndarray],
device: torch.device,
seq_len: int,
epoch: int,
experimental_logs: list[dict],
epoch_logs: list[dict],
phase_label: str = "VAL",
) -> tuple[list[float], dict[str, float | int]]:
epoch_eval_losses: list[float] = []
eval_targets: list[float] = []
eval_probs: list[float] = []
default_threshold = float(rain_probability_threshold())
with torch.no_grad():
for file_path in client_files:
sensor_id = Path(file_path).stem
try:
df = sensor_data_cache.get(file_path)
if df is None:
continue
eval_indices = eval_index_cache.get(file_path)
if eval_indices is None or len(eval_indices) == 0:
continue
for target_idx in eval_indices:
target_value = float(df[FUTURE_RAIN_COL].iloc[target_idx])
log_entry = run_forward_step(
stub,
client_id,
client_model,
optimizer,
df,
int(target_idx),
target_value,
f"FIXED_{phase_label}",
sensor_id,
eval_state.compression_mode,
feature_cols,
feat_stats,
device,
is_training=False,
last_latency_ms=eval_state.last_latency_ms,
seq_len=seq_len,
)
eval_state.update(log_entry)
epoch_record = {"Epoch": epoch + 1, "Status": phase_label, "Sensor": sensor_id, **log_entry}
experimental_logs.append(epoch_record)
epoch_logs.append(epoch_record)
if log_entry["Loss"] is not None:
epoch_eval_losses.append(float(log_entry["Loss"]))
eval_targets.append(float(log_entry["Target"]))
eval_probs.append(float(log_entry.get("RainProbability", 0.0)))
except Exception as e:
print(f"[CLIENT {client_id} WARN] Eval failed on {sensor_id}: {e}")
if not eval_targets:
empty_metrics = {
"tp": 0,
"fn": 0,
"fp": 0,
"tn": 0,
"recall": 0.0,
"precision": 0.0,
"f1": 0.0,
"accuracy": 0.0,
"selected_threshold": default_threshold,
"default_threshold": default_threshold,
"default_recall": 0.0,
"default_precision": 0.0,
"default_f1": 0.0,
"default_accuracy": 0.0,
}
return epoch_eval_losses, empty_metrics
y_true = np.array([1 if is_rain(target) else 0 for target in eval_targets], dtype=np.int32)
probs_arr = np.array(eval_probs, dtype=np.float32)
selected_threshold, selected_metrics, default_metrics = _select_best_threshold(
y_true,
probs_arr,
default_threshold=default_threshold,
)
eval_metrics: dict[str, float | int] = {
**selected_metrics,
"phase": phase_label,
"selected_threshold": float(selected_threshold),
"default_threshold": default_threshold,
"default_recall": float(default_metrics["recall"]),
"default_precision": float(default_metrics["precision"]),
"default_f1": float(default_metrics["f1"]),
"default_accuracy": float(default_metrics["accuracy"]),
}
return epoch_eval_losses, eval_metrics