csc8114 / code / src / client / forward_step.py
forward_step.py
Raw
import os
import time

import pandas as pd
import torch

from proto import fsl_pb2
from src.client.latency_generator import LatencyGenerator
from src.models.split_lstm import ClientLSTM
from src.shared.common import cfg
from src.shared.compression import compress, decompress
from src.shared.runtime import maybe_autocast
from src.shared.targets import is_rain, transform_target_scalar

_LATENCY_GENERATORS: dict[int, LatencyGenerator] = {}


def _latency_generator_for(client_id: int) -> LatencyGenerator:
    generator = _LATENCY_GENERATORS.get(int(client_id))
    if generator is None:
        generator = LatencyGenerator(client_id=int(client_id))
        _LATENCY_GENERATORS[int(client_id)] = generator
    return generator


def run_forward_step(
    stub,
    client_id: int,
    client_model: ClientLSTM,
    optimizer: torch.optim.Optimizer,
    df: pd.DataFrame,
    target_idx: int,
    target_value: float,
    mode: str,
    sensor_id: str,
    compression_mode: str,
    feature_cols: list[str],
    feat_stats: tuple | None = None,
    device: torch.device = torch.device("cpu"),
    is_training: bool = True,
    last_latency_ms: float = 0.0,
    seq_len: int | None = None,
) -> dict:
    """
    Runs one split-learning request/response cycle and returns the step log.
    """
    log_step_details = cfg.get("console", {}).get("log_step_details", False)
    profiler_enabled = cfg.get("profiler", {}).get("enabled", True)
    cls_weight = float(cfg.get("training", {}).get("classification_loss_weight", 1.0))
    reg_weight = float(cfg.get("training", {}).get("regression_loss_weight", 1.0))
    seq_len = int(seq_len if seq_len is not None else cfg.get("model", {}).get("seq_len", 24))

    raw_data = df[feature_cols].iloc[target_idx - seq_len:target_idx].values
    if feat_stats:
        mean, std = feat_stats
        raw_data = (raw_data - mean) / std

    raw_data_tensor = torch.tensor(raw_data, dtype=torch.float32, device=device)
    input_tensor = raw_data_tensor.unsqueeze(0)
    with maybe_autocast(device):
        smashed_activation = client_model(input_tensor)

    activation_bytes = compress(smashed_activation, compression_mode)
    payload_size = len(activation_bytes)
    if profiler_enabled:
        latency_generator = _latency_generator_for(client_id)
        reported_latency_ms = latency_generator.next_latency_ms(
            measured_latency_ms=last_latency_ms,
        )
        sleep_ms = latency_generator.suggested_sleep_ms(
            reported_latency_ms=reported_latency_ms,
        )
        if sleep_ms > 0.0:
            time.sleep(sleep_ms / 1000.0)
    else:
        # Report last step's measured RTT so the scheduler sees real network latency.
        reported_latency_ms = last_latency_ms
    reported_payload_bytes = payload_size
    training_target = transform_target_scalar(target_value)

    request = fsl_pb2.ForwardRequest(
        client_id=client_id,
        activation_data=activation_bytes,
        true_target=training_target,
        latency_ms=reported_latency_ms,
        compression_mode=compression_mode,
        is_training=is_training,
        payload_bytes=reported_payload_bytes,
        raw_target=target_value,
        classification_loss_weight=cls_weight,
        regression_loss_weight=reg_weight,
    )
    phase = "TRAIN" if is_training else "TEST"
    if log_step_details:
        print(f"[{phase}] Transmitting activations for {sensor_id}... Payload: {payload_size} bytes")

    rpc_start = time.time()
    response = stub.Forward(request, metadata=[("scenario-id", os.environ.get("SCENARIO_ID", ""))])
    latency_ms = (time.time() - rpc_start) * 1000.0

    if not response.success:
        raise RuntimeError(response.status_message or "Server forward pass failed.")

    current_loss = float(response.loss)
    prediction_val = float(response.prediction)
    rain_probability = float(getattr(response, "rain_probability", 0.0))
    classification_loss = float(getattr(response, "classification_loss", 0.0))
    regression_loss = float(getattr(response, "regression_loss", 0.0))

    if log_step_details:
        icon = "RAIN" if is_rain(target_value) else "DRY"
        print(f"[{icon}] [{mode}] {sensor_id[:10]} | 3h Target: {target_value:.2f} | Loss: {current_loss:.6f}")

    if is_training:
        received_grad = decompress(response.gradient_data, smashed_activation.shape, compression_mode).to(device)
        smashed_activation.backward(received_grad)
        torch.nn.utils.clip_grad_norm_(client_model.parameters(), max_norm=1.0)
        optimizer.step()

    if log_step_details:
        print(f"[SERVER] Feedback processed | {response.status_message} | Latency: {latency_ms:.2f} ms")
    scheduler_enabled = cfg.get("scheduler", {}).get("enabled", True)
    next_compression_mode = compression_mode
    next_rho = int(cfg.get("federated", {}).get("rho", 1))
    if scheduler_enabled and response.next_compression_mode:
        next_compression_mode = response.next_compression_mode
    if scheduler_enabled and getattr(response, "next_rho", 0) > 0:
        next_rho = int(response.next_rho)

    # Sparsity ratio: how much of the activation was actually transmitted
    # For topk modes: k/n ratio from config; for dense modes: 1.0 (100%)
    if "topk" in compression_mode:
        raw_ratio = cfg.get("compression", {}).get("topk_ratio", 0.125)
        try:
            sparsity_ratio = float(raw_ratio)
        except (TypeError, ValueError):
            sparsity_ratio = 0.5
    else:
        sparsity_ratio = 1.0

    return {
        "Target": target_value,
        "Prediction": prediction_val,
        "RainFlag": int(is_rain(target_value)),
        "Loss": current_loss,
        "RainProbability": rain_probability,
        "ClassificationLoss": classification_loss,
        "RegressionLoss": regression_loss,
        "LatencyMs": float(latency_ms),
        "PayloadBytes": reported_payload_bytes,
        "CompressionMode": compression_mode,       # 這一步實際使用的壓縮模式
        "SparsityRatio": round(sparsity_ratio, 4), # topk 只傳送了幾%的參數 (1.0 = 全部)
        "NextCompression": next_compression_mode,
        "NextRho": int(next_rho),
        "ProfilerEnabled": int(profiler_enabled),
        "SchedulerEnabled": int(scheduler_enabled),
    }