csc8114 / code / src / nodes / server_node.py
server_node.py
Raw
import os
import threading
from datetime import datetime

import grpc
import torch
from proto import fsl_pb2
from proto import fsl_pb2_grpc
from src.models.split_lstm import ServerHead
from src.server.bootstrap import run_server
from src.server.fedavg import FedAvgCoordinator
from src.server.forward_service import handle_forward_request
from src.server.reporting import ServerReporter
from src.server.scheduler import CompressionScheduler

from src.shared.common import cfg, project_root
from src.shared.runtime import resolve_device, set_global_seed
from src.shared.serialization import bytes_to_tensor


class FSLServerServicer(fsl_pb2_grpc.FSLServiceServicer):
    """gRPC servicer for federated split learning."""

    def __init__(self):
        model_cfg = cfg.get("model", {})
        self.hidden_size = model_cfg.get("hidden_size", 64)
        self.server_head_width = model_cfg.get("server_head_width", 64)
        self.server_head_dropout = model_cfg.get("server_head_dropout", 0.1)
        lr = cfg.get("training", {}).get("lr", 0.001)
        self.seed = set_global_seed(cfg.get("training", {}).get("seed", 42), role="server")
        self.device = resolve_device()
        print(f"[SERVER] Using device: {self.device}")

        self.server_model = ServerHead(
            hidden_size=self.hidden_size,
            output_size=1,
            head_width=self.server_head_width,
            dropout=self.server_head_dropout,
        ).to(self.device)
        self.optimizer = torch.optim.Adam(self.server_model.parameters(), lr=lr)

        federated_cfg = cfg.get("federated") or {}
        self.num_clients = federated_cfg.get("num_clients", 3)
        self.min_clients_per_round = federated_cfg.get("min_clients_per_round", 2)
        self.round_timeout_sec = federated_cfg.get("round_timeout_sec", 120.0)
        self.grace_period_sec = federated_cfg.get("grace_period_sec", 0.0)
        self.sync_lock = threading.Lock()

        self._next_client_id = 1
        self._reg_lock = threading.Lock()
        self._client_name_to_id: dict[str, int] = {}
        self._assigned_ids: set[int] = set()
        self._registered_clients: set[int] = set()
        self._completion_lock = threading.Lock()
        self._completed_clients: set[int] = set()
        self._shutdown_event = threading.Event()

        self.session_id = os.environ.get("SESSION_ID") or datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.scenario_id = os.environ.get("SCENARIO_ID")
        
        # Weights root: bestweights/<session_id>[/<scenario_id>]
        if self.scenario_id:
            self.session_dir = os.path.join(project_root, "bestweights", self.session_id, self.scenario_id)
        else:
            self.session_dir = os.path.join(project_root, "bestweights", self.session_id)

        self.periodic_dir = os.path.join(self.session_dir, "periodic")
        os.makedirs(self.session_dir, exist_ok=True)
        os.makedirs(self.periodic_dir, exist_ok=True)

        # Results/logs root: results/<session_id>[/<scenario_id>]
        if self.scenario_id:
            self.results_dir = os.path.join(project_root, "results", self.session_id, self.scenario_id)
        else:
            self.results_dir = os.path.join(project_root, "results", self.session_id)
        os.makedirs(self.results_dir, exist_ok=True)

        self.ckpt_interval = cfg.get("training", {}).get("checkpoint_interval", 1)
        print(f"[SERVER] Session ID: {self.session_id} | Scenario: {self.scenario_id or 'None'}")
        print(f"[SERVER] Weights dir: {self.session_dir}")
        print(f"[SERVER] Results dir: {self.results_dir}")
        print(f"[SERVER] Periodic checkpoint every {self.ckpt_interval} rounds -> {self.periodic_dir}")
        self.fedavg = FedAvgCoordinator(
            num_clients=self.num_clients,
            hidden_size=self.hidden_size,
            session_id=self.session_id,
            session_dir=self.session_dir,
            periodic_dir=self.periodic_dir,
            ckpt_interval=self.ckpt_interval,
            min_clients_per_round=self.min_clients_per_round,
            round_timeout_sec=self.round_timeout_sec,
            grace_period_sec=self.grace_period_sec,
            max_staleness=federated_cfg.get("max_staleness", 0),
        )

        scheduler_cfg = cfg.get("scheduler", {})
        self.scheduler = CompressionScheduler(
            default_mode=cfg.get("compression", {}).get("mode", "float32"),
            enabled=scheduler_cfg.get("enabled", True),
            float16_threshold=scheduler_cfg.get("latency_threshold", 4.0),
            int8_threshold=scheduler_cfg.get("int8_latency_threshold", 10.0),
            base_rho=federated_cfg.get("rho", 1),
            min_rho=scheduler_cfg.get("min_rho", 1),
            max_rho=scheduler_cfg.get("max_rho", 20),
            rho_step=scheduler_cfg.get("rho_step", 1),
            topk_multiplier=scheduler_cfg.get("topk_multiplier", 1.5),
            latency_ema_alpha=scheduler_cfg.get("latency_ema_alpha", 0.2),
        )
        self.log_server_requests = cfg.get("console", {}).get("log_server_requests", False)
        self.profiler_enabled = cfg.get("profiler", {}).get("enabled", True)
        self.scheduler_enabled = scheduler_cfg.get("enabled", True)
        self.reporter = ServerReporter(session_id=self.session_id, session_dir=self.results_dir)

    def GetInfo(self, request, context):
        """Return basic server status metadata."""
        return fsl_pb2.ServerInfo(
            scenario_id=self.scenario_id or "unknown",
            session_id=self.session_id or "unknown",
            current_round=self.fedavg.current_round
        )

    def Register(self, request, context):
        """Assign a client id and return the shared session id."""
        with self._reg_lock:
            # --- Scenario Validation via Metadata ---
            metadata = dict(context.invocation_metadata())
            client_scenario = metadata.get("scenario-id")
            if client_scenario and self.scenario_id and client_scenario != self.scenario_id:
                print(f"[SERVER] Rejecting registration from mismatched scenario: {client_scenario} != {self.scenario_id}")
                return fsl_pb2.RegisterResponse(client_id=-1, total_clients=0, session_id="ERROR_SCENARIO_MISMATCH")

            client_name = request.client_name or f"client-{self._next_client_id}"
            requested_id = request.requested_client_id

            if client_name in self._client_name_to_id:
                assigned_id = self._client_name_to_id[client_name]
            elif requested_id > 0 and requested_id not in self._assigned_ids:
                assigned_id = requested_id
                self._client_name_to_id[client_name] = assigned_id
                self._assigned_ids.add(assigned_id)
            else:
                while self._next_client_id in self._assigned_ids:
                    self._next_client_id += 1
                assigned_id = self._next_client_id
                self._client_name_to_id[client_name] = assigned_id
                self._assigned_ids.add(assigned_id)
                self._next_client_id += 1

            if assigned_id >= self._next_client_id:
                self._next_client_id = assigned_id + 1

        print(
            f"[SERVER] Client registered - name: {client_name} | requested_id: {requested_id or 'auto'} "
            f"| assigned_id: {assigned_id} | session: {self.session_id}"
        )
        self._registered_clients.add(assigned_id)
        self.fedavg.register_client(assigned_id)
        return fsl_pb2.RegisterResponse(
            client_id=assigned_id,
            total_clients=self.num_clients,
            session_id=self.session_id,
        )

    def Forward(self, request, context):
        """Handle one forward request from a client."""
        # --- Scenario Validation via Metadata ---
        metadata = dict(context.invocation_metadata())
        client_scenario = metadata.get("scenario-id")
        if client_scenario and self.scenario_id and client_scenario != self.scenario_id:
            return fsl_pb2.ForwardResponse(
                status_message=f"Scenario mismatch: {client_scenario} != {self.scenario_id}",
                success=False,
            )

        try:
            client_id = getattr(request, "client_id", -1)
            reported_latency = getattr(request, "latency_ms", 0.0)
            assigned_compression, assigned_rho = self.scheduler.assign(client_id, reported_latency)
            result = handle_forward_request(
                request,
                hidden_size=self.hidden_size,
                device=self.device,
                server_model=self.server_model,
                optimizer=self.optimizer,
                sync_lock=self.sync_lock,
                current_round=self.fedavg.current_round,
                assigned_compression=assigned_compression,
                assigned_rho=assigned_rho,
                profiler_enabled=self.profiler_enabled,
                scheduler_enabled=self.scheduler_enabled,
            )
            self.reporter.record(result.log_entry)
            if self.log_server_requests:
                print(result.monitor_message)

            return result.response

        except Exception as e:
            print(f"[SERVER ERROR] Processing failed: {str(e)}")
            return fsl_pb2.ForwardResponse(
                status_message=f"Error: {str(e)}",
                success=False,
            )

    def flush_logs(self):
        self.reporter.flush()

    def Synchronize(self, request, context):
        """Aggregate client weights and return the latest global model."""
        # --- Scenario Validation via Metadata ---
        metadata = dict(context.invocation_metadata())
        client_scenario = metadata.get("scenario-id")
        if client_scenario and self.scenario_id and client_scenario != self.scenario_id:
            context.set_details(f"Scenario mismatch: {client_scenario} != {self.scenario_id}")
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            return fsl_pb2.SyncResponse()

        try:
            local_weights = bytes_to_tensor(request.client_weights)
            return self.fedavg.synchronize(
                request,
                local_weights=local_weights,
                server_model=self.server_model,
                optimizer=self.optimizer,
            )
        except Exception as e:
            print(f"[FED AVG ERROR] Synchronization failed: {str(e)}")
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.INTERNAL)
            return fsl_pb2.SyncResponse()

    def NotifyCompletion(self, request, context):
        """Record client completion and emit a server-side all-finished signal."""
        if request.session_id != self.session_id:
            print(
                f"[SERVER] Ignoring completion from client {request.client_id}: "
                f"session mismatch ({request.session_id} != {self.session_id})"
            )
            return fsl_pb2.CompletionResponse(
                acknowledged=False,
                completed_clients=len(self._completed_clients),
                total_clients=self.num_clients,
            )

        with self._completion_lock:
            if request.client_id not in self._registered_clients:
                print(
                    f"[SERVER] Ignoring completion from unregistered client {request.client_id} "
                    f"(session={self.session_id})"
                )
                return fsl_pb2.CompletionResponse(
                    acknowledged=False,
                    completed_clients=len(self._completed_clients),
                    total_clients=self.num_clients,
                )

            if request.client_id in self._completed_clients:
                completed = len(self._completed_clients)
                print(
                    f"[SERVER] Duplicate completion ignored for client {request.client_id} "
                    f"(completed={completed}/{self.num_clients})"
                )
                return fsl_pb2.CompletionResponse(
                    acknowledged=True,
                    completed_clients=completed,
                    total_clients=self.num_clients,
                )

            self._completed_clients.add(request.client_id)
            completed = len(self._completed_clients)

        self.fedavg.mark_client_completed(
            request.client_id,
            server_model=self.server_model,
            optimizer=self.optimizer,
        )
        print(
            f"[SERVER] Client {request.client_id} finished training | "
            f"epochs={request.completed_epochs} steps={request.total_steps} | "
            f"completed={completed}/{self.num_clients}"
        )
        # Shutdown if everyone who showed up is finished
        if completed >= len(self._registered_clients) and len(self._registered_clients) > 0:
            print(f"[SERVER] ALL REGISTERED CLIENTS FINISHED ({completed}/{len(self._registered_clients)}) | session={self.session_id}")
            self.flush_logs()
            self._shutdown_event.set()

        return fsl_pb2.CompletionResponse(
            acknowledged=True,
            completed_clients=completed,
            total_clients=self.num_clients,
        )

    def should_shutdown(self) -> bool:
        """Signal bootstrap loop to stop server once all clients are done."""
        return self._shutdown_event.is_set()


def serve():
    torch.set_num_threads(max(1, int(cfg.get("training", {}).get("torch_num_threads", 1))))
    run_server(FSLServerServicer())


if __name__ == "__main__":
    serve()