FedSecurity / defense / coordinate_wise_trimmed_mean_defense.py
coordinate_wise_trimmed_mean_defense.py
Raw
from collections import OrderedDict
from typing import Callable, List, Tuple, Dict, Any
from .defense_base import BaseDefenseMethod
from ..common.utils import trimmed_mean


class CoordinateWiseTrimmedMeanDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.beta = config.beta  # fraction of trimmed values; total trimmed values: client_num * beta * 2


    def defend_before_aggregation(
        self,
        raw_client_grad_list: List[Tuple[float, OrderedDict]],
        extra_auxiliary_info: Any = None,
    ):
        if self.beta > 1 / 2 or self.beta < 0:
            raise ValueError("the bound of beta is [0, 1/2)")
        return trimmed_mean(raw_client_grad_list, int(self.beta * len(raw_client_grad_list)))