csc8114 / code / test_fedavg_partial.py
test_fedavg_partial.py
Raw
import io
import threading

import torch

from proto import fsl_pb2
from src.server.fedavg import FedAvgCoordinator


def _client_state(fill_value: float) -> dict[str, torch.Tensor]:
    return {
        "lstm.weight_ih_l0": torch.full((2, 2), fill_value=fill_value),
        "lstm.weight_hh_l0": torch.full((2, 2), fill_value=fill_value),
    }


def _decode_state_dict(payload: bytes) -> dict[str, torch.Tensor]:
    return torch.load(io.BytesIO(payload), weights_only=True, map_location="cpu")


def test_partial_participation_and_stale_refresh(tmp_path):
    session_dir = tmp_path / "session"
    periodic_dir = session_dir / "periodic"
    session_dir.mkdir(parents=True)
    periodic_dir.mkdir(parents=True)

    server_model = torch.nn.Linear(2, 1)
    optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
    coordinator = FedAvgCoordinator(
        num_clients=3,
        hidden_size=64,
        session_id="test-session",
        session_dir=str(session_dir),
        periodic_dir=str(periodic_dir),
        ckpt_interval=10,
        min_clients_per_round=2,
        round_timeout_sec=5.0,
    )
    for client_id in (1, 2, 3):
        coordinator.register_client(client_id)

    responses: dict[int, fsl_pb2.SyncResponse] = {}

    def do_sync(client_id: int, fill_value: float) -> None:
        request = fsl_pb2.SyncRequest(
            client_id=client_id,
            client_weights=b"placeholder",
            base_round=0,
            local_epochs=1,
        )
        responses[client_id] = coordinator.synchronize(
            request,
            local_weights=_client_state(fill_value),
            server_model=server_model,
            optimizer=optimizer,
        )

    t1 = threading.Thread(target=do_sync, args=(1, 1.0), daemon=True)
    t1.start()

    do_sync(2, 3.0)
    t1.join(timeout=2.0)
    assert not t1.is_alive()

    res1 = responses[1]
    res2 = responses[2]
    assert res1.accepted is True
    assert res2.accepted is True
    assert res1.round_number == 1
    assert res2.round_number == 1

    aggregated = _decode_state_dict(res1.global_weights)
    assert torch.allclose(aggregated["lstm.weight_ih_l0"], torch.full((2, 2), 2.0))
    assert torch.allclose(aggregated["lstm.weight_hh_l0"], torch.full((2, 2), 2.0))

    stale_response = coordinator.synchronize(
        fsl_pb2.SyncRequest(
            client_id=3,
            client_weights=b"placeholder",
            base_round=0,
            local_epochs=3,
        ),
        local_weights=_client_state(5.0),
        server_model=server_model,
        optimizer=optimizer,
    )
    assert stale_response.accepted is False
    assert stale_response.refresh_only is True
    assert stale_response.round_number == 1

    refreshed = _decode_state_dict(stale_response.global_weights)
    assert torch.allclose(refreshed["lstm.weight_ih_l0"], torch.full((2, 2), 2.0))