FedSecurity / defense / slsgd_defense.py
slsgd_defense.py
Raw
import math
from collections import OrderedDict
from typing import Callable, List, Tuple, Dict, Any
from ..common.utils import trimmed_mean
from ..defense.defense_base import BaseDefenseMethod


class SLSGDDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.b = config.trim_param_b  # parameter of trimmed mean
        if config.alpha > 1 or config.alpha < 0:
            raise ValueError("the bound of alpha is [0, 1]")
        self.alpha = config.alpha
        self.option_type = config.option_type
        self.config = config

    def defend_before_aggregation(
        self,
        raw_client_grad_list: List[Tuple[float, OrderedDict]],
        extra_auxiliary_info: Any = None,
    ):
        if self.b > math.ceil(len(raw_client_grad_list) / 2) - 1 or self.b < 0:
            raise ValueError(
                "the bound of b is [0, {}])".format(
                    math.ceil(len(raw_client_grad_list) / 2) - 1
                )
            )
        if self.option_type != 1 and self.option_type != 2:
            raise Exception("Such option type does not exist!")
        if self.option_type == 2:
            raw_client_grad_list = trimmed_mean(
                raw_client_grad_list, self.b
            )  # process model list
        return raw_client_grad_list

    def defend_on_aggregation(
        self,
        raw_client_grad_list: List[Tuple[float, OrderedDict]],
        base_aggregation_func: Callable = None,
        extra_auxiliary_info: Any = None,
    ):
        global_model = extra_auxiliary_info
        avg_params = base_aggregation_func(args=self.config, raw_grad_list=raw_client_grad_list)
        for k in avg_params.keys():
            avg_params[k] = (1 - self.alpha) * global_model[
                k
            ] + self.alpha * avg_params[k]
        return avg_params