csc8114 / code / src / shared / serialization.py
serialization.py
Raw
import io
import torch

def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
    """
    Serializes a PyTorch tensor into a raw byte sequence so it can be 
    transmitted over a gRPC network connection.
    
    Args:
        tensor (torch.Tensor): The tensor to serialize (e.g., smashed activation or gradient).
        
    Returns:
        bytes: A compact byte string representation of the tensor.
    """
    buffer = io.BytesIO()
    # torch.save natively handles serialization to file-like objects
    torch.save(tensor, buffer)
    # Extract the raw binary data
    return buffer.getvalue()

def bytes_to_tensor(data: bytes) -> torch.Tensor:
    """
    Deserializes a byte sequence received from the network back into 
    a workable PyTorch tensor. Automatically maps the tensor to CPU initially.
    
    Args:
        data (bytes): The raw byte string received from gRPC.
        
    Returns:
        torch.Tensor: The reconstructed PyTorch tensor.
    """
    buffer = io.BytesIO(data)
    # Ensure it's loaded to CPU first to avoid device-map issues across nodes
    tensor = torch.load(buffer, weights_only=True, map_location='cpu')
    return tensor

if __name__ == "__main__":
    # --- Quick Sanity Check ---
    print("Testing Tensor Serialization Engine...")
    
    # 1. Create a mock tensor locally (similar to a batch of 32 smashed activations)
    original_tensor = torch.randn(32, 64)
    print(f"Original Tensor Shape: {original_tensor.shape}")
    
    # 2. Serialize to bytes (Mocking the process of packing it into the proto message)
    encoded_bytes = tensor_to_bytes(original_tensor)
    print(f"Serialized Byte Length: {len(encoded_bytes)} bytes")
    
    # 3. Deserialize back to tensor (Mocking the receiving end)
    reconstructed_tensor = bytes_to_tensor(encoded_bytes)
    
    # 4. Verify mathematical fidelity
    is_identical = torch.allclose(original_tensor, reconstructed_tensor)
    print(f"Reconstruction 100% Identical: {is_identical}")