csc8114 / code / src / server / scheduler.py
scheduler.py
Raw
import threading


class CompressionScheduler:
    def __init__(
        self,
        default_mode: str = "float32",
        *,
        enabled: bool = True,
        float16_threshold: float = 4.0,
        int8_threshold: float = 10.0,
        base_rho: int = 1,
        min_rho: int = 1,
        max_rho: int = 20,
        rho_step: int = 1,
        topk_multiplier: float = 1.5,
        latency_ema_alpha: float = 0.2,
    ):
        self.default_mode = default_mode
        self.enabled = enabled
        self.float16_threshold = float16_threshold
        self.int8_threshold = int8_threshold
        self.base_rho = max(1, int(base_rho))
        self.min_rho = max(1, int(min_rho))
        self.max_rho = max(self.min_rho, int(max_rho))
        self.rho_step = max(0, int(rho_step))
        self.topk_threshold = float(int8_threshold) * float(topk_multiplier)
        self.latency_ema_alpha = float(latency_ema_alpha)
        self._client_state: dict[int, dict[str, float | int | str]] = {}
        self._lock = threading.Lock()

    def assign(self, client_id: int, reported_latency: float) -> tuple[str, int]:
        with self._lock:
            if client_id not in self._client_state:
                self._client_state[client_id] = {
                    "mode": self.default_mode,
                    "rho": self.base_rho,
                    "latency_ema": 0.0,
                }
            state = self._client_state[client_id]

            if not self.enabled:
                state["mode"] = self.default_mode
                state["rho"] = self.base_rho
            elif reported_latency > 0:
                prev_ema = float(state["latency_ema"])
                alpha = min(max(self.latency_ema_alpha, 0.0), 1.0)
                if prev_ema <= 0.0:
                    latency_ema = float(reported_latency)
                else:
                    latency_ema = alpha * float(reported_latency) + (1.0 - alpha) * prev_ema
                state["latency_ema"] = latency_ema

                # Escalate compression level and synchronization interval with network pressure.
                # Three levels: float32 → float16 → int8 (severity 0/1/2).
                if latency_ema > self.int8_threshold:
                    severity = 2
                    mode = "int8"
                elif latency_ema > self.float16_threshold:
                    severity = 1
                    mode = "float16"
                else:
                    severity = 0
                    mode = self.default_mode

                rho = self.base_rho + severity * self.rho_step
                state["mode"] = mode
                state["rho"] = max(self.min_rho, min(self.max_rho, int(rho)))

            return str(state["mode"]), int(state["rho"])