import copy import os import time from dataclasses import dataclass from datetime import datetime from pathlib import Path import threading from typing import Any import torch from proto import fsl_pb2 from src.shared.config_artifacts import build_config_ref, build_config_snapshot from src.shared.serialization import tensor_to_bytes from src.shared.common import cfg @dataclass class PendingUpdate: client_id: int weights: dict[str, torch.Tensor] base_round: int local_epochs: int arrived_at: float class FedAvgCoordinator: def __init__( self, *, num_clients: int, hidden_size: int, session_id: str, session_dir: str, periodic_dir: str, ckpt_interval: int, min_clients_per_round: int = 2, round_timeout_sec: float = 15.0, grace_period_sec: float = 0.0, max_staleness: int = 0, ): self.num_clients = num_clients self.hidden_size = hidden_size self.session_id = session_id self.session_dir = session_dir self.periodic_dir = periodic_dir self.ckpt_interval = ckpt_interval self.min_clients_per_round = max(1, int(min_clients_per_round)) self.round_timeout_sec = max(0.1, float(round_timeout_sec)) self.grace_period_sec = max(0.0, float(grace_period_sec)) # max_staleness=0: strict synchronous FedAvg (legacy default). # max_staleness=N: accept updates up to N rounds stale; weight by # local_epochs / (1 + staleness) to discount older contributions. self.max_staleness = max(0, int(max_staleness)) self._quorum_reached_at: float | None = None self._startup_deadline: float | None = None self.client_weights_buffer: dict[int, PendingUpdate] = {} self.global_weights = None self.current_round = 0 self.lock = threading.Lock() self.round_cond = threading.Condition(self.lock) self.round_start_time: float | None = None self.round_error: str | None = None self._expected_schema: dict[str, tuple[tuple[int, ...], torch.dtype]] | None = None self._active_clients: set[int] = set() self._completed_clients: set[int] = set() @staticmethod def _validate_weights_object(local_weights: Any) -> dict[str, torch.Tensor]: if not isinstance(local_weights, dict) or not local_weights: raise ValueError("Client weights must be a non-empty state_dict-like dict.") normalized: dict[str, torch.Tensor] = {} for key, value in local_weights.items(): if not isinstance(key, str): raise ValueError("State dict keys must be strings.") if not isinstance(value, torch.Tensor): raise ValueError(f"State dict value for key '{key}' is not a torch.Tensor.") normalized[key] = value return normalized @staticmethod def _build_schema(weights: dict[str, torch.Tensor]) -> dict[str, tuple[tuple[int, ...], torch.dtype]]: return {key: (tuple(tensor.shape), tensor.dtype) for key, tensor in weights.items()} def _validate_against_schema(self, weights: dict[str, torch.Tensor]) -> None: schema = self._expected_schema if schema is None: self._expected_schema = self._build_schema(weights) return current_keys = set(weights.keys()) expected_keys = set(schema.keys()) if current_keys != expected_keys: missing = sorted(expected_keys - current_keys) extra = sorted(current_keys - expected_keys) raise ValueError( f"State dict key mismatch. missing={missing[:5]} extra={extra[:5]}" ) for key, tensor in weights.items(): expected_shape, expected_dtype = schema[key] if tuple(tensor.shape) != expected_shape: raise ValueError( f"Tensor shape mismatch for '{key}': got {tuple(tensor.shape)} expected {expected_shape}" ) if tensor.dtype != expected_dtype: raise ValueError( f"Tensor dtype mismatch for '{key}': got {tensor.dtype} expected {expected_dtype}" ) def register_client(self, client_id: int) -> None: with self.round_cond: if client_id not in self._completed_clients: self._active_clients.add(int(client_id)) self.round_cond.notify_all() def mark_client_completed(self, client_id: int, *, server_model, optimizer) -> None: with self.round_cond: client_id = int(client_id) self._completed_clients.add(client_id) self._active_clients.discard(client_id) print( f"[FED AVG] Client {client_id} marked complete. " f"active_clients={len(self._active_clients)}" ) if self.client_weights_buffer and self._has_quorum_locked(): self._aggregate_locked( server_model=server_model, optimizer=optimizer, reason="active-set-updated", ) self.round_cond.notify_all() def synchronize(self, request, *, local_weights, server_model, optimizer) -> fsl_pb2.SyncResponse: client_id = int(request.client_id) base_round = int(getattr(request, "base_round", self.current_round)) local_epochs = int(getattr(request, "local_epochs", 0)) with self.round_cond: if client_id not in self._completed_clients: self._active_clients.add(client_id) # --- Safety Catch: Prevent rounds exceeding config --- num_rounds = cfg.get("training", {}).get("num_rounds", 30) if self.current_round >= num_rounds: print(f"[FED AVG] Target rounds ({num_rounds}) reached. Rejecting update from Client:{client_id}") return self._build_sync_response_locked( accepted=False, applied_round=self.current_round, status_message="FINISHED", refresh_only=True ) if base_round < self.current_round: staleness = self.current_round - base_round if staleness > self.max_staleness: print( f"[FED AVG] Rejecting overly stale update from Client:{client_id} " f"(staleness={staleness} > max_staleness={self.max_staleness})" ) return self._build_sync_response_locked( accepted=False, applied_round=0, refresh_only=True, status_message=( f"Stale update rejected: staleness={staleness} " f"exceeds max_staleness={self.max_staleness}." ), ) print( f"[FED AVG] Accepting stale update from Client:{client_id} " f"(staleness={staleness}, max_staleness={self.max_staleness})" ) if base_round > self.current_round: print( f"[FED AVG] Client:{client_id} is ahead of server state " f"(base_round={base_round}, current_round={self.current_round})" ) return self._build_sync_response_locked( accepted=False, applied_round=0, refresh_only=self.global_weights is not None, status_message=( f"Client is ahead of server state: client base_round={base_round}, " f"server current_round={self.current_round}." ), ) # Before the first round, wait until enough clients have connected. # Deadline is measured from when the FIRST client enters this wait, # not from server startup — the server may have been up for a long time # before clients are deployed via Ansible. if self.current_round == 0: if self._startup_deadline is None: self._startup_deadline = time.time() + self.round_timeout_sec print( f"[FED AVG] Startup wait begun: waiting up to {self.round_timeout_sec:.0f}s " f"for {self.min_clients_per_round} clients " f"(currently {len(self._active_clients)} registered)." ) while len(self._active_clients) < self.min_clients_per_round: remaining = self._startup_deadline - time.time() if remaining <= 0: print( f"[FED AVG] Startup wait timed out: only " f"{len(self._active_clients)}/{self.min_clients_per_round} " f"clients connected. Proceeding with current active set." ) break self.round_cond.wait(timeout=min(remaining, 5.0)) target_round = self.current_round + 1 now = time.time() if not self.client_weights_buffer: self.round_start_time = now self.round_error = None validated_weights = self._validate_weights_object(local_weights) self._validate_against_schema(validated_weights) self.client_weights_buffer[client_id] = PendingUpdate( client_id=client_id, weights=validated_weights, base_round=base_round, local_epochs=local_epochs, arrived_at=now, ) barrier_elapsed_s = (now - self.round_start_time) if self.round_start_time is not None else 0.0 required = self._required_clients_locked() print( f"[FED AVG] Received weights from Client:{client_id}. " f"Buffer size: {len(self.client_weights_buffer)}/{required} " f"| active_clients={len(self._active_clients)} " f"| base_round={base_round} local_epochs={local_epochs} " f"| barrier_elapsed={barrier_elapsed_s:.2f}s" ) self._maybe_aggregate_locked( server_model=server_model, optimizer=optimizer, reason="quorum-reached", ) while self.current_round < target_round and not self.round_error: remaining = self._remaining_window_locked() if remaining <= 0: if self.client_weights_buffer and self.current_round < target_round: self._aggregate_locked( server_model=server_model, optimizer=optimizer, reason="timeout", ) break self.round_cond.wait(timeout=remaining) if self.current_round < target_round and not self.round_error: self._maybe_aggregate_locked( server_model=server_model, optimizer=optimizer, reason="quorum-reached", ) if self.round_error: raise RuntimeError(self.round_error) if self.current_round < target_round or self.global_weights is None: raise TimeoutError("Timeout waiting for global model aggregation.") return self._build_sync_response_locked( accepted=True, applied_round=self.current_round, refresh_only=False, status_message=( f"Accepted into aggregation round {self.current_round} " f"(active_clients={len(self._active_clients)})." ), ) def _active_client_count_locked(self) -> int: return len(self._active_clients) def _required_clients_locked(self) -> int: active_count = self._active_client_count_locked() if active_count <= 0: return 1 return max(1, min(self.min_clients_per_round, active_count)) def _remaining_window_locked(self) -> float: if self.round_start_time is None: return self.round_timeout_sec elapsed = time.time() - self.round_start_time timeout_remaining = max(0.0, self.round_timeout_sec - elapsed) # If quorum is reached and grace period is active, wake up sooner to check it. if self._quorum_reached_at is not None and self.grace_period_sec > 0.0: grace_remaining = max(0.0, self.grace_period_sec - (time.time() - self._quorum_reached_at)) return min(timeout_remaining, grace_remaining) return timeout_remaining def _has_quorum_locked(self) -> bool: return len(self.client_weights_buffer) >= self._required_clients_locked() def _grace_period_elapsed_locked(self) -> bool: """True if grace period has passed since quorum was first reached.""" if self.grace_period_sec <= 0.0: return True if self._quorum_reached_at is None: return False return (time.time() - self._quorum_reached_at) >= self.grace_period_sec def _maybe_aggregate_locked(self, *, server_model, optimizer, reason: str) -> None: if not (self.client_weights_buffer and self._has_quorum_locked()): return if self._quorum_reached_at is None: self._quorum_reached_at = time.time() if self.grace_period_sec > 0.0: print( f"[FED AVG] Quorum reached ({len(self.client_weights_buffer)}/" f"{self._required_clients_locked()}), " f"waiting grace period {self.grace_period_sec:.0f}s for stragglers..." ) if self._grace_period_elapsed_locked(): self._aggregate_locked(server_model=server_model, optimizer=optimizer, reason=reason) def _build_sync_response_locked( self, *, accepted: bool, applied_round: int, refresh_only: bool, status_message: str, ) -> fsl_pb2.SyncResponse: global_weights_bytes = tensor_to_bytes(self.global_weights) if self.global_weights is not None else b"" return fsl_pb2.SyncResponse( global_weights=global_weights_bytes, round_number=int(self.current_round), accepted=bool(accepted), applied_round=int(applied_round), refresh_only=bool(refresh_only), status_message=status_message, ) def _aggregate_locked(self, *, server_model, optimizer, reason: str) -> None: aggregate_start = time.time() pending_updates = list(self.client_weights_buffer.values()) client_ids = [update.client_id for update in pending_updates] barrier_elapsed_s = (time.time() - self.round_start_time) if self.round_start_time is not None else 0.0 print( f"[FED AVG] Round {self.current_round + 1}: Aggregating {len(pending_updates)} models " f"from clients={client_ids} (active={len(self._active_clients)} " f"quorum={self._required_clients_locked()} reason={reason} " f"waited={barrier_elapsed_s:.2f}s)" ) # Staleness-discounted weighting: w_i = local_epochs_i / (1 + staleness_i) # When max_staleness=0 all staleness values are 0, reducing to plain # local-epoch-weighted FedAvg (identical to prior behaviour). raw_weights = [ u.local_epochs / (1.0 + max(0, self.current_round - u.base_round)) for u in pending_updates ] total_weight = sum(raw_weights) or float(len(pending_updates)) stalenesses = [max(0, self.current_round - u.base_round) for u in pending_updates] print( f"[FED AVG] Staleness per client: " + ", ".join(f"C{u.client_id}={s}" for u, s in zip(pending_updates, stalenesses)) ) self.global_weights = copy.deepcopy(pending_updates[0].weights) for key in self.global_weights.keys(): self.global_weights[key] = sum( u.weights[key] * (w / total_weight) for u, w in zip(pending_updates, raw_weights) ) self.client_weights_buffer = {} self._quorum_reached_at = None self.current_round += 1 stamp = datetime.now().strftime("%Y%m%d%H%M%S") config_snapshot, snapshot_policy = build_config_snapshot() server_ckpt = { "round": self.current_round, "model_state_dict": server_model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "config": { "hidden_size": self.hidden_size, }, "config_snapshot_policy": snapshot_policy, "config_ref": build_config_ref(), "session_id": self.session_id, "aggregated_client_ids": client_ids, "aggregation_reason": reason, } if config_snapshot is not None: server_ckpt["config_snapshot"] = config_snapshot server_model_path = os.path.join( self.session_dir, f"server_head_round_{self.current_round}_{stamp}.pth", ) torch.save(server_ckpt, server_model_path) old_checkpoints = sorted( Path(self.session_dir).glob("server_head_round_*.pth"), key=self._checkpoint_sort_key, ) for old_ckpt in old_checkpoints[:-1]: try: os.remove(old_ckpt) except Exception: pass if self.current_round % self.ckpt_interval == 0: periodic_path = os.path.join( self.periodic_dir, f"server_round_{self.current_round:04d}.pth", ) torch.save(server_ckpt, periodic_path) print(f"[SERVER] Periodic ckpt saved: round {self.current_round:04d}") aggregate_elapsed_s = time.time() - aggregate_start self.round_start_time = None self.round_error = None print( f"[FED AVG] Successfully updated global model to Round {self.current_round} " f"(aggregate_time={aggregate_elapsed_s:.2f}s)" ) print(f"[SERVER] Best ckpt: {self.session_id}/{Path(server_model_path).name}") self.round_cond.notify_all() @staticmethod def _checkpoint_sort_key(path: Path) -> tuple[int, str]: stem_parts = path.stem.split("_") try: round_number = int(stem_parts[3]) except (IndexError, ValueError): round_number = -1 return round_number, path.name