csc8114 / code / src / client / sync.py
sync.py
Raw
import io
import os
import time
from dataclasses import dataclass

import torch

from proto import fsl_pb2
from src.models.split_lstm import ClientLSTM


@dataclass
class SyncResult:
    client_model: ClientLSTM
    round_number: int
    accepted: bool
    applied_round: int
    refresh_only: bool
    status_message: str
    sync_bytes_sent: int = 0   # bytes sent (client weights upload)
    sync_bytes_recv: int = 0   # bytes received (global weights download)


def fed_avg_sync(
    stub,
    client_id: int,
    client_model: ClientLSTM,
    *,
    model_round: int,
    local_epochs: int,
) -> SyncResult:
    """
    Synchronize local client weights with the server and load the returned global model.
    """
    buffer = io.BytesIO()
    torch.save(client_model.state_dict(), buffer)
    client_weights_bytes = buffer.getvalue()

    sync_req = fsl_pb2.SyncRequest(
        client_id=client_id,
        client_weights=client_weights_bytes,
        base_round=int(model_round),
        local_epochs=int(local_epochs),
    )

    print(
        f"[CLIENT {client_id}] Waiting for global aggregation... "
        f"(base_round={int(model_round)} local_epochs={int(local_epochs)})"
    )
    wait_start = time.time()
    sync_res = stub.Synchronize(sync_req, metadata=[("scenario-id", os.environ.get("SCENARIO_ID", ""))])
    wait_elapsed_s = time.time() - wait_start

    round_number = int(getattr(sync_res, "round_number", model_round))
    accepted = bool(getattr(sync_res, "accepted", False))
    applied_round = int(getattr(sync_res, "applied_round", 0))
    refresh_only = bool(getattr(sync_res, "refresh_only", False))
    status_message = str(getattr(sync_res, "status_message", "")).strip()

    if sync_res.global_weights:
        global_buffer = io.BytesIO(sync_res.global_weights)
        global_state_dict = torch.load(global_buffer, weights_only=True, map_location="cpu")
        client_model.load_state_dict(global_state_dict)

    if accepted:
        print(
            f"[CLIENT {client_id}] Global model updated to Round {round_number} "
            f"(applied_round={applied_round} sync_wait={wait_elapsed_s:.2f}s)"
        )
    elif refresh_only:
        print(
            f"[CLIENT {client_id}] Refreshed local model to Round {round_number} "
            f"without contributing this update (sync_wait={wait_elapsed_s:.2f}s)"
        )
    else:
        print(
            f"[CLIENT {client_id}] Synchronization completed without model update "
            f"(sync_wait={wait_elapsed_s:.2f}s)"
        )

    if status_message:
        print(f"[CLIENT {client_id}] Sync status: {status_message}")

    sync_bytes_sent = len(client_weights_bytes)
    sync_bytes_recv = len(sync_res.global_weights) if sync_res.global_weights else 0

    return SyncResult(
        client_model=client_model,
        round_number=round_number,
        accepted=accepted,
        applied_round=applied_round,
        refresh_only=refresh_only,
        status_message=status_message,
        sync_bytes_sent=sync_bytes_sent,
        sync_bytes_recv=sync_bytes_recv,
    )