import random from dataclasses import dataclass from src.shared.common import cfg def _to_float(value, default: float) -> float: try: return float(value) except (TypeError, ValueError): return float(default) def _to_int(value, default: int) -> int: try: return int(value) except (TypeError, ValueError): return int(default) def _offset_for_client(client_id: int, offsets: list[float]) -> float: if not offsets: return 0.0 if client_id <= 0: return float(offsets[0]) idx = min(client_id - 1, len(offsets) - 1) return float(offsets[idx]) @dataclass class LatencyConfig: base_latency_ms: float jitter_ms: float burst_every_steps: int burst_latency_ms: float sleep_fraction: float max_sleep_ms: float measured_floor_ratio: float client_offsets_ms: list[float] def load_latency_config() -> LatencyConfig: profiler_cfg = cfg.get("profiler", {}) if isinstance(cfg, dict) else {} latency_cfg = profiler_cfg.get("latency_generator", {}) if not isinstance(latency_cfg, dict): latency_cfg = {} offsets_raw = latency_cfg.get("client_offsets_ms", [0.0, 4.0, 9.0]) if isinstance(offsets_raw, list) and offsets_raw: offsets = [_to_float(v, 0.0) for v in offsets_raw] else: offsets = [0.0, 4.0, 9.0] return LatencyConfig( base_latency_ms=max(0.0, _to_float(latency_cfg.get("base_latency_ms", 1.5), 1.5)), jitter_ms=max(0.0, _to_float(latency_cfg.get("jitter_ms", 0.8), 0.8)), burst_every_steps=max(0, _to_int(latency_cfg.get("burst_every_steps", 0), 0)), burst_latency_ms=max(0.0, _to_float(latency_cfg.get("burst_latency_ms", 0.0), 0.0)), sleep_fraction=max(0.0, _to_float(latency_cfg.get("sleep_fraction", 0.0), 0.0)), max_sleep_ms=max(0.0, _to_float(latency_cfg.get("max_sleep_ms", 0.0), 0.0)), measured_floor_ratio=max(0.0, _to_float(latency_cfg.get("measured_floor_ratio", 0.25), 0.25)), client_offsets_ms=offsets, ) class LatencyGenerator: """Client-side synthetic latency generator used for scheduler experiments.""" def __init__(self, *, client_id: int): self.client_id = int(client_id) self.cfg = load_latency_config() base_seed = _to_int(cfg.get("training", {}).get("seed", 42), 42) self._rng = random.Random(base_seed + 9973 + self.client_id) self._step = 0 def next_latency_ms(self, *, measured_latency_ms: float) -> float: self._step += 1 offset = _offset_for_client(self.client_id, self.cfg.client_offsets_ms) latency = self.cfg.base_latency_ms + offset if self.cfg.jitter_ms > 0.0: latency += self._rng.gauss(0.0, self.cfg.jitter_ms) if self.cfg.burst_every_steps > 0 and (self._step % self.cfg.burst_every_steps) == 0: latency += self.cfg.burst_latency_ms # Keep synthetic latency from going unrealistically below observed local RTT. measured_floor = max(0.0, float(measured_latency_ms)) * self.cfg.measured_floor_ratio latency = max(latency, measured_floor, 0.0) return float(latency) def suggested_sleep_ms(self, *, reported_latency_ms: float) -> float: if self.cfg.sleep_fraction <= 0.0: return 0.0 sleep_ms = float(reported_latency_ms) * self.cfg.sleep_fraction if self.cfg.max_sleep_ms > 0.0: sleep_ms = min(sleep_ms, self.cfg.max_sleep_ms) return max(0.0, sleep_ms)