import torch
import numpy as np
def _resolve_topk_ratio() -> float:
# Local import avoids introducing a hard dependency at module import time.
from src.shared.common import cfg
raw = cfg.get("compression", {}).get("topk_ratio", 0.5)
try:
ratio = float(raw)
except (TypeError, ValueError):
ratio = 0.5
return float(np.clip(ratio, 1e-6, 1.0))
def _topk_select(flat: np.ndarray) -> tuple[np.ndarray, np.ndarray, int]:
"""Return (indices, values, n) for top-k selection by magnitude."""
n = int(flat.size)
if n == 0:
return np.array([], dtype=np.int32), np.array([], dtype=np.float32), 0
ratio = _resolve_topk_ratio()
k = int(max(1, min(n, round(n * ratio))))
if k >= n:
return np.arange(n, dtype=np.int32), flat.copy(), n
indices = np.argpartition(np.abs(flat), -k)[-k:].astype(np.int32, copy=False)
return indices, flat[indices], n
def compress(tensor: torch.Tensor, mode: str) -> bytes:
"""
Compresses a PyTorch tensor into bytes based on the specified mode.
Modes: 'float32', 'float16', 'int8', 'topk', 'topk_int8'
'topk_int8' combines sparsification (top-k by magnitude) with int8
quantization of the selected values, giving the highest compression
at the cost of some accuracy.
"""
arr = tensor.detach().cpu().numpy()
if mode == "float16":
return arr.astype(np.float16).tobytes()
elif mode == "int8":
max_abs = np.max(np.abs(arr))
scale = float(max_abs / 127.0) if max_abs > 0 else 1.0
quantized = np.round(arr / scale).astype(np.int8)
scale_bytes = np.array([scale], dtype=np.float32).tobytes()
return scale_bytes + quantized.tobytes()
elif mode == "topk":
flat = arr.astype(np.float32, copy=False).reshape(-1)
indices, values, n = _topk_select(flat)
if n == 0:
return np.array([0, 0], dtype=np.int32).tobytes()
k = int(indices.size)
header = np.array([n, k], dtype=np.int32).tobytes()
return header + indices.tobytes() + values.astype(np.float32, copy=False).tobytes()
elif mode == "topk_int8":
# Sparsify (top-k) then quantize selected values with int8.
# Layout: header(8B) + scale(4B) + indices(4k B) + int8_values(k B)
flat = arr.astype(np.float32, copy=False).reshape(-1)
indices, values, n = _topk_select(flat)
if n == 0:
return np.array([0, 0], dtype=np.int32).tobytes() + np.array([1.0], dtype=np.float32).tobytes()
k = int(indices.size)
max_abs = float(np.max(np.abs(values))) if k > 0 else 0.0
scale = max_abs / 127.0 if max_abs > 0 else 1.0
quantized = np.round(values / scale).astype(np.int8)
header = np.array([n, k], dtype=np.int32).tobytes()
scale_bytes = np.array([scale], dtype=np.float32).tobytes()
return header + scale_bytes + indices.tobytes() + quantized.tobytes()
else: # float32 or default
return arr.astype(np.float32).tobytes()
def decompress(data_bytes: bytes, shape: tuple, mode: str) -> torch.Tensor:
"""
Decompresses bytes back into a PyTorch tensor.
"""
if mode == "float16":
arr = np.frombuffer(data_bytes, dtype=np.float16).astype(np.float32, copy=True)
elif mode == "int8":
# Extract the scale (first 4 bytes)
scale = np.frombuffer(data_bytes[:4], dtype=np.float32)[0]
# Dequantize the rest
quantized = np.frombuffer(data_bytes[4:], dtype=np.int8).astype(np.float32, copy=True)
arr = quantized * scale
elif mode == "topk":
if len(data_bytes) < 8:
raise ValueError("Invalid topk payload: header is incomplete.")
n, k = np.frombuffer(data_bytes[:8], dtype=np.int32)
n, k = int(n), int(k)
if n < 0 or k < 0:
raise ValueError(f"Invalid topk payload header: n={n}, k={k}")
expected = 8 + 4 * k + 4 * k
if len(data_bytes) != expected:
raise ValueError(
f"Invalid topk payload length: expected {expected} bytes, got {len(data_bytes)} bytes."
)
dense = np.zeros(n, dtype=np.float32)
if k > 0:
indices = np.frombuffer(data_bytes[8 : 8 + 4 * k], dtype=np.int32)
values = np.frombuffer(data_bytes[8 + 4 * k :], dtype=np.float32)
dense[indices] = values
arr = dense
elif mode == "topk_int8":
# Layout: header(8B) + scale(4B) + indices(4k B) + int8_values(k B)
if len(data_bytes) < 12:
raise ValueError("Invalid topk_int8 payload: header incomplete.")
n, k = np.frombuffer(data_bytes[:8], dtype=np.int32)
n, k = int(n), int(k)
scale = float(np.frombuffer(data_bytes[8:12], dtype=np.float32)[0])
expected = 12 + 4 * k + k
if len(data_bytes) != expected:
raise ValueError(
f"Invalid topk_int8 payload length: expected {expected} bytes, got {len(data_bytes)} bytes."
)
dense = np.zeros(n, dtype=np.float32)
if k > 0:
indices = np.frombuffer(data_bytes[12 : 12 + 4 * k], dtype=np.int32)
quantized = np.frombuffer(data_bytes[12 + 4 * k :], dtype=np.int8).astype(np.float32, copy=True)
dense[indices] = quantized * scale
arr = dense
else: # float32 or default
arr = np.frombuffer(data_bytes, dtype=np.float32).copy()
return torch.from_numpy(arr).view(shape)