FedSecurity / defense / RFA_defense.py
RFA_defense.py
Raw
from collections import OrderedDict
from typing import Callable, List, Tuple, Dict, Any
from ..common.utils import compute_geometric_median
from ...security.defense.defense_base import BaseDefenseMethod


class RFADefense(BaseDefenseMethod):
    def __init__(self, config):
        self.device = device.get_device(config)

    def defend_on_aggregation(
            self,
            raw_client_grad_list: List[Tuple[float, OrderedDict]],
            base_aggregation_func: Callable = None,
            extra_auxiliary_info: Any = None,
    ):
        (num0, avg_params) = raw_client_grad_list[0]
        weights = {num for (num, params) in raw_client_grad_list}
        weights = {weight / sum(weights, 0.0) for weight in weights}
        for k in avg_params.keys():
            client_grads = [params[k] for (_, params) in raw_client_grad_list]
            avg_params[k] = compute_geometric_median(weights, client_grads)
        return avg_params