import copy
import os
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import threading
from typing import Any
import torch
from proto import fsl_pb2
from src.shared.config_artifacts import build_config_ref, build_config_snapshot
from src.shared.serialization import tensor_to_bytes
from src.shared.common import cfg
@dataclass
class PendingUpdate:
client_id: int
weights: dict[str, torch.Tensor]
base_round: int
local_epochs: int
arrived_at: float
class FedAvgCoordinator:
def __init__(
self,
*,
num_clients: int,
hidden_size: int,
session_id: str,
session_dir: str,
periodic_dir: str,
ckpt_interval: int,
min_clients_per_round: int = 2,
round_timeout_sec: float = 15.0,
grace_period_sec: float = 0.0,
max_staleness: int = 0,
):
self.num_clients = num_clients
self.hidden_size = hidden_size
self.session_id = session_id
self.session_dir = session_dir
self.periodic_dir = periodic_dir
self.ckpt_interval = ckpt_interval
self.min_clients_per_round = max(1, int(min_clients_per_round))
self.round_timeout_sec = max(0.1, float(round_timeout_sec))
self.grace_period_sec = max(0.0, float(grace_period_sec))
# max_staleness=0: strict synchronous FedAvg (legacy default).
# max_staleness=N: accept updates up to N rounds stale; weight by
# local_epochs / (1 + staleness) to discount older contributions.
self.max_staleness = max(0, int(max_staleness))
self._quorum_reached_at: float | None = None
self._startup_deadline: float | None = None
self.client_weights_buffer: dict[int, PendingUpdate] = {}
self.global_weights = None
self.current_round = 0
self.lock = threading.Lock()
self.round_cond = threading.Condition(self.lock)
self.round_start_time: float | None = None
self.round_error: str | None = None
self._expected_schema: dict[str, tuple[tuple[int, ...], torch.dtype]] | None = None
self._active_clients: set[int] = set()
self._completed_clients: set[int] = set()
@staticmethod
def _validate_weights_object(local_weights: Any) -> dict[str, torch.Tensor]:
if not isinstance(local_weights, dict) or not local_weights:
raise ValueError("Client weights must be a non-empty state_dict-like dict.")
normalized: dict[str, torch.Tensor] = {}
for key, value in local_weights.items():
if not isinstance(key, str):
raise ValueError("State dict keys must be strings.")
if not isinstance(value, torch.Tensor):
raise ValueError(f"State dict value for key '{key}' is not a torch.Tensor.")
normalized[key] = value
return normalized
@staticmethod
def _build_schema(weights: dict[str, torch.Tensor]) -> dict[str, tuple[tuple[int, ...], torch.dtype]]:
return {key: (tuple(tensor.shape), tensor.dtype) for key, tensor in weights.items()}
def _validate_against_schema(self, weights: dict[str, torch.Tensor]) -> None:
schema = self._expected_schema
if schema is None:
self._expected_schema = self._build_schema(weights)
return
current_keys = set(weights.keys())
expected_keys = set(schema.keys())
if current_keys != expected_keys:
missing = sorted(expected_keys - current_keys)
extra = sorted(current_keys - expected_keys)
raise ValueError(
f"State dict key mismatch. missing={missing[:5]} extra={extra[:5]}"
)
for key, tensor in weights.items():
expected_shape, expected_dtype = schema[key]
if tuple(tensor.shape) != expected_shape:
raise ValueError(
f"Tensor shape mismatch for '{key}': got {tuple(tensor.shape)} expected {expected_shape}"
)
if tensor.dtype != expected_dtype:
raise ValueError(
f"Tensor dtype mismatch for '{key}': got {tensor.dtype} expected {expected_dtype}"
)
def register_client(self, client_id: int) -> None:
with self.round_cond:
if client_id not in self._completed_clients:
self._active_clients.add(int(client_id))
self.round_cond.notify_all()
def mark_client_completed(self, client_id: int, *, server_model, optimizer) -> None:
with self.round_cond:
client_id = int(client_id)
self._completed_clients.add(client_id)
self._active_clients.discard(client_id)
print(
f"[FED AVG] Client {client_id} marked complete. "
f"active_clients={len(self._active_clients)}"
)
if self.client_weights_buffer and self._has_quorum_locked():
self._aggregate_locked(
server_model=server_model,
optimizer=optimizer,
reason="active-set-updated",
)
self.round_cond.notify_all()
def synchronize(self, request, *, local_weights, server_model, optimizer) -> fsl_pb2.SyncResponse:
client_id = int(request.client_id)
base_round = int(getattr(request, "base_round", self.current_round))
local_epochs = int(getattr(request, "local_epochs", 0))
with self.round_cond:
if client_id not in self._completed_clients:
self._active_clients.add(client_id)
# --- Safety Catch: Prevent rounds exceeding config ---
num_rounds = cfg.get("training", {}).get("num_rounds", 30)
if self.current_round >= num_rounds:
print(f"[FED AVG] Target rounds ({num_rounds}) reached. Rejecting update from Client:{client_id}")
return self._build_sync_response_locked(
accepted=False,
applied_round=self.current_round,
status_message="FINISHED",
refresh_only=True
)
if base_round < self.current_round:
staleness = self.current_round - base_round
if staleness > self.max_staleness:
print(
f"[FED AVG] Rejecting overly stale update from Client:{client_id} "
f"(staleness={staleness} > max_staleness={self.max_staleness})"
)
return self._build_sync_response_locked(
accepted=False,
applied_round=0,
refresh_only=True,
status_message=(
f"Stale update rejected: staleness={staleness} "
f"exceeds max_staleness={self.max_staleness}."
),
)
print(
f"[FED AVG] Accepting stale update from Client:{client_id} "
f"(staleness={staleness}, max_staleness={self.max_staleness})"
)
if base_round > self.current_round:
print(
f"[FED AVG] Client:{client_id} is ahead of server state "
f"(base_round={base_round}, current_round={self.current_round})"
)
return self._build_sync_response_locked(
accepted=False,
applied_round=0,
refresh_only=self.global_weights is not None,
status_message=(
f"Client is ahead of server state: client base_round={base_round}, "
f"server current_round={self.current_round}."
),
)
# Before the first round, wait until enough clients have connected.
# Deadline is measured from when the FIRST client enters this wait,
# not from server startup — the server may have been up for a long time
# before clients are deployed via Ansible.
if self.current_round == 0:
if self._startup_deadline is None:
self._startup_deadline = time.time() + self.round_timeout_sec
print(
f"[FED AVG] Startup wait begun: waiting up to {self.round_timeout_sec:.0f}s "
f"for {self.min_clients_per_round} clients "
f"(currently {len(self._active_clients)} registered)."
)
while len(self._active_clients) < self.min_clients_per_round:
remaining = self._startup_deadline - time.time()
if remaining <= 0:
print(
f"[FED AVG] Startup wait timed out: only "
f"{len(self._active_clients)}/{self.min_clients_per_round} "
f"clients connected. Proceeding with current active set."
)
break
self.round_cond.wait(timeout=min(remaining, 5.0))
target_round = self.current_round + 1
now = time.time()
if not self.client_weights_buffer:
self.round_start_time = now
self.round_error = None
validated_weights = self._validate_weights_object(local_weights)
self._validate_against_schema(validated_weights)
self.client_weights_buffer[client_id] = PendingUpdate(
client_id=client_id,
weights=validated_weights,
base_round=base_round,
local_epochs=local_epochs,
arrived_at=now,
)
barrier_elapsed_s = (now - self.round_start_time) if self.round_start_time is not None else 0.0
required = self._required_clients_locked()
print(
f"[FED AVG] Received weights from Client:{client_id}. "
f"Buffer size: {len(self.client_weights_buffer)}/{required} "
f"| active_clients={len(self._active_clients)} "
f"| base_round={base_round} local_epochs={local_epochs} "
f"| barrier_elapsed={barrier_elapsed_s:.2f}s"
)
self._maybe_aggregate_locked(
server_model=server_model,
optimizer=optimizer,
reason="quorum-reached",
)
while self.current_round < target_round and not self.round_error:
remaining = self._remaining_window_locked()
if remaining <= 0:
if self.client_weights_buffer and self.current_round < target_round:
self._aggregate_locked(
server_model=server_model,
optimizer=optimizer,
reason="timeout",
)
break
self.round_cond.wait(timeout=remaining)
if self.current_round < target_round and not self.round_error:
self._maybe_aggregate_locked(
server_model=server_model,
optimizer=optimizer,
reason="quorum-reached",
)
if self.round_error:
raise RuntimeError(self.round_error)
if self.current_round < target_round or self.global_weights is None:
raise TimeoutError("Timeout waiting for global model aggregation.")
return self._build_sync_response_locked(
accepted=True,
applied_round=self.current_round,
refresh_only=False,
status_message=(
f"Accepted into aggregation round {self.current_round} "
f"(active_clients={len(self._active_clients)})."
),
)
def _active_client_count_locked(self) -> int:
return len(self._active_clients)
def _required_clients_locked(self) -> int:
active_count = self._active_client_count_locked()
if active_count <= 0:
return 1
return max(1, min(self.min_clients_per_round, active_count))
def _remaining_window_locked(self) -> float:
if self.round_start_time is None:
return self.round_timeout_sec
elapsed = time.time() - self.round_start_time
timeout_remaining = max(0.0, self.round_timeout_sec - elapsed)
# If quorum is reached and grace period is active, wake up sooner to check it.
if self._quorum_reached_at is not None and self.grace_period_sec > 0.0:
grace_remaining = max(0.0, self.grace_period_sec - (time.time() - self._quorum_reached_at))
return min(timeout_remaining, grace_remaining)
return timeout_remaining
def _has_quorum_locked(self) -> bool:
return len(self.client_weights_buffer) >= self._required_clients_locked()
def _grace_period_elapsed_locked(self) -> bool:
"""True if grace period has passed since quorum was first reached."""
if self.grace_period_sec <= 0.0:
return True
if self._quorum_reached_at is None:
return False
return (time.time() - self._quorum_reached_at) >= self.grace_period_sec
def _maybe_aggregate_locked(self, *, server_model, optimizer, reason: str) -> None:
if not (self.client_weights_buffer and self._has_quorum_locked()):
return
if self._quorum_reached_at is None:
self._quorum_reached_at = time.time()
if self.grace_period_sec > 0.0:
print(
f"[FED AVG] Quorum reached ({len(self.client_weights_buffer)}/"
f"{self._required_clients_locked()}), "
f"waiting grace period {self.grace_period_sec:.0f}s for stragglers..."
)
if self._grace_period_elapsed_locked():
self._aggregate_locked(server_model=server_model, optimizer=optimizer, reason=reason)
def _build_sync_response_locked(
self,
*,
accepted: bool,
applied_round: int,
refresh_only: bool,
status_message: str,
) -> fsl_pb2.SyncResponse:
global_weights_bytes = tensor_to_bytes(self.global_weights) if self.global_weights is not None else b""
return fsl_pb2.SyncResponse(
global_weights=global_weights_bytes,
round_number=int(self.current_round),
accepted=bool(accepted),
applied_round=int(applied_round),
refresh_only=bool(refresh_only),
status_message=status_message,
)
def _aggregate_locked(self, *, server_model, optimizer, reason: str) -> None:
aggregate_start = time.time()
pending_updates = list(self.client_weights_buffer.values())
client_ids = [update.client_id for update in pending_updates]
barrier_elapsed_s = (time.time() - self.round_start_time) if self.round_start_time is not None else 0.0
print(
f"[FED AVG] Round {self.current_round + 1}: Aggregating {len(pending_updates)} models "
f"from clients={client_ids} (active={len(self._active_clients)} "
f"quorum={self._required_clients_locked()} reason={reason} "
f"waited={barrier_elapsed_s:.2f}s)"
)
# Staleness-discounted weighting: w_i = local_epochs_i / (1 + staleness_i)
# When max_staleness=0 all staleness values are 0, reducing to plain
# local-epoch-weighted FedAvg (identical to prior behaviour).
raw_weights = [
u.local_epochs / (1.0 + max(0, self.current_round - u.base_round))
for u in pending_updates
]
total_weight = sum(raw_weights) or float(len(pending_updates))
stalenesses = [max(0, self.current_round - u.base_round) for u in pending_updates]
print(
f"[FED AVG] Staleness per client: "
+ ", ".join(f"C{u.client_id}={s}" for u, s in zip(pending_updates, stalenesses))
)
self.global_weights = copy.deepcopy(pending_updates[0].weights)
for key in self.global_weights.keys():
self.global_weights[key] = sum(
u.weights[key] * (w / total_weight)
for u, w in zip(pending_updates, raw_weights)
)
self.client_weights_buffer = {}
self._quorum_reached_at = None
self.current_round += 1
stamp = datetime.now().strftime("%Y%m%d%H%M%S")
config_snapshot, snapshot_policy = build_config_snapshot()
server_ckpt = {
"round": self.current_round,
"model_state_dict": server_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"config": {
"hidden_size": self.hidden_size,
},
"config_snapshot_policy": snapshot_policy,
"config_ref": build_config_ref(),
"session_id": self.session_id,
"aggregated_client_ids": client_ids,
"aggregation_reason": reason,
}
if config_snapshot is not None:
server_ckpt["config_snapshot"] = config_snapshot
server_model_path = os.path.join(
self.session_dir,
f"server_head_round_{self.current_round}_{stamp}.pth",
)
torch.save(server_ckpt, server_model_path)
old_checkpoints = sorted(
Path(self.session_dir).glob("server_head_round_*.pth"),
key=self._checkpoint_sort_key,
)
for old_ckpt in old_checkpoints[:-1]:
try:
os.remove(old_ckpt)
except Exception:
pass
if self.current_round % self.ckpt_interval == 0:
periodic_path = os.path.join(
self.periodic_dir,
f"server_round_{self.current_round:04d}.pth",
)
torch.save(server_ckpt, periodic_path)
print(f"[SERVER] Periodic ckpt saved: round {self.current_round:04d}")
aggregate_elapsed_s = time.time() - aggregate_start
self.round_start_time = None
self.round_error = None
print(
f"[FED AVG] Successfully updated global model to Round {self.current_round} "
f"(aggregate_time={aggregate_elapsed_s:.2f}s)"
)
print(f"[SERVER] Best ckpt: {self.session_id}/{Path(server_model_path).name}")
self.round_cond.notify_all()
@staticmethod
def _checkpoint_sort_key(path: Path) -> tuple[int, str]:
stem_parts = path.stem.split("_")
try:
round_number = int(stem_parts[3])
except (IndexError, ValueError):
round_number = -1
return round_number, path.name