import io import os import time from dataclasses import dataclass import torch from proto import fsl_pb2 from src.models.split_lstm import ClientLSTM @dataclass class SyncResult: client_model: ClientLSTM round_number: int accepted: bool applied_round: int refresh_only: bool status_message: str sync_bytes_sent: int = 0 # bytes sent (client weights upload) sync_bytes_recv: int = 0 # bytes received (global weights download) def fed_avg_sync( stub, client_id: int, client_model: ClientLSTM, *, model_round: int, local_epochs: int, ) -> SyncResult: """ Synchronize local client weights with the server and load the returned global model. """ buffer = io.BytesIO() torch.save(client_model.state_dict(), buffer) client_weights_bytes = buffer.getvalue() sync_req = fsl_pb2.SyncRequest( client_id=client_id, client_weights=client_weights_bytes, base_round=int(model_round), local_epochs=int(local_epochs), ) print( f"[CLIENT {client_id}] Waiting for global aggregation... " f"(base_round={int(model_round)} local_epochs={int(local_epochs)})" ) wait_start = time.time() sync_res = stub.Synchronize(sync_req, metadata=[("scenario-id", os.environ.get("SCENARIO_ID", ""))]) wait_elapsed_s = time.time() - wait_start round_number = int(getattr(sync_res, "round_number", model_round)) accepted = bool(getattr(sync_res, "accepted", False)) applied_round = int(getattr(sync_res, "applied_round", 0)) refresh_only = bool(getattr(sync_res, "refresh_only", False)) status_message = str(getattr(sync_res, "status_message", "")).strip() if sync_res.global_weights: global_buffer = io.BytesIO(sync_res.global_weights) global_state_dict = torch.load(global_buffer, weights_only=True, map_location="cpu") client_model.load_state_dict(global_state_dict) if accepted: print( f"[CLIENT {client_id}] Global model updated to Round {round_number} " f"(applied_round={applied_round} sync_wait={wait_elapsed_s:.2f}s)" ) elif refresh_only: print( f"[CLIENT {client_id}] Refreshed local model to Round {round_number} " f"without contributing this update (sync_wait={wait_elapsed_s:.2f}s)" ) else: print( f"[CLIENT {client_id}] Synchronization completed without model update " f"(sync_wait={wait_elapsed_s:.2f}s)" ) if status_message: print(f"[CLIENT {client_id}] Sync status: {status_message}") sync_bytes_sent = len(client_weights_bytes) sync_bytes_recv = len(sync_res.global_weights) if sync_res.global_weights else 0 return SyncResult( client_model=client_model, round_number=round_number, accepted=accepted, applied_round=applied_round, refresh_only=refresh_only, status_message=status_message, sync_bytes_sent=sync_bytes_sent, sync_bytes_recv=sync_bytes_recv, )