import torch
import torch.nn as nn
from src.models.split_lstm import ClientLSTM, ServerHead
def test_fsl_communication_loop():
print("Initializing FSL Models...")
# 1. Initialize models
client_model = ClientLSTM()
server_model = ServerHead()
# Setup Optimizers
client_optim = torch.optim.Adam(client_model.parameters(), lr=0.001)
server_optim = torch.optim.Adam(server_model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# 2. Setup mock data (batch_size=32, seq_len=24, input_size=5)
print("Generating mock data from dataloader (32, 24, 5)...")
mock_x = torch.randn(32, 24, 5)
mock_y = torch.randn(32, 1)
# --- SIMULATE FSL TRAINING STEP ---
print("\n--- Starting FSL Step ---")
# [CLIENT SIDE]
# Client executes forward pass
client_optim.zero_grad()
smashed_activation = client_model(mock_x)
print(f"Client Output Shape (Smashed Activation): {smashed_activation.shape}")
# Client 'sends' this tensor to the server.
# CRITICAL: We must detach the tensor so the Server doesn't try to backpropagate
# directly into the Client's graph (which would fail over a real network).
# Then we call requires_grad_() so the Server computes the gradient specifically for this intermediate tensor.
received_activation = smashed_activation.detach().clone()
received_activation.requires_grad_(True)
# [SERVER SIDE]
# Server executes forward pass
# ServerHead returns (rain_logit, rain_amount); use the classifier head for the loss.
server_optim.zero_grad()
rain_logit, _ = server_model(received_activation)
print(f"Server Prediction Shape: {rain_logit.shape}")
# Server calculates Loss
loss = criterion(rain_logit, mock_y)
print(f"Loss computed: {loss.item():.4f}")
# Server executes backward pass
loss.backward()
# Server updates its own weights
server_optim.step()
# [NETWORK TRANSFER BACK]
# Server extracts the gradient at the cut-layer (the input to the ServerHead)
# This is what gets sent back across the network to the Client!
gradient_for_client = received_activation.grad.clone()
print(f"Gradient extracted for Client. Shape: {gradient_for_client.shape}")
# [CLIENT SIDE]
# Client receives the gradient and continues the backward pass into its own LSTM weights!
smashed_activation.backward(gradient_for_client)
# Client updates its own weights
client_optim.step()
print("\n[SUCCESS] FSL Forward and Backward Pass completed without errors!")
if __name__ == "__main__":
test_fsl_communication_loop()