FedSecurity / defense / norm_diff_clipping_defense.py
norm_diff_clipping_defense.py
Raw
from collections import OrderedDict
import torch
from .defense_base import BaseDefenseMethod
from ..common import utils
from typing import Callable, List, Tuple, Dict, Any


class NormDiffClippingDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.config = config
        self.norm_bound = config.norm_bound  # for norm diff clipping; in the paper, they set it to 0.1, 0.17, and 0.33.

    def defend_before_aggregation(
            self,
            raw_client_grad_list: List[Tuple[float, OrderedDict]],
            extra_auxiliary_info: Any = None,
    ):
        global_model = extra_auxiliary_info
        vec_global_w = utils.vectorize_weight(global_model)
        new_grad_list = []
        for (sample_num, local_w) in raw_client_grad_list:
            vec_local_w = utils.vectorize_weight(local_w)
            clipped_weight_diff = self._get_clipped_norm_diff(vec_local_w, vec_global_w)
            clipped_w = self._get_clipped_weights(
                local_w, global_model, clipped_weight_diff
            )
            new_grad_list.append((sample_num, clipped_w))
        return new_grad_list

    def _get_clipped_norm_diff(self, vec_local_w, vec_global_w):
        vec_diff = vec_local_w - vec_global_w
        weight_diff_norm = torch.norm(vec_diff).item()
        clipped_weight_diff = vec_diff / max(1, weight_diff_norm / self.norm_bound)
        return clipped_weight_diff

    @staticmethod
    def _get_clipped_weights(local_w, global_w, weight_diff):
        #  rule: global_w + clipped(local_w - global_w)
        recons_local_w = OrderedDict()
        index_bias = 0
        for item_index, (k, v) in enumerate(local_w.items()):
            if utils.is_weight_param(k):
                recons_local_w[k] = (
                    weight_diff[index_bias: index_bias + v.numel()].view(v.size())
                    + global_w[k]
                )
                index_bias += v.numel()
            else:
                recons_local_w[k] = v
        return recons_local_w