csc8114 / code / src / server / forward_service.py
forward_service.py
Raw
from dataclasses import dataclass
from datetime import datetime
import time
import threading

import torch
import torch.nn.functional as F

from proto import fsl_pb2
from src.shared.compression import compress, decompress
from src.shared.common import cfg
from src.shared.runtime import maybe_autocast
from src.shared.targets import (
    inverse_target_scalar,
    is_rain,
    rain_probability_threshold,
    rain_threshold_mm,
)


@dataclass
class ForwardPassResult:
    response: fsl_pb2.ForwardResponse
    log_entry: dict
    monitor_message: str


def _classification_loss(
    rain_logit: torch.Tensor,
    rain_target: torch.Tensor,
    *,
    pos_weight: torch.Tensor,
) -> torch.Tensor:
    training_cfg = cfg.get("training", {})
    loss_type = str(training_cfg.get("classification_loss_type", "weighted_bce")).strip().lower()
    focal_gamma = float(training_cfg.get("focal_gamma", 2.0))
    focal_alpha = float(training_cfg.get("focal_alpha", -1.0))

    bce_loss = F.binary_cross_entropy_with_logits(
        rain_logit,
        rain_target,
        pos_weight=pos_weight,
        reduction="none",
    )
    if loss_type != "focal":
        return bce_loss.mean()

    prob = torch.sigmoid(rain_logit)
    pt = torch.where(rain_target > 0.5, prob, 1.0 - prob)
    focal_factor = torch.pow(1.0 - pt, focal_gamma)
    loss = focal_factor * bce_loss

    if 0.0 <= focal_alpha <= 1.0:
        alpha_t = rain_target * focal_alpha + (1.0 - rain_target) * (1.0 - focal_alpha)
        loss = alpha_t * loss

    return loss.mean()


def handle_forward_request(
    request,
    *,
    hidden_size: int,
    device: torch.device,
    server_model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    sync_lock: threading.Lock,
    current_round: int,
    assigned_compression: str,
    assigned_rho: int,
    profiler_enabled: bool,
    scheduler_enabled: bool,
) -> ForwardPassResult:
    compression_mode = request.compression_mode if hasattr(request, "compression_mode") and request.compression_mode else "float32"
    start_decomp_time = time.time()
    smashed_activation = decompress(request.activation_data, (-1, hidden_size), compression_mode).to(device)
    smashed_activation = smashed_activation.detach().clone().requires_grad_(True)
    decomp_time = (time.time() - start_decomp_time) * 1000.0

    target = torch.tensor([[request.true_target]], dtype=torch.float32, device=device)
    is_training = getattr(request, "is_training", True)
    raw_target_val = getattr(request, "raw_target", request.true_target)
    rain_threshold = rain_threshold_mm()
    prob_threshold = rain_probability_threshold()
    rain_target = torch.tensor([[1.0 if is_rain(raw_target_val, threshold=rain_threshold) else 0.0]], dtype=torch.float32, device=device)
    cls_weight = float(request.classification_loss_weight) if hasattr(request, "classification_loss_weight") and request.classification_loss_weight > 0 else 1.0
    reg_weight = float(request.regression_loss_weight) if hasattr(request, "regression_loss_weight") and request.regression_loss_weight > 0 else 1.0
    pos_weight_value = float(cfg.get("training", {}).get("classification_positive_weight", 1.0))
    pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=device)

    with sync_lock:
        start_comp_time = time.time()
        previous_mode = server_model.training
        server_model.train(mode=is_training)
        try:
            with maybe_autocast(device):
                rain_logit, rain_amount = server_model(smashed_activation)
                cls_loss = _classification_loss(
                    rain_logit,
                    rain_target,
                    pos_weight=pos_weight,
                )
                if rain_target.item() > 0.5:
                    reg_loss = F.smooth_l1_loss(rain_amount, target)
                else:
                    reg_loss = torch.zeros((), dtype=torch.float32, device=device)
                #  total loss
                loss = cls_weight * cls_loss + reg_weight * reg_loss

            if is_training:
                optimizer.zero_grad()
                loss.backward()

                if smashed_activation.grad is None:
                    raise ValueError("Gradient calculation failed on the smashed activation.")

                grad_mag = torch.norm(smashed_activation.grad).item()
                activation_gradient = compress(smashed_activation.grad, compression_mode)
                optimizer.step()
            else:
                grad_mag = 0.0
                activation_gradient = b""
        finally:
            server_model.train(mode=previous_mode)

        comp_time = (time.time() - start_comp_time) * 1000.0

    rain_prob = torch.sigmoid(rain_logit).item()
    pred_val = rain_amount.item()
    raw_pred_val = inverse_target_scalar(pred_val) if rain_prob >= prob_threshold else 0.0
    loss_val = loss.item()
    cls_loss_val = cls_loss.item()
    reg_loss_val = reg_loss.item() if isinstance(reg_loss, torch.Tensor) else float(reg_loss)
    current_lr = optimizer.param_groups[0]["lr"]
    rain_correct = int(
        is_rain(raw_target_val, threshold=rain_threshold)
        == is_rain(raw_pred_val, threshold=rain_threshold)
    )
    reported_latency = getattr(request, "latency_ms", 0.0)
    payload_bytes = getattr(request, "payload_bytes", 0)

    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "round": current_round,
        "client_id": request.client_id,
        "is_training": int(is_training),
        "rain_flag": int(is_rain(raw_target_val, threshold=rain_threshold)),
        "rain_correct": rain_correct,
        "compression_mode": compression_mode,
        "next_compression": assigned_compression,
        "next_rho": int(assigned_rho),
        "profiler_enabled": int(profiler_enabled),
        "scheduler_enabled": int(scheduler_enabled),
        "reported_latency_ms": reported_latency,
        "payload_bytes": payload_bytes,
        "target": raw_target_val,
        "prediction": raw_pred_val,
        "target_transformed": request.true_target,
        "prediction_transformed": pred_val,
        "rain_probability": rain_prob,
        "loss": loss_val,
        "classification_loss": cls_loss_val,
        "regression_loss": reg_loss_val,
        "learning_rate": current_lr,
        "decompression_time_ms": decomp_time,
        "computation_time_ms": comp_time,
        "gradient_magnitude": grad_mag,
    }

    if raw_target_val > 0:
        monitor_message = (
            f"[💧] ID:{request.client_id} | Tgt:{raw_target_val:.2f} | "
            f"Pred:{raw_pred_val:.2f} | P(rain):{rain_prob:.2f} | Loss:{loss_val:.4f}"
        )
    else:
        monitor_message = (
            f"[☁️] ID:{request.client_id} | Pred:{raw_pred_val:.2f} | "
            f"P(rain):{rain_prob:.2f} | Loss:{loss_val:.4f}"
        )
    monitor_message = (
        f"{monitor_message} | {compression_mode}->{assigned_compression}/rho{int(assigned_rho)} "
        f"[D:{decomp_time:.1f}ms, C:{comp_time:.1f}ms, G:{grad_mag:.3f}]"
    )

    response = fsl_pb2.ForwardResponse(
        gradient_data=activation_gradient,
        status_message=f"Success: Loss {loss_val:.4f} Pred {raw_pred_val:.4f} P(rain) {rain_prob:.4f}",
        next_compression_mode=assigned_compression,
        success=True,
        loss=loss_val,
        prediction=raw_pred_val,
        rain_probability=rain_prob,
        classification_loss=cls_loss_val,
        regression_loss=reg_loss_val,
        next_rho=int(assigned_rho),
    )

    return ForwardPassResult(
        response=response,
        log_entry=log_entry,
        monitor_message=monitor_message,
    )