import grpc
import glob
import os
import re
import socket
import time
import psutil
from dataclasses import dataclass, field
import torch
from proto import fsl_pb2
from proto import fsl_pb2_grpc
from src.client.checkpointing import CheckpointState, evaluate_epoch
from src.client.data_pipeline import partition_client_files
from src.client.reporting import print_summary, save_progress, save_results, summarize_logs, summarize_phase
from src.client.scheduler_state import SchedulerState
from src.client.sync import fed_avg_sync
from src.client.training_loop import (
build_eval_index_cache,
compute_feature_stats,
preload_sensor_data,
run_eval_epoch,
run_train_epoch,
)
from src.models.split_lstm import ClientLSTM
from src.shared.common import cfg, feature_cols_from_cfg, project_root
from src.shared.runtime import create_grpc_channel, resolve_device, resolve_server_address, set_global_seed
from src.shared.targets import rain_threshold_mm
FEATURE_COLS = feature_cols_from_cfg()
# Maximum reconnect attempts on transient gRPC failures.
_MAX_RECONNECT = 5
_RECONNECT_BACKOFF = [5, 15, 30, 60, 120] # seconds between attempts
# ── Helpers ───────────────────────────────────────────────────────────────────
def _resolve_requested_client_id() -> int:
"""Prefer an explicit CLIENT_ID env var; fall back to compose-style hostname."""
raw_env = os.getenv("CLIENT_ID", "").strip()
if raw_env:
try:
return int(raw_env)
except ValueError:
pass
hostname = os.getenv("HOSTNAME") or socket.gethostname()
match = re.fullmatch(r"fsl-client-(\d+)", hostname)
return int(match.group(1)) if match else 0
def _is_retriable(exc: grpc.RpcError) -> bool:
"""True for transient network errors that are safe to retry."""
return exc.code() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN)
# ── Persistent state ──────────────────────────────────────────────────────────
def get_system_metrics(self):
"""Monitor CPU and Memory utilization."""
return {
"CPU_Percent": psutil.cpu_percent(interval=None),
"Mem_Percent": psutil.virtual_memory().percent
}
def get_model_size_bytes(self):
"""Calculate total model size in bytes."""
param_size = 0
for param in self.client_model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in self.client_model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return param_size + buffer_size
@dataclass
class _ClientState:
"""
All mutable state for one client process.
Fields set after registration and initialisation survive gRPC reconnects
so training can resume from the last committed epoch.
"""
# Set after registration
client_id: int | None = None
session_id: str | None = None
num_clients: int = 0
session_dir: str | None = None
periodic_dir: str | None = None
actual_seed: int | None = None
# Set by _init_local (once per run, not per connection)
client_model: ClientLSTM | None = None
optimizer: torch.optim.Optimizer | None = None
device: torch.device | None = None
client_files: list | None = None
sensor_data_cache: object = None
feat_stats: object = None
val_index_cache: object = None
seq_len: int = 0
target_horizon: int = 0
train_state: SchedulerState | None = None
val_state: SchedulerState | None = None
# Progress (must survive reconnects)
start_epoch: int = 0
current_round: int = 0
model_round: int = 0
completed_epochs: int = 0
total_steps: int = 0
checkpoint_state: CheckpointState = field(default_factory=CheckpointState)
experimental_logs: list = field(default_factory=list)
finalized: bool = False
run_start_time: float = field(default_factory=time.time)
# Sync transmission tracking
total_sync_bytes_sent: int = 0
total_sync_bytes_recv: int = 0
def get_system_metrics(self):
"""Monitor CPU, Memory (RSS), and Network utilization."""
import psutil
process = psutil.Process()
mem_info = process.memory_info()
net_info = psutil.net_io_counters()
current_rss_mb = mem_info.rss / (1024 * 1024)
if not hasattr(self, "_peak_rss_mb"):
self._peak_rss_mb = current_rss_mb
else:
self._peak_rss_mb = max(self._peak_rss_mb, current_rss_mb)
return {
"CPU_Percent": psutil.cpu_percent(interval=None),
"Mem_Percent": psutil.virtual_memory().percent,
"Mem_RSS_MB": round(current_rss_mb, 2),
"Mem_Peak_MB": round(self._peak_rss_mb, 2),
"Net_Sent_MB": round(net_info.bytes_sent / (1024 * 1024), 2),
"Net_Recv_MB": round(net_info.bytes_recv / (1024 * 1024), 2)
}
def get_model_size_bytes(self):
"""Calculate total model size in bytes."""
if self.client_model is None:
return 0
param_size = sum(p.nelement() * p.element_size() for p in self.client_model.parameters())
buffer_size = sum(b.nelement() * b.element_size() for b in self.client_model.buffers())
return param_size + buffer_size
# ── Step 1: Register ──────────────────────────────────────────────────────────
def _register(stub, state: _ClientState, client_name: str, requested_client_id: int) -> None:
"""
Registers the client with the server to get a logical ID and session.
Retries indefinitely if a scenario mismatch is detected (waiting for server).
"""
scenario_id = os.environ.get("SCENARIO_ID", "")
while True:
try:
reg = stub.Register(
fsl_pb2.RegisterRequest(
client_name=client_name,
requested_client_id=requested_client_id,
),
metadata=[("scenario-id", scenario_id)]
)
if reg.session_id == "ERROR_SCENARIO_MISMATCH":
print(f"[CLIENT] Waiting for server to switch to scenario: {scenario_id}...")
time.sleep(10)
continue
new_client_id, num_clients, new_session_id = reg.client_id, reg.total_clients, reg.session_id
break
except grpc.RpcError as e:
print(f"[CLIENT] Registration failed (server might be restarting): {e.code()}")
time.sleep(10)
if state.client_id is None:
state.client_id = new_client_id
state.session_id = new_session_id
state.num_clients = num_clients
base_seed = cfg.get("training", {}).get("seed", 42)
client_seed = (int(base_seed) + int(state.client_id)) if base_seed is not None else None
state.actual_seed = client_seed
set_global_seed(client_seed, role=f"client-{state.client_id}")
scenario_id = os.environ.get("SCENARIO_ID")
if scenario_id:
state.session_dir = os.path.join(project_root, "bestweights", state.session_id, scenario_id)
else:
state.session_dir = os.path.join(project_root, "bestweights", state.session_id)
state.periodic_dir = os.path.join(state.session_dir, "periodic")
os.makedirs(state.session_dir, exist_ok=True)
os.makedirs(state.periodic_dir, exist_ok=True)
# Pre-create results/<session_id>[/<scenario_id>] so save_results never
# hits FileNotFoundError on the second (or later) seed run.
if scenario_id:
os.makedirs(os.path.join(project_root, "results", state.session_id, scenario_id), exist_ok=True)
else:
os.makedirs(os.path.join(project_root, "results", state.session_id), exist_ok=True)
elif new_client_id != state.client_id or new_session_id != state.session_id:
raise RuntimeError(
f"Session mismatch on reconnect: "
f"expected client={state.client_id}/session={state.session_id}, "
f"got client={new_client_id}/session={new_session_id}. "
"The server may have been restarted. Aborting."
)
print(
f"[CLIENT] Registered name: {client_name} | requested_id: {requested_client_id or 'auto'} "
f"| assigned_id: {state.client_id} / {num_clients} | session: {state.session_id}"
f" | seed: {state.actual_seed}"
+ (" [resumed]" if state.start_epoch > 0 else "")
)
# ── Step 2: One-time local setup ──────────────────────────────────────────────
def _init_local(state: _ClientState, data_dir: str, compression_mode: str) -> None:
"""
Load sensor data, build model, compute feature statistics and val index cache.
This is called once per training run; subsequent calls (on reconnect) are no-ops.
"""
if state.client_model is not None:
print(f"[CLIENT {state.client_id}] Resuming from epoch {state.start_epoch + 1}")
return
all_files = sorted(glob.glob(os.path.join(project_root, data_dir, "*.parquet")))
state.client_files = partition_client_files(
all_files, client_id=state.client_id, num_clients=state.num_clients,
)
print(f"[CLIENT {state.client_id}] Allocated {len(state.client_files)}/{len(all_files)} sensors")
if not state.client_files:
raise RuntimeError(
f"Client {state.client_id} was assigned 0 sensors "
f"(total_sensors={len(all_files)}, total_clients={state.num_clients}). "
"Reduce num_clients or provide more sensor files."
)
state.device = resolve_device()
print(f"[CLIENT {state.client_id}] Using device: {state.device}")
model_cfg = cfg.get("model", {})
lstm_dropout = float(model_cfg.get("lstm_dropout", model_cfg.get("dropout", 0.3)))
state.client_model = ClientLSTM(
input_size=model_cfg.get("input_size", len(FEATURE_COLS)),
hidden_size=model_cfg.get("hidden_size", 64),
num_layers=model_cfg.get("num_layers", 1),
lstm_dropout=lstm_dropout,
).to(state.device)
state.optimizer = torch.optim.Adam(
state.client_model.parameters(),
lr=cfg.get("training", {}).get("lr", 0.001),
)
state.target_horizon = max(1, int(cfg.get("model", {}).get("horizon", 3)))
state.sensor_data_cache = preload_sensor_data(
state.client_id, state.client_files, horizon=state.target_horizon,
)
state.feat_stats = compute_feature_stats(
client_id=state.client_id,
sensor_data_cache=state.sensor_data_cache,
feature_cols=FEATURE_COLS,
)
state.seq_len = int(cfg.get("model", {}).get("seq_len", 24))
eval_max_samples = max(0, int(cfg.get("training", {}).get("eval_max_samples_per_sensor", 0)))
state.val_index_cache, val_samples, _ = build_eval_index_cache(
client_id=state.client_id,
sensor_data_cache=state.sensor_data_cache,
target_phase="VAL",
eval_max_samples=eval_max_samples,
seq_len=state.seq_len,
label="VAL",
horizon=state.target_horizon,
)
if val_samples == 0:
raise RuntimeError(
f"Client {state.client_id} has 0 validation samples. "
"Check dataset timestamps and train_end/val_end configuration."
)
data_cfg = cfg.get("data", {})
str_train_end = data_cfg.get("train_end", "2024-12-31")
str_val_end = data_cfg.get("val_end", "2025-06-30")
print(
f"[CLIENT {state.client_id}] Chronological split | TRAIN: <{str_train_end} "
f"| VAL: {str_train_end}→<{str_val_end} | TEST: >={str_val_end} | horizon={state.target_horizon}h"
)
base_rho = max(1, int(cfg.get("federated", {}).get("rho", 1)))
state.train_state = SchedulerState(compression_mode=compression_mode, rho=base_rho)
state.val_state = SchedulerState(compression_mode=compression_mode, rho=base_rho)
# ── Step 3a: Validation (called after each sync) ──────────────────────────────
def _run_validation(
stub,
state: _ClientState,
epoch: int,
epoch_logs: list,
patience: int,
ckpt_interval: int,
) -> tuple[bool, int]:
"""
Run one validation epoch and check early-stopping criteria.
Returns (should_stop, val_step_count).
"""
print(f"--- [VALIDATION] Client {state.client_id} Epoch {epoch+1} ---")
state.client_model.eval()
eval_start = time.time()
epoch_val_losses, eval_metrics = run_eval_epoch(
stub=stub,
client_id=state.client_id,
client_model=state.client_model,
optimizer=state.optimizer,
client_files=state.client_files,
sensor_data_cache=state.sensor_data_cache,
eval_index_cache=state.val_index_cache,
eval_state=state.val_state,
feature_cols=FEATURE_COLS,
feat_stats=state.feat_stats,
device=state.device,
seq_len=state.seq_len,
epoch=epoch,
experimental_logs=state.experimental_logs,
epoch_logs=epoch_logs,
phase_label="VAL",
)
if not epoch_val_losses:
return False, 0
avg_val_loss = sum(epoch_val_losses) / len(epoch_val_losses)
val_summary = summarize_phase(epoch_logs, "VAL")
tp, fn, fp, tn = (
int(eval_metrics["tp"]), int(eval_metrics["fn"]),
int(eval_metrics["fp"]), int(eval_metrics["tn"]),
)
eval_elapsed = max(1e-9, time.time() - eval_start)
print(
f"[CLIENT {state.client_id}] Epoch {epoch+1} val summary | "
f"steps={len(epoch_val_losses)} avg_loss={avg_val_loss:.4f} "
f"rain_acc={val_summary['rain_acc']:.3f} "
f"cls_loss={val_summary['avg_cls_loss']:.4f} reg_loss={val_summary['avg_reg_loss']:.4f} "
f"positive_count={tp + fn} "
f"recall={eval_metrics['recall']:.3f} precision={eval_metrics['precision']:.3f} "
f"f1={eval_metrics['f1']:.3f} "
f"thr={eval_metrics['selected_threshold']:.3f} "
f"(default={eval_metrics['default_threshold']:.3f}, "
f"default_r/p/f1={eval_metrics['default_recall']:.3f}/"
f"{eval_metrics['default_precision']:.3f}/{eval_metrics['default_f1']:.3f}) "
f"cm=TP:{tp}/FN:{fn}/FP:{fp}/TN:{tn} "
f"val_time={eval_elapsed:.2f}s "
f"val_throughput={len(epoch_val_losses) / eval_elapsed:.2f} steps/s"
)
should_stop = evaluate_epoch(
client_id=state.client_id,
client_model=state.client_model,
optimizer=state.optimizer,
current_round=state.current_round,
epoch=epoch,
avg_val_loss=avg_val_loss,
val_metrics=eval_metrics,
session_id=state.session_id,
session_dir=state.session_dir,
periodic_dir=state.periodic_dir,
patience=patience,
ckpt_interval=ckpt_interval,
state=state.checkpoint_state,
)
return should_stop, len(epoch_val_losses)
# ── Step 3b: Single training epoch ───────────────────────────────────────────
def _run_single_epoch(stub, state: _ClientState, epoch: int, epochs: int) -> bool:
"""
Run one full epoch: train → optional FedAvg sync → optional validation.
Updates *state* in-place. Returns True if training should stop.
"""
training_cfg = cfg.get("training", {})
patience = training_cfg.get("early_stopping_patience", 15)
ckpt_interval = training_cfg.get("checkpoint_interval", 10)
local_steps = max(1, int(training_cfg.get("local_steps", 1)))
rain_sample_ratio = float(training_cfg.get("rain_sample_ratio", 0.35))
rain_threshold = rain_threshold_mm()
epoch_start = time.time()
state.client_model.train()
print(f"[EPOCH {epoch+1}/{epochs}] Client {state.client_id} starting...")
epoch_logs: list[dict] = []
# --- Train ---
epoch_train_steps = run_train_epoch(
stub=stub,
client_id=state.client_id,
client_model=state.client_model,
optimizer=state.optimizer,
client_files=state.client_files,
sensor_data_cache=state.sensor_data_cache,
train_state=state.train_state,
feature_cols=FEATURE_COLS,
feat_stats=state.feat_stats,
device=state.device,
local_steps=local_steps,
rain_sample_ratio=rain_sample_ratio,
seq_len=state.seq_len,
epoch=epoch,
experimental_logs=state.experimental_logs,
epoch_logs=epoch_logs,
horizon=state.target_horizon,
rain_threshold=rain_threshold,
)
if epoch_train_steps:
summary = summarize_phase(epoch_logs, "TRAIN")
train_elapsed = max(1e-9, time.time() - epoch_start)
print(
f"[CLIENT {state.client_id}] Epoch {epoch+1} train summary | "
f"steps={epoch_train_steps} avg_loss={summary['avg_loss']:.4f} "
f"rain_acc={summary['rain_acc']:.3f} "
f"cls_loss={summary['avg_cls_loss']:.4f} reg_loss={summary['avg_reg_loss']:.4f} "
f"train_time={train_elapsed:.2f}s "
f"train_throughput={epoch_train_steps / train_elapsed:.2f} steps/s"
)
# Log system metrics and epoch timing
metrics = state.get_system_metrics()
metrics["Epoch_Time_s"] = train_elapsed
metrics["Model_Size_Bytes"] = state.get_model_size_bytes()
for log in epoch_logs:
log.update(metrics)
# --- Sync + validate (every rho epochs) ---
sync_interval = max(1, int(state.train_state.rho))
val_steps = 0
if (epoch + 1) % sync_interval == 0:
print(f"[CLIENT {state.client_id}] Epoch {epoch+1} done. Synchronizing (rho={sync_interval})...")
try:
sync_result = fed_avg_sync(
stub, state.client_id, state.client_model,
model_round=state.model_round, local_epochs=sync_interval,
)
state.client_model = sync_result.client_model
state.current_round = max(state.current_round, sync_result.round_number)
state.model_round = max(state.model_round, sync_result.round_number)
state.total_sync_bytes_sent += sync_result.sync_bytes_sent
state.total_sync_bytes_recv += sync_result.sync_bytes_recv
except Exception as exc:
print(f"[CLIENT {state.client_id}] Sync failed: {exc}")
if "Timeout waiting for global model aggregation" in str(exc):
state.completed_epochs = epoch + 1
state.total_steps += epoch_train_steps
print(f"[CLIENT {state.client_id}] Stopping due to sync timeout after epoch {epoch+1}.")
return True
raise # re-raise gRPC errors so the reconnect loop can catch them
should_stop, val_steps = _run_validation(
stub, state, epoch, epoch_logs, patience, ckpt_interval,
)
if should_stop:
return True
else:
print(
f"[CLIENT {state.client_id}] Epoch {epoch+1} done. "
f"Skip sync (rho={sync_interval}); continuing local training."
)
# --- Commit epoch progress ---
epoch_elapsed = max(1e-9, time.time() - epoch_start)
epoch_steps = epoch_train_steps + val_steps
print(
f"[CLIENT {state.client_id}] Epoch {epoch+1} timing | "
f"total_time={epoch_elapsed:.2f}s total_steps={epoch_steps} "
f"throughput={epoch_steps / epoch_elapsed:.2f} steps/s"
)
state.start_epoch = epoch + 1 # reconnect resumes here
state.completed_epochs = epoch + 1
state.total_steps += epoch_steps
avg_latency, avg_bytes = summarize_logs(state.experimental_logs)
save_progress(
state.client_id, state.experimental_logs, state.session_id,
epoch=epoch + 1,
best_model_path=state.checkpoint_state.best_model_path,
best_test_loss=state.checkpoint_state.best_test_loss if state.checkpoint_state.best_test_loss != float("inf") else None,
avg_latency=avg_latency,
avg_bytes=avg_bytes,
model_size_bytes=state.get_model_size_bytes(),
)
return False
# ── Step 4: Finalise session ──────────────────────────────────────────────────
def _finalize_session(stub, state: _ClientState, epochs: int) -> None:
"""Save final results, print summary, and notify the server of completion."""
# Calculate avg system metrics
cpus = [log.get("CPU_Percent", 0.0) for log in state.experimental_logs if "CPU_Percent" in log]
mems = [log.get("Mem_Percent", 0.0) for log in state.experimental_logs if "Mem_Percent" in log]
avg_cpu = sum(cpus) / len(cpus) if cpus else 0.0
avg_mem = sum(mems) / len(mems) if mems else 0.0
# Track Peak Memory and Total Network
final_metrics = state.get_system_metrics()
net_sent_mb = final_metrics["Net_Sent_MB"]
net_recv_mb = final_metrics["Net_Recv_MB"]
mem_peak_mb = final_metrics["Mem_Peak_MB"]
total_runtime = time.time() - state.run_start_time
avg_latency, avg_bytes = summarize_logs(state.experimental_logs)
save_results(
state.client_id, state.experimental_logs, state.session_id,
best_model_path=state.checkpoint_state.best_model_path,
best_test_loss=state.checkpoint_state.best_test_loss if state.checkpoint_state.best_test_loss != float("inf") else None,
avg_latency=avg_latency,
avg_bytes=avg_bytes,
avg_cpu=avg_cpu,
avg_mem=avg_mem,
total_runtime_s=total_runtime,
model_size_bytes=state.get_model_size_bytes(),
net_sent_mb=net_sent_mb,
net_recv_mb=net_recv_mb,
mem_peak_mb=mem_peak_mb,
sync_bytes_sent_mb=round(state.total_sync_bytes_sent / (1024 * 1024), 4),
sync_bytes_recv_mb=round(state.total_sync_bytes_recv / (1024 * 1024), 4),
actual_seed=state.actual_seed,
)
print_summary(
client_id=state.client_id,
epochs=state.completed_epochs or epochs,
num_logs=len(state.experimental_logs),
best_test_loss=state.checkpoint_state.best_test_loss,
avg_latency=avg_latency,
avg_bytes=avg_bytes,
best_model_path=state.checkpoint_state.best_model_path,
total_runtime_s=total_runtime,
avg_steps_per_s=state.total_steps / max(1e-9, total_runtime),
avg_cpu=avg_cpu,
avg_mem=avg_mem,
actual_seed=state.actual_seed,
)
completion = stub.NotifyCompletion(
fsl_pb2.CompletionRequest(
client_id=state.client_id,
completed_epochs=state.completed_epochs or epochs,
total_steps=state.total_steps,
session_id=state.session_id,
),
metadata=[("scenario-id", os.environ.get("SCENARIO_ID", ""))]
)
print(
f"[CLIENT {state.client_id}] Completion acknowledged by server "
f"({completion.completed_clients}/{completion.total_clients})"
)
state.finalized = True
# ── Entry point ───────────────────────────────────────────────────────────────
def run_all_client(data_dir: str = "dataset/processed", epochs: int = 10) -> None:
"""
Orchestrate the full client lifecycle:
1. Connect → register → init local data & model (once)
2. Train for N epochs, syncing with server every rho rounds
3. Finalise results and notify completion
On transient gRPC failures the connection is re-established and training
resumes from the last committed epoch (up to _MAX_RECONNECT attempts).
"""
# Prevent PyTorch from spawning multiple BLAS threads per process.
# With 11 client processes on 12 vCPUs, each process should own exactly
# one core; letting PyTorch use all 12 threads causes 132-thread contention
# and turns a 15 ms LSTM backward into 7+ seconds.
torch.set_num_threads(max(1, int(cfg.get("training", {}).get("torch_num_threads", 1))))
target_address = resolve_server_address()
compression_mode = cfg.get("compression", {}).get("mode", "float32")
epochs = cfg.get("training", {}).get("num_rounds", epochs)
requested_client_id = _resolve_requested_client_id()
# Fixed name so the server can match this client across reconnects.
# (Docker container hostnames change on restart.)
client_name = (
f"fsl-client-cid{requested_client_id}"
if requested_client_id > 0
else (os.getenv("HOSTNAME") or socket.gethostname())
)
time.sleep(cfg.get("training", {}).get("start_delay", 8))
state = _ClientState()
for attempt in range(_MAX_RECONNECT):
try:
print(
f"[CLIENT] Connecting to {target_address}"
+ (f" (attempt {attempt + 1}/{_MAX_RECONNECT})" if attempt > 0 else "") + "..."
)
with create_grpc_channel(target_address) as channel:
stub = fsl_pb2_grpc.FSLServiceStub(channel)
_register(stub, state, client_name, requested_client_id)
_init_local(state, data_dir, compression_mode)
for epoch in range(state.start_epoch, epochs):
if _run_single_epoch(stub, state, epoch, epochs):
break
_finalize_session(stub, state, epochs)
break # success — exit reconnect loop
except KeyboardInterrupt:
print("[CLIENT] Interrupted by user; saving partial results...")
break
except grpc.RpcError as exc:
if _is_retriable(exc) and attempt < _MAX_RECONNECT - 1:
backoff = _RECONNECT_BACKOFF[attempt]
print(
f"[CLIENT] Connection lost (attempt {attempt + 1}/{_MAX_RECONNECT}): "
f"{exc.details()}. Reconnecting in {backoff}s..."
)
time.sleep(backoff)
else:
print(f"[CLIENT] Fatal gRPC error after {attempt + 1} attempt(s): {exc.details()}")
break
except Exception as exc:
print(f"[CLIENT] Fatal error: {exc}")
break
if not state.finalized and state.client_id is not None and state.session_id is not None:
avg_latency, avg_bytes = summarize_logs(state.experimental_logs)
cpus = [log.get("CPU_Percent", 0.0) for log in state.experimental_logs if "CPU_Percent" in log]
mems = [log.get("Mem_Percent", 0.0) for log in state.experimental_logs if "Mem_Percent" in log]
avg_cpu = sum(cpus) / len(cpus) if cpus else 0.0
avg_mem = sum(mems) / len(mems) if mems else 0.0
total_runtime = time.time() - state.run_start_time
save_results(
state.client_id, state.experimental_logs, state.session_id,
best_model_path=state.checkpoint_state.best_model_path,
best_test_loss=state.checkpoint_state.best_test_loss if state.checkpoint_state.best_test_loss != float("inf") else None,
avg_latency=avg_latency,
avg_bytes=avg_bytes,
avg_cpu=avg_cpu,
avg_mem=avg_mem,
total_runtime_s=total_runtime,
model_size_bytes=state.get_model_size_bytes(),
actual_seed=state.actual_seed,
)
if __name__ == "__main__":
# Prioritise num_rounds from config, fallback to 10
total_rounds = cfg.get("training", {}).get("num_rounds", 10)
run_all_client(epochs=total_rounds)