from dataclasses import dataclass from datetime import datetime import time import threading import torch import torch.nn.functional as F from proto import fsl_pb2 from src.shared.compression import compress, decompress from src.shared.common import cfg from src.shared.runtime import maybe_autocast from src.shared.targets import ( inverse_target_scalar, is_rain, rain_probability_threshold, rain_threshold_mm, ) @dataclass class ForwardPassResult: response: fsl_pb2.ForwardResponse log_entry: dict monitor_message: str def _classification_loss( rain_logit: torch.Tensor, rain_target: torch.Tensor, *, pos_weight: torch.Tensor, ) -> torch.Tensor: training_cfg = cfg.get("training", {}) loss_type = str(training_cfg.get("classification_loss_type", "weighted_bce")).strip().lower() focal_gamma = float(training_cfg.get("focal_gamma", 2.0)) focal_alpha = float(training_cfg.get("focal_alpha", -1.0)) bce_loss = F.binary_cross_entropy_with_logits( rain_logit, rain_target, pos_weight=pos_weight, reduction="none", ) if loss_type != "focal": return bce_loss.mean() prob = torch.sigmoid(rain_logit) pt = torch.where(rain_target > 0.5, prob, 1.0 - prob) focal_factor = torch.pow(1.0 - pt, focal_gamma) loss = focal_factor * bce_loss if 0.0 <= focal_alpha <= 1.0: alpha_t = rain_target * focal_alpha + (1.0 - rain_target) * (1.0 - focal_alpha) loss = alpha_t * loss return loss.mean() def handle_forward_request( request, *, hidden_size: int, device: torch.device, server_model: torch.nn.Module, optimizer: torch.optim.Optimizer, sync_lock: threading.Lock, current_round: int, assigned_compression: str, assigned_rho: int, profiler_enabled: bool, scheduler_enabled: bool, ) -> ForwardPassResult: compression_mode = request.compression_mode if hasattr(request, "compression_mode") and request.compression_mode else "float32" start_decomp_time = time.time() smashed_activation = decompress(request.activation_data, (-1, hidden_size), compression_mode).to(device) smashed_activation = smashed_activation.detach().clone().requires_grad_(True) decomp_time = (time.time() - start_decomp_time) * 1000.0 target = torch.tensor([[request.true_target]], dtype=torch.float32, device=device) is_training = getattr(request, "is_training", True) raw_target_val = getattr(request, "raw_target", request.true_target) rain_threshold = rain_threshold_mm() prob_threshold = rain_probability_threshold() rain_target = torch.tensor([[1.0 if is_rain(raw_target_val, threshold=rain_threshold) else 0.0]], dtype=torch.float32, device=device) cls_weight = float(request.classification_loss_weight) if hasattr(request, "classification_loss_weight") and request.classification_loss_weight > 0 else 1.0 reg_weight = float(request.regression_loss_weight) if hasattr(request, "regression_loss_weight") and request.regression_loss_weight > 0 else 1.0 pos_weight_value = float(cfg.get("training", {}).get("classification_positive_weight", 1.0)) pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=device) with sync_lock: start_comp_time = time.time() previous_mode = server_model.training server_model.train(mode=is_training) try: with maybe_autocast(device): rain_logit, rain_amount = server_model(smashed_activation) cls_loss = _classification_loss( rain_logit, rain_target, pos_weight=pos_weight, ) if rain_target.item() > 0.5: reg_loss = F.smooth_l1_loss(rain_amount, target) else: reg_loss = torch.zeros((), dtype=torch.float32, device=device) # total loss loss = cls_weight * cls_loss + reg_weight * reg_loss if is_training: optimizer.zero_grad() loss.backward() if smashed_activation.grad is None: raise ValueError("Gradient calculation failed on the smashed activation.") grad_mag = torch.norm(smashed_activation.grad).item() activation_gradient = compress(smashed_activation.grad, compression_mode) optimizer.step() else: grad_mag = 0.0 activation_gradient = b"" finally: server_model.train(mode=previous_mode) comp_time = (time.time() - start_comp_time) * 1000.0 rain_prob = torch.sigmoid(rain_logit).item() pred_val = rain_amount.item() raw_pred_val = inverse_target_scalar(pred_val) if rain_prob >= prob_threshold else 0.0 loss_val = loss.item() cls_loss_val = cls_loss.item() reg_loss_val = reg_loss.item() if isinstance(reg_loss, torch.Tensor) else float(reg_loss) current_lr = optimizer.param_groups[0]["lr"] rain_correct = int( is_rain(raw_target_val, threshold=rain_threshold) == is_rain(raw_pred_val, threshold=rain_threshold) ) reported_latency = getattr(request, "latency_ms", 0.0) payload_bytes = getattr(request, "payload_bytes", 0) log_entry = { "timestamp": datetime.now().isoformat(), "round": current_round, "client_id": request.client_id, "is_training": int(is_training), "rain_flag": int(is_rain(raw_target_val, threshold=rain_threshold)), "rain_correct": rain_correct, "compression_mode": compression_mode, "next_compression": assigned_compression, "next_rho": int(assigned_rho), "profiler_enabled": int(profiler_enabled), "scheduler_enabled": int(scheduler_enabled), "reported_latency_ms": reported_latency, "payload_bytes": payload_bytes, "target": raw_target_val, "prediction": raw_pred_val, "target_transformed": request.true_target, "prediction_transformed": pred_val, "rain_probability": rain_prob, "loss": loss_val, "classification_loss": cls_loss_val, "regression_loss": reg_loss_val, "learning_rate": current_lr, "decompression_time_ms": decomp_time, "computation_time_ms": comp_time, "gradient_magnitude": grad_mag, } if raw_target_val > 0: monitor_message = ( f"[💧] ID:{request.client_id} | Tgt:{raw_target_val:.2f} | " f"Pred:{raw_pred_val:.2f} | P(rain):{rain_prob:.2f} | Loss:{loss_val:.4f}" ) else: monitor_message = ( f"[☁️] ID:{request.client_id} | Pred:{raw_pred_val:.2f} | " f"P(rain):{rain_prob:.2f} | Loss:{loss_val:.4f}" ) monitor_message = ( f"{monitor_message} | {compression_mode}->{assigned_compression}/rho{int(assigned_rho)} " f"[D:{decomp_time:.1f}ms, C:{comp_time:.1f}ms, G:{grad_mag:.3f}]" ) response = fsl_pb2.ForwardResponse( gradient_data=activation_gradient, status_message=f"Success: Loss {loss_val:.4f} Pred {raw_pred_val:.4f} P(rain) {rain_prob:.4f}", next_compression_mode=assigned_compression, success=True, loss=loss_val, prediction=raw_pred_val, rain_probability=rain_prob, classification_loss=cls_loss_val, regression_loss=reg_loss_val, next_rho=int(assigned_rho), ) return ForwardPassResult( response=response, log_entry=log_entry, monitor_message=monitor_message, )