import os import threading from datetime import datetime import grpc import torch from proto import fsl_pb2 from proto import fsl_pb2_grpc from src.models.split_lstm import ServerHead from src.server.bootstrap import run_server from src.server.fedavg import FedAvgCoordinator from src.server.forward_service import handle_forward_request from src.server.reporting import ServerReporter from src.server.scheduler import CompressionScheduler from src.shared.common import cfg, project_root from src.shared.runtime import resolve_device, set_global_seed from src.shared.serialization import bytes_to_tensor class FSLServerServicer(fsl_pb2_grpc.FSLServiceServicer): """gRPC servicer for federated split learning.""" def __init__(self): model_cfg = cfg.get("model", {}) self.hidden_size = model_cfg.get("hidden_size", 64) self.server_head_width = model_cfg.get("server_head_width", 64) self.server_head_dropout = model_cfg.get("server_head_dropout", 0.1) lr = cfg.get("training", {}).get("lr", 0.001) self.seed = set_global_seed(cfg.get("training", {}).get("seed", 42), role="server") self.device = resolve_device() print(f"[SERVER] Using device: {self.device}") self.server_model = ServerHead( hidden_size=self.hidden_size, output_size=1, head_width=self.server_head_width, dropout=self.server_head_dropout, ).to(self.device) self.optimizer = torch.optim.Adam(self.server_model.parameters(), lr=lr) federated_cfg = cfg.get("federated") or {} self.num_clients = federated_cfg.get("num_clients", 3) self.min_clients_per_round = federated_cfg.get("min_clients_per_round", 2) self.round_timeout_sec = federated_cfg.get("round_timeout_sec", 120.0) self.grace_period_sec = federated_cfg.get("grace_period_sec", 0.0) self.sync_lock = threading.Lock() self._next_client_id = 1 self._reg_lock = threading.Lock() self._client_name_to_id: dict[str, int] = {} self._assigned_ids: set[int] = set() self._registered_clients: set[int] = set() self._completion_lock = threading.Lock() self._completed_clients: set[int] = set() self._shutdown_event = threading.Event() self.session_id = os.environ.get("SESSION_ID") or datetime.now().strftime("%Y-%m-%d_%H-%M-%S") self.scenario_id = os.environ.get("SCENARIO_ID") # Weights root: bestweights/[/] if self.scenario_id: self.session_dir = os.path.join(project_root, "bestweights", self.session_id, self.scenario_id) else: self.session_dir = os.path.join(project_root, "bestweights", self.session_id) self.periodic_dir = os.path.join(self.session_dir, "periodic") os.makedirs(self.session_dir, exist_ok=True) os.makedirs(self.periodic_dir, exist_ok=True) # Results/logs root: results/[/] if self.scenario_id: self.results_dir = os.path.join(project_root, "results", self.session_id, self.scenario_id) else: self.results_dir = os.path.join(project_root, "results", self.session_id) os.makedirs(self.results_dir, exist_ok=True) self.ckpt_interval = cfg.get("training", {}).get("checkpoint_interval", 1) print(f"[SERVER] Session ID: {self.session_id} | Scenario: {self.scenario_id or 'None'}") print(f"[SERVER] Weights dir: {self.session_dir}") print(f"[SERVER] Results dir: {self.results_dir}") print(f"[SERVER] Periodic checkpoint every {self.ckpt_interval} rounds -> {self.periodic_dir}") self.fedavg = FedAvgCoordinator( num_clients=self.num_clients, hidden_size=self.hidden_size, session_id=self.session_id, session_dir=self.session_dir, periodic_dir=self.periodic_dir, ckpt_interval=self.ckpt_interval, min_clients_per_round=self.min_clients_per_round, round_timeout_sec=self.round_timeout_sec, grace_period_sec=self.grace_period_sec, max_staleness=federated_cfg.get("max_staleness", 0), ) scheduler_cfg = cfg.get("scheduler", {}) self.scheduler = CompressionScheduler( default_mode=cfg.get("compression", {}).get("mode", "float32"), enabled=scheduler_cfg.get("enabled", True), float16_threshold=scheduler_cfg.get("latency_threshold", 4.0), int8_threshold=scheduler_cfg.get("int8_latency_threshold", 10.0), base_rho=federated_cfg.get("rho", 1), min_rho=scheduler_cfg.get("min_rho", 1), max_rho=scheduler_cfg.get("max_rho", 20), rho_step=scheduler_cfg.get("rho_step", 1), topk_multiplier=scheduler_cfg.get("topk_multiplier", 1.5), latency_ema_alpha=scheduler_cfg.get("latency_ema_alpha", 0.2), ) self.log_server_requests = cfg.get("console", {}).get("log_server_requests", False) self.profiler_enabled = cfg.get("profiler", {}).get("enabled", True) self.scheduler_enabled = scheduler_cfg.get("enabled", True) self.reporter = ServerReporter(session_id=self.session_id, session_dir=self.results_dir) def GetInfo(self, request, context): """Return basic server status metadata.""" return fsl_pb2.ServerInfo( scenario_id=self.scenario_id or "unknown", session_id=self.session_id or "unknown", current_round=self.fedavg.current_round ) def Register(self, request, context): """Assign a client id and return the shared session id.""" with self._reg_lock: # --- Scenario Validation via Metadata --- metadata = dict(context.invocation_metadata()) client_scenario = metadata.get("scenario-id") if client_scenario and self.scenario_id and client_scenario != self.scenario_id: print(f"[SERVER] Rejecting registration from mismatched scenario: {client_scenario} != {self.scenario_id}") return fsl_pb2.RegisterResponse(client_id=-1, total_clients=0, session_id="ERROR_SCENARIO_MISMATCH") client_name = request.client_name or f"client-{self._next_client_id}" requested_id = request.requested_client_id if client_name in self._client_name_to_id: assigned_id = self._client_name_to_id[client_name] elif requested_id > 0 and requested_id not in self._assigned_ids: assigned_id = requested_id self._client_name_to_id[client_name] = assigned_id self._assigned_ids.add(assigned_id) else: while self._next_client_id in self._assigned_ids: self._next_client_id += 1 assigned_id = self._next_client_id self._client_name_to_id[client_name] = assigned_id self._assigned_ids.add(assigned_id) self._next_client_id += 1 if assigned_id >= self._next_client_id: self._next_client_id = assigned_id + 1 print( f"[SERVER] Client registered - name: {client_name} | requested_id: {requested_id or 'auto'} " f"| assigned_id: {assigned_id} | session: {self.session_id}" ) self._registered_clients.add(assigned_id) self.fedavg.register_client(assigned_id) return fsl_pb2.RegisterResponse( client_id=assigned_id, total_clients=self.num_clients, session_id=self.session_id, ) def Forward(self, request, context): """Handle one forward request from a client.""" # --- Scenario Validation via Metadata --- metadata = dict(context.invocation_metadata()) client_scenario = metadata.get("scenario-id") if client_scenario and self.scenario_id and client_scenario != self.scenario_id: return fsl_pb2.ForwardResponse( status_message=f"Scenario mismatch: {client_scenario} != {self.scenario_id}", success=False, ) try: client_id = getattr(request, "client_id", -1) reported_latency = getattr(request, "latency_ms", 0.0) assigned_compression, assigned_rho = self.scheduler.assign(client_id, reported_latency) result = handle_forward_request( request, hidden_size=self.hidden_size, device=self.device, server_model=self.server_model, optimizer=self.optimizer, sync_lock=self.sync_lock, current_round=self.fedavg.current_round, assigned_compression=assigned_compression, assigned_rho=assigned_rho, profiler_enabled=self.profiler_enabled, scheduler_enabled=self.scheduler_enabled, ) self.reporter.record(result.log_entry) if self.log_server_requests: print(result.monitor_message) return result.response except Exception as e: print(f"[SERVER ERROR] Processing failed: {str(e)}") return fsl_pb2.ForwardResponse( status_message=f"Error: {str(e)}", success=False, ) def flush_logs(self): self.reporter.flush() def Synchronize(self, request, context): """Aggregate client weights and return the latest global model.""" # --- Scenario Validation via Metadata --- metadata = dict(context.invocation_metadata()) client_scenario = metadata.get("scenario-id") if client_scenario and self.scenario_id and client_scenario != self.scenario_id: context.set_details(f"Scenario mismatch: {client_scenario} != {self.scenario_id}") context.set_code(grpc.StatusCode.INVALID_ARGUMENT) return fsl_pb2.SyncResponse() try: local_weights = bytes_to_tensor(request.client_weights) return self.fedavg.synchronize( request, local_weights=local_weights, server_model=self.server_model, optimizer=self.optimizer, ) except Exception as e: print(f"[FED AVG ERROR] Synchronization failed: {str(e)}") context.set_details(str(e)) context.set_code(grpc.StatusCode.INTERNAL) return fsl_pb2.SyncResponse() def NotifyCompletion(self, request, context): """Record client completion and emit a server-side all-finished signal.""" if request.session_id != self.session_id: print( f"[SERVER] Ignoring completion from client {request.client_id}: " f"session mismatch ({request.session_id} != {self.session_id})" ) return fsl_pb2.CompletionResponse( acknowledged=False, completed_clients=len(self._completed_clients), total_clients=self.num_clients, ) with self._completion_lock: if request.client_id not in self._registered_clients: print( f"[SERVER] Ignoring completion from unregistered client {request.client_id} " f"(session={self.session_id})" ) return fsl_pb2.CompletionResponse( acknowledged=False, completed_clients=len(self._completed_clients), total_clients=self.num_clients, ) if request.client_id in self._completed_clients: completed = len(self._completed_clients) print( f"[SERVER] Duplicate completion ignored for client {request.client_id} " f"(completed={completed}/{self.num_clients})" ) return fsl_pb2.CompletionResponse( acknowledged=True, completed_clients=completed, total_clients=self.num_clients, ) self._completed_clients.add(request.client_id) completed = len(self._completed_clients) self.fedavg.mark_client_completed( request.client_id, server_model=self.server_model, optimizer=self.optimizer, ) print( f"[SERVER] Client {request.client_id} finished training | " f"epochs={request.completed_epochs} steps={request.total_steps} | " f"completed={completed}/{self.num_clients}" ) # Shutdown if everyone who showed up is finished if completed >= len(self._registered_clients) and len(self._registered_clients) > 0: print(f"[SERVER] ALL REGISTERED CLIENTS FINISHED ({completed}/{len(self._registered_clients)}) | session={self.session_id}") self.flush_logs() self._shutdown_event.set() return fsl_pb2.CompletionResponse( acknowledged=True, completed_clients=completed, total_clients=self.num_clients, ) def should_shutdown(self) -> bool: """Signal bootstrap loop to stop server once all clients are done.""" return self._shutdown_event.is_set() def serve(): torch.set_num_threads(max(1, int(cfg.get("training", {}).get("torch_num_threads", 1)))) run_server(FSLServerServicer()) if __name__ == "__main__": serve()