import io import threading import torch from proto import fsl_pb2 from src.server.fedavg import FedAvgCoordinator def _client_state(fill_value: float) -> dict[str, torch.Tensor]: return { "lstm.weight_ih_l0": torch.full((2, 2), fill_value=fill_value), "lstm.weight_hh_l0": torch.full((2, 2), fill_value=fill_value), } def _decode_state_dict(payload: bytes) -> dict[str, torch.Tensor]: return torch.load(io.BytesIO(payload), weights_only=True, map_location="cpu") def test_partial_participation_and_stale_refresh(tmp_path): session_dir = tmp_path / "session" periodic_dir = session_dir / "periodic" session_dir.mkdir(parents=True) periodic_dir.mkdir(parents=True) server_model = torch.nn.Linear(2, 1) optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001) coordinator = FedAvgCoordinator( num_clients=3, hidden_size=64, session_id="test-session", session_dir=str(session_dir), periodic_dir=str(periodic_dir), ckpt_interval=10, min_clients_per_round=2, round_timeout_sec=5.0, ) for client_id in (1, 2, 3): coordinator.register_client(client_id) responses: dict[int, fsl_pb2.SyncResponse] = {} def do_sync(client_id: int, fill_value: float) -> None: request = fsl_pb2.SyncRequest( client_id=client_id, client_weights=b"placeholder", base_round=0, local_epochs=1, ) responses[client_id] = coordinator.synchronize( request, local_weights=_client_state(fill_value), server_model=server_model, optimizer=optimizer, ) t1 = threading.Thread(target=do_sync, args=(1, 1.0), daemon=True) t1.start() do_sync(2, 3.0) t1.join(timeout=2.0) assert not t1.is_alive() res1 = responses[1] res2 = responses[2] assert res1.accepted is True assert res2.accepted is True assert res1.round_number == 1 assert res2.round_number == 1 aggregated = _decode_state_dict(res1.global_weights) assert torch.allclose(aggregated["lstm.weight_ih_l0"], torch.full((2, 2), 2.0)) assert torch.allclose(aggregated["lstm.weight_hh_l0"], torch.full((2, 2), 2.0)) stale_response = coordinator.synchronize( fsl_pb2.SyncRequest( client_id=3, client_weights=b"placeholder", base_round=0, local_epochs=3, ), local_weights=_client_state(5.0), server_model=server_model, optimizer=optimizer, ) assert stale_response.accepted is False assert stale_response.refresh_only is True assert stale_response.round_number == 1 refreshed = _decode_state_dict(stale_response.global_weights) assert torch.allclose(refreshed["lstm.weight_ih_l0"], torch.full((2, 2), 2.0))