syntax = "proto3";

// Namespace / Package declaration to keep things organized
package fsl;

// ==========================================
// 1. MESSAGES (DATA PACKETS)
// ==========================================

// Request sent from Client to Server during the Forward Pass
message ForwardRequest {
    int32 client_id = 1;        // The ID of the client sending the data
    bytes activation_data = 2;  // The serialized PyTorch tensor (smashed activation)
    float true_target = 3;      // The target value used for loss calculation
    float latency_ms = 4;       // Measured local latency
    string compression_mode = 5; // The compression mode used
    bool is_training = 6;        // True for train pass (do backprop), False for test pass
    int32 payload_bytes = 7;     // Size of the transmitted activation in bytes
    float raw_target = 8;        // Original rainfall value in mm, used for logging and metrics
    float classification_loss_weight = 9;
    float regression_loss_weight = 10;
}

// Used by a client to register with the server and get an assigned ID
message RegisterRequest {
    string client_name = 1;       // Stable client/container name for logging and re-registration
    int32 requested_client_id = 2; // Preferred logical client ID, typically provided by orchestration
}

message RegisterResponse {
    int32 client_id = 1;      // The ID assigned by the server (1-indexed)
    int32 total_clients = 2;  // Total number of clients expected (from server config)
    string session_id = 3;    // Shared session folder name (YYYYMMDDHHMMSS), generated by server
}

// Response sent from Server back to Client during the Backward Pass
message ForwardResponse {
    bytes gradient_data = 1;         // The serialized PyTorch gradient tensor (to continue backprop locally)
    string status_message = 2;       // Any status or commands from the server (e.g., "OK" or "STOP")
    string next_compression_mode = 3; // Assigned compression mode for the next round by the Scheduler
    bool success = 4;                // Whether the server completed the request successfully
    float loss = 5;                  // Structured loss value for the current request
    float prediction = 6;            // Structured prediction for the current request
    float rain_probability = 7;      // Structured rain probability from the classifier head
    float classification_loss = 8;   // Structured classifier loss for the current request
    float regression_loss = 9;       // Structured regressor loss for the current request
    int32 next_rho = 10;             // Assigned synchronization interval (in local epochs) for the client
}

// Request sent from Client to Server during Global Sync (Federated Averaging)
message SyncRequest {
    int32 client_id = 1;
    bytes client_weights = 2;   // Serialized state_dict of the client's local model
    int32 base_round = 3;       // Global round the local model was based on before this sync
    int32 local_epochs = 4;     // Local epochs trained since the last successful refresh/sync
}

// Response sent from Server back to Client during Global Sync
message SyncResponse {
    bytes global_weights = 1;   // Serialized state_dict of the aggregated global model
    int32 round_number = 2;     // Latest server-side global round after processing this request
    bool accepted = 3;          // True if this client's update was included in an aggregation
    int32 applied_round = 4;    // Round this update was applied to, or 0 if rejected
    bool refresh_only = 5;      // True when the client should refresh without contributing its update
    string status_message = 6;  // Human-readable explanation for logging/debugging
}

message CompletionRequest {
    int32 client_id = 1;
    int32 completed_epochs = 2;
    int32 total_steps = 3;
    string session_id = 4;  // Session claimed by client; must match current server session
}

message CompletionResponse {
    bool acknowledged = 1;
    int32 completed_clients = 2;
    int32 total_clients = 3;
}

// ==========================================
// 2. SERVICES (NETWORK API ENDPOINTS)
// ==========================================

// Defines the physical endpoints the Server will expose to the Clients
service FSLService {

    // Registration: Client registers and receives an assigned ID
    rpc Register(RegisterRequest) returns (RegisterResponse);

    // Get current server status (Scenario ID, Session ID)
    rpc GetInfo(Empty) returns (ServerInfo);

    // The core function: Send activations, get gradients back
    rpc Forward(ForwardRequest) returns (ForwardResponse);

    // The aggregation function: Send local model, get global model back
    rpc Synchronize(SyncRequest) returns (SyncResponse);

    // Final notification sent by a client after it has completed training
    rpc NotifyCompletion(CompletionRequest) returns (CompletionResponse);
}

message Empty {}

message ServerInfo {
    string scenario_id = 1;
    string session_id = 2;
    int32 current_round = 3;
}
