import torch import torch.nn as nn class ClientLSTM(nn.Module): """ The local model deployed on the Edge/IoT device. It is responsible for taking raw D-dimensional weather features over a time window and extracting abstract 'smashed activations' (hidden states). """ def __init__( self, input_size=5, hidden_size=64, num_layers=1, lstm_dropout=0.3, dropout=None, ): super(ClientLSTM, self).__init__() if dropout is not None: # Backward-compatible alias for legacy call sites. lstm_dropout = dropout self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.lstm_dropout = float(lstm_dropout) # Core feature extractor self.lstm = nn.LSTM( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=self.lstm_dropout if num_layers > 1 else 0.0 ) def forward(self, x): """ x shape: (batch_size, seq_length, input_size) output shape: (batch_size, hidden_size) - the Smashed Activation """ # lstm returns: output, (h_n, c_n) # We only care about the final hidden state of the sequence to send to the server _, (h_n, _) = self.lstm(x) # h_n shape: (num_layers, batch_size, hidden_size) # We take the output of the final layer and return shape: (batch_size, hidden_size) smashed_activation = h_n[-1] return smashed_activation class ServerHead(nn.Module): """ The central model residing on the cloud server. It takes the abstract 'smashed activations' from the clients over the network and finishes the predictive calculation. """ def __init__(self, hidden_size=64, output_size=1, head_width=64, dropout=0.1): super(ServerHead, self).__init__() self.backbone = nn.Sequential( nn.Linear(hidden_size, head_width), nn.LayerNorm(head_width), nn.SiLU(), nn.Dropout(dropout), nn.Linear(head_width, head_width), nn.SiLU(), ) self.rain_classifier = nn.Linear(head_width, 1) self.rain_regressor = nn.Linear(head_width, output_size) def forward(self, smashed_activation): """ smashed_activation shape: (batch_size, hidden_size) returns: rain_logit: unnormalized rain/no-rain score rain_amount: predicted rainfall amount in transformed space """ features = self.backbone(smashed_activation) rain_logit = self.rain_classifier(features) rain_amount = self.rain_regressor(features) return rain_logit, rain_amount