import threading class CompressionScheduler: def __init__( self, default_mode: str = "float32", *, enabled: bool = True, float16_threshold: float = 4.0, int8_threshold: float = 10.0, base_rho: int = 1, min_rho: int = 1, max_rho: int = 20, rho_step: int = 1, topk_multiplier: float = 1.5, latency_ema_alpha: float = 0.2, ): self.default_mode = default_mode self.enabled = enabled self.float16_threshold = float16_threshold self.int8_threshold = int8_threshold self.base_rho = max(1, int(base_rho)) self.min_rho = max(1, int(min_rho)) self.max_rho = max(self.min_rho, int(max_rho)) self.rho_step = max(0, int(rho_step)) self.topk_threshold = float(int8_threshold) * float(topk_multiplier) self.latency_ema_alpha = float(latency_ema_alpha) self._client_state: dict[int, dict[str, float | int | str]] = {} self._lock = threading.Lock() def assign(self, client_id: int, reported_latency: float) -> tuple[str, int]: with self._lock: if client_id not in self._client_state: self._client_state[client_id] = { "mode": self.default_mode, "rho": self.base_rho, "latency_ema": 0.0, } state = self._client_state[client_id] if not self.enabled: state["mode"] = self.default_mode state["rho"] = self.base_rho elif reported_latency > 0: prev_ema = float(state["latency_ema"]) alpha = min(max(self.latency_ema_alpha, 0.0), 1.0) if prev_ema <= 0.0: latency_ema = float(reported_latency) else: latency_ema = alpha * float(reported_latency) + (1.0 - alpha) * prev_ema state["latency_ema"] = latency_ema # Escalate compression level and synchronization interval with network pressure. # Three levels: float32 → float16 → int8 (severity 0/1/2). if latency_ema > self.int8_threshold: severity = 2 mode = "int8" elif latency_ema > self.float16_threshold: severity = 1 mode = "float16" else: severity = 0 mode = self.default_mode rho = self.base_rho + severity * self.rho_step state["mode"] = mode state["rho"] = max(self.min_rho, min(self.max_rho, int(rho))) return str(state["mode"]), int(state["rho"])