FedSecurity / defense / geometric_median_defense.py
geometric_median_defense.py
Raw
import math
from collections import OrderedDict
from typing import Callable, List, Tuple, Dict, Any

from ..common.bucket import Bucket
from ..common.utils import compute_geometric_median
from ...security.defense.defense_base import BaseDefenseMethod


class GeometricMedianDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.byzantine_client_num = config.byzantine_client_num
        self.client_num_per_round = config.client_num_per_round
        # 2(1 + ε )q ≤ batch_num ≤ client_num_per_round
        # trade-off between accuracy & robustness:
        #       larger batch_num --> more Byzantine robustness, larger estimation error.
        self.batch_num = config.batch_num
        if self.byzantine_client_num == 0:
            self.batch_num = 1
        self.batch_size = math.ceil(self.client_num_per_round / self.batch_num)

    def defend_on_aggregation(
            self,
            raw_client_grad_list: List[Tuple[float, OrderedDict]],
            base_aggregation_func: Callable = None,
            extra_auxiliary_info: Any = None,
    ):
        batch_grad_list = Bucket.bucketization(raw_client_grad_list, self.batch_size)
        (num0, avg_params) = batch_grad_list[0]
        alphas = {alpha for (alpha, params) in batch_grad_list}
        alphas = {alpha / sum(alphas, 0.0) for alpha in alphas}
        for k in avg_params.keys():
            batch_grads = [params[k] for (alpha, params) in batch_grad_list]
            avg_params[k] = compute_geometric_median(alphas, batch_grads)
        return avg_params