import os import random from contextlib import nullcontext import numpy as np import torch from src.shared.common import cfg def resolve_device() -> torch.device: """ Resolve runtime device in a cross-platform way. Priority: 1) FSL_DEVICE env 2) training.device in config 3) legacy training.use_gpu flag """ training_cfg = cfg.get("training", {}) requested = os.getenv("FSL_DEVICE", "").strip().lower() if not requested: requested = str(training_cfg.get("device", "auto")).strip().lower() if requested in {"cuda", "gpu"}: if torch.cuda.is_available(): return torch.device("cuda") print("[RUNTIME] CUDA requested but unavailable; falling back to CPU.") return torch.device("cpu") if requested == "mps": if torch.backends.mps.is_available(): return torch.device("mps") print("[RUNTIME] MPS requested but unavailable; falling back to CPU.") return torch.device("cpu") if requested == "cpu": return torch.device("cpu") if requested == "auto": if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") legacy_use_gpu = bool(training_cfg.get("use_gpu", False)) if legacy_use_gpu: if torch.cuda.is_available(): return torch.device("cuda") if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def grpc_channel_options() -> list[tuple[str, int]]: """Build gRPC channel options: message size limits + keepalive for real-network stability.""" max_mb = int(cfg.get("grpc", {}).get("max_message_mb", 50)) max_bytes = max_mb * 1024 * 1024 return [ ("grpc.max_send_message_length", max_bytes), ("grpc.max_receive_message_length", max_bytes), # Send a ping every 30 s even when there are no active RPCs, so both sides # detect a dead TCP connection before the next training step hits it. ("grpc.keepalive_time_ms", 30_000), # If the peer does not respond to the ping within 10 s, declare the link dead. ("grpc.keepalive_timeout_ms", 10_000), # Allow pings even when no RPC is in flight (required for the above to work). ("grpc.keepalive_permit_without_calls", True), ("grpc.http2.max_pings_without_data", 0), ] def create_grpc_channel(address: str): """ Create a gRPC channel to *address* with keepalive options. When ``grpc.tls_enabled`` is true in config, an SSL/TLS channel is returned and the server certificate (or CA bundle) is read from ``grpc.tls_cert_path``. Otherwise an insecure channel is returned (safe when running over Tailscale). """ import grpc as _grpc # local import to avoid circular dependency at module level options = grpc_channel_options() grpc_cfg = cfg.get("grpc", {}) if grpc_cfg.get("tls_enabled", False): ca_cert_path = grpc_cfg.get("tls_cert_path") if ca_cert_path: with open(ca_cert_path, "rb") as fh: root_certificates = fh.read() else: root_certificates = None # fall back to system CA bundle credentials = _grpc.ssl_channel_credentials(root_certificates=root_certificates) return _grpc.secure_channel(address, credentials, options=options) return _grpc.insecure_channel(address, options=options) def resolve_server_address() -> str: """Resolve server address with env override support.""" grpc_cfg = cfg.get("grpc", {}) host = os.getenv("FSL_SERVER_HOST", str(grpc_cfg.get("server_host", "fsl-server"))) port = int(os.getenv("FSL_SERVER_PORT", str(grpc_cfg.get("server_port", 50051)))) return f"{host}:{port}" def maybe_autocast(device: torch.device): """ Optional AMP autocast context. Enabled when training.mixed_precision is set to 'auto' or 'bf16'. """ mixed = str(cfg.get("training", {}).get("mixed_precision", "none")).lower().strip() if mixed == "none": return nullcontext() if mixed in {"auto", "bf16"}: if device.type == "cuda": return torch.autocast(device_type="cuda", dtype=torch.bfloat16) if device.type == "mps": return torch.autocast(device_type="mps", dtype=torch.float16) return nullcontext() def set_global_seed(seed: int | None, *, role: str = "runtime") -> int | None: """ Set Python/NumPy/PyTorch random seeds for reproducibility. Returns the applied seed, or None if seeding is disabled. """ if seed is None: print(f"[RUNTIME] Seed disabled for {role}.") return None seed = int(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) if hasattr(torch.backends, "cudnn"): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False print(f"[RUNTIME] Seed set for {role}: {seed}") return seed