import io
import torch
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
"""
Serializes a PyTorch tensor into a raw byte sequence so it can be
transmitted over a gRPC network connection.
Args:
tensor (torch.Tensor): The tensor to serialize (e.g., smashed activation or gradient).
Returns:
bytes: A compact byte string representation of the tensor.
"""
buffer = io.BytesIO()
# torch.save natively handles serialization to file-like objects
torch.save(tensor, buffer)
# Extract the raw binary data
return buffer.getvalue()
def bytes_to_tensor(data: bytes) -> torch.Tensor:
"""
Deserializes a byte sequence received from the network back into
a workable PyTorch tensor. Automatically maps the tensor to CPU initially.
Args:
data (bytes): The raw byte string received from gRPC.
Returns:
torch.Tensor: The reconstructed PyTorch tensor.
"""
buffer = io.BytesIO(data)
# Ensure it's loaded to CPU first to avoid device-map issues across nodes
tensor = torch.load(buffer, weights_only=True, map_location='cpu')
return tensor
if __name__ == "__main__":
# --- Quick Sanity Check ---
print("Testing Tensor Serialization Engine...")
# 1. Create a mock tensor locally (similar to a batch of 32 smashed activations)
original_tensor = torch.randn(32, 64)
print(f"Original Tensor Shape: {original_tensor.shape}")
# 2. Serialize to bytes (Mocking the process of packing it into the proto message)
encoded_bytes = tensor_to_bytes(original_tensor)
print(f"Serialized Byte Length: {len(encoded_bytes)} bytes")
# 3. Deserialize back to tensor (Mocking the receiving end)
reconstructed_tensor = bytes_to_tensor(encoded_bytes)
# 4. Verify mathematical fidelity
is_identical = torch.allclose(original_tensor, reconstructed_tensor)
print(f"Reconstruction 100% Identical: {is_identical}")