csc8114 / code / src / models / split_lstm.py
split_lstm.py
Raw
import torch
import torch.nn as nn

class ClientLSTM(nn.Module):
    """
    The local model deployed on the Edge/IoT device. 
    It is responsible for taking raw D-dimensional weather features over a time window
    and extracting abstract 'smashed activations' (hidden states). 
    """
    def __init__(
        self,
        input_size=5,
        hidden_size=64,
        num_layers=1,
        lstm_dropout=0.3,
        dropout=None,
    ):
        super(ClientLSTM, self).__init__()
        if dropout is not None:
            # Backward-compatible alias for legacy call sites.
            lstm_dropout = dropout
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm_dropout = float(lstm_dropout)
        
        # Core feature extractor 
        self.lstm = nn.LSTM(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True,
            dropout=self.lstm_dropout if num_layers > 1 else 0.0
        )
        
    def forward(self, x):
        """
        x shape: (batch_size, seq_length, input_size) 
        output shape: (batch_size, hidden_size) - the Smashed Activation
        """
        # lstm returns: output, (h_n, c_n)
        # We only care about the final hidden state of the sequence to send to the server
        _, (h_n, _) = self.lstm(x)
        
        # h_n shape: (num_layers, batch_size, hidden_size)
        # We take the output of the final layer and return shape: (batch_size, hidden_size)
        smashed_activation = h_n[-1]
        
        return smashed_activation


class ServerHead(nn.Module):
    """
    The central model residing on the cloud server. 
    It takes the abstract 'smashed activations' from the clients over the network
    and finishes the predictive calculation.
    """
    def __init__(self, hidden_size=64, output_size=1, head_width=64, dropout=0.1):
        super(ServerHead, self).__init__()

        self.backbone = nn.Sequential(
            nn.Linear(hidden_size, head_width),
            nn.LayerNorm(head_width),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(head_width, head_width),
            nn.SiLU(),
        )
        self.rain_classifier = nn.Linear(head_width, 1)
        self.rain_regressor = nn.Linear(head_width, output_size)

    def forward(self, smashed_activation):
        """
        smashed_activation shape: (batch_size, hidden_size)
        returns:
          rain_logit: unnormalized rain/no-rain score
          rain_amount: predicted rainfall amount in transformed space
        """
        features = self.backbone(smashed_activation)
        rain_logit = self.rain_classifier(features)
        rain_amount = self.rain_regressor(features)
        return rain_logit, rain_amount