FedSecurity / defense / crfl_defense.py
crfl_defense.py
Raw
from collections import OrderedDict
from .defense_base import BaseDefenseMethod
from ..common import utils
from ...dp.mechanisms import Gaussian


class CRFLDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.config = config
        self.epoch = 1
        if hasattr(config, "clip_threshold"):
            self.user_defined_clip_threshold = config.clip_threshold
        else:
            self.user_defined_clip_threshold = None
        if hasattr(config, "sigma") and isinstance(config.sigma, float):
            self.sigma = config.sigma
        else:
            self.sigma = 0.01  # in the code of CRFL, the author set sigma to 0.01
        self.total_ite_num = config.comm_round
        if config.dataset == "mnist":
            self.dataset_param_function = self._crfl_compute_param_for_mnist
        elif config.dataset == "emnist":
            self.dataset_param_function = self._crfl_compute_param_for_emnist
        elif config.dataset == "lending_club_loan":
            self.dataset_param_function = self._crfl_compute_param_for_loan
        elif self.user_defined_clip_threshold is not None:
            self.dataset_param_function = self._crfl_self_defined_dataset_param
        else:
            raise Exception(f"dataset not supported: {config.dataset} and clip_threshold not defined ")

    def defend_after_aggregation(self, global_model):
        """
        clip the global model; dynamic threshold is adjusted according to the dataset;
        in the experiment, the authors set the dynamic threshold as follows:
           dataset == MNIST: dynamic_thres = epoch * 0.1 + 2
           dataseet == LOAN: dynamic_thres = epoch * 0.025 + 2
           datset == EMNIST: dynamic_thres = epoch * 0.25 + 4
        """
        clip_threshold = self.dataset_param_function()
        if self.user_defined_clip_threshold is not None and self.user_defined_clip_threshold < clip_threshold:
            clip_threshold = self.user_defined_clip_threshold

        global_model = self.clip_weight_norm(global_model, clip_threshold)
        if self.epoch == self.total_ite_num:
            return global_model
        self.epoch += 1
        new_global_model = OrderedDict()
        for k in global_model.keys():
            new_global_model[k] = global_model[k] + Gaussian.compute_noise_using_sigma(self.sigma, global_model[k].shape)
        return new_global_model

    def _crfl_self_defined_dataset_param(self):
        return self.user_defined_clip_threshold

    def _crfl_compute_param_for_mnist(self):
        return self.epoch * 0.1 + 2

    def _crfl_compute_param_for_loan(self):
        return self.epoch * 0.025 + 2

    def _crfl_compute_param_for_emnist(self):
        return self.epoch * 0.25 + 4

    @staticmethod
    def clip_weight_norm(model, clip_threshold):
        total_norm = utils.compute_model_norm(model)
        print(f"total_norm = {total_norm}")
        if total_norm > clip_threshold:
            clip_coef = clip_threshold / (total_norm + 1e-6)
            new_model = OrderedDict()
            for k in model.keys():
                new_model[k] = model[k] * clip_coef
            return new_model
        return model