csc8114 / code / src / shared / compression.py
compression.py
Raw
import torch
import numpy as np


def _resolve_topk_ratio() -> float:
    # Local import avoids introducing a hard dependency at module import time.
    from src.shared.common import cfg

    raw = cfg.get("compression", {}).get("topk_ratio", 0.5)
    try:
        ratio = float(raw)
    except (TypeError, ValueError):
        ratio = 0.5
    return float(np.clip(ratio, 1e-6, 1.0))


def _topk_select(flat: np.ndarray) -> tuple[np.ndarray, np.ndarray, int]:
    """Return (indices, values, n) for top-k selection by magnitude."""
    n = int(flat.size)
    if n == 0:
        return np.array([], dtype=np.int32), np.array([], dtype=np.float32), 0
    ratio = _resolve_topk_ratio()
    k = int(max(1, min(n, round(n * ratio))))
    if k >= n:
        return np.arange(n, dtype=np.int32), flat.copy(), n
    indices = np.argpartition(np.abs(flat), -k)[-k:].astype(np.int32, copy=False)
    return indices, flat[indices], n


def compress(tensor: torch.Tensor, mode: str) -> bytes:
    """
    Compresses a PyTorch tensor into bytes based on the specified mode.
    Modes: 'float32', 'float16', 'int8', 'topk', 'topk_int8'

    'topk_int8' combines sparsification (top-k by magnitude) with int8
    quantization of the selected values, giving the highest compression
    at the cost of some accuracy.
    """
    arr = tensor.detach().cpu().numpy()

    if mode == "float16":
        return arr.astype(np.float16).tobytes()

    elif mode == "int8":
        max_abs = np.max(np.abs(arr))
        scale = float(max_abs / 127.0) if max_abs > 0 else 1.0
        quantized = np.round(arr / scale).astype(np.int8)
        scale_bytes = np.array([scale], dtype=np.float32).tobytes()
        return scale_bytes + quantized.tobytes()

    elif mode == "topk":
        flat = arr.astype(np.float32, copy=False).reshape(-1)
        indices, values, n = _topk_select(flat)
        if n == 0:
            return np.array([0, 0], dtype=np.int32).tobytes()
        k = int(indices.size)
        header = np.array([n, k], dtype=np.int32).tobytes()
        return header + indices.tobytes() + values.astype(np.float32, copy=False).tobytes()

    elif mode == "topk_int8":
        # Sparsify (top-k) then quantize selected values with int8.
        # Layout: header(8B) + scale(4B) + indices(4k B) + int8_values(k B)
        flat = arr.astype(np.float32, copy=False).reshape(-1)
        indices, values, n = _topk_select(flat)
        if n == 0:
            return np.array([0, 0], dtype=np.int32).tobytes() + np.array([1.0], dtype=np.float32).tobytes()
        k = int(indices.size)
        max_abs = float(np.max(np.abs(values))) if k > 0 else 0.0
        scale = max_abs / 127.0 if max_abs > 0 else 1.0
        quantized = np.round(values / scale).astype(np.int8)
        header = np.array([n, k], dtype=np.int32).tobytes()
        scale_bytes = np.array([scale], dtype=np.float32).tobytes()
        return header + scale_bytes + indices.tobytes() + quantized.tobytes()

    else:  # float32 or default
        return arr.astype(np.float32).tobytes()

def decompress(data_bytes: bytes, shape: tuple, mode: str) -> torch.Tensor:
    """
    Decompresses bytes back into a PyTorch tensor.
    """
    if mode == "float16":
        arr = np.frombuffer(data_bytes, dtype=np.float16).astype(np.float32, copy=True)
        
    elif mode == "int8":
        # Extract the scale (first 4 bytes)
        scale = np.frombuffer(data_bytes[:4], dtype=np.float32)[0]
        # Dequantize the rest
        quantized = np.frombuffer(data_bytes[4:], dtype=np.int8).astype(np.float32, copy=True)
        arr = quantized * scale
    elif mode == "topk":
        if len(data_bytes) < 8:
            raise ValueError("Invalid topk payload: header is incomplete.")
        n, k = np.frombuffer(data_bytes[:8], dtype=np.int32)
        n, k = int(n), int(k)
        if n < 0 or k < 0:
            raise ValueError(f"Invalid topk payload header: n={n}, k={k}")
        expected = 8 + 4 * k + 4 * k
        if len(data_bytes) != expected:
            raise ValueError(
                f"Invalid topk payload length: expected {expected} bytes, got {len(data_bytes)} bytes."
            )
        dense = np.zeros(n, dtype=np.float32)
        if k > 0:
            indices = np.frombuffer(data_bytes[8 : 8 + 4 * k], dtype=np.int32)
            values = np.frombuffer(data_bytes[8 + 4 * k :], dtype=np.float32)
            dense[indices] = values
        arr = dense

    elif mode == "topk_int8":
        # Layout: header(8B) + scale(4B) + indices(4k B) + int8_values(k B)
        if len(data_bytes) < 12:
            raise ValueError("Invalid topk_int8 payload: header incomplete.")
        n, k = np.frombuffer(data_bytes[:8], dtype=np.int32)
        n, k = int(n), int(k)
        scale = float(np.frombuffer(data_bytes[8:12], dtype=np.float32)[0])
        expected = 12 + 4 * k + k
        if len(data_bytes) != expected:
            raise ValueError(
                f"Invalid topk_int8 payload length: expected {expected} bytes, got {len(data_bytes)} bytes."
            )
        dense = np.zeros(n, dtype=np.float32)
        if k > 0:
            indices = np.frombuffer(data_bytes[12 : 12 + 4 * k], dtype=np.int32)
            quantized = np.frombuffer(data_bytes[12 + 4 * k :], dtype=np.int8).astype(np.float32, copy=True)
            dense[indices] = quantized * scale
        arr = dense

    else:  # float32 or default
        arr = np.frombuffer(data_bytes, dtype=np.float32).copy()
        
    return torch.from_numpy(arr).view(shape)