FedSecurity / defense / three_sigma_defense.py
three_sigma_defense.py
Raw
import logging
import math
from collections import OrderedDict
import numpy as np
from .defense_base import BaseDefenseMethod
from typing import List, Tuple, Dict, Any
from ..common.utils import (
    compute_euclidean_distance,
    compute_middle_point,
    compute_krum_score, compute_gaussian_distribution
)
import torch


class ThreeSigmaDefense(BaseDefenseMethod):
    def __init__(self, config):
        self.average = None
        self.upper_bound = 0
        self.malicious_client_idxs = []
        # OutlierDetectionDefense will set this list; when it is empty, kick out detected malicious models directly
        self.potential_malicious_client_idxs = None
        self.device = device.get_device(config)
        self.lambda_value = 0.5
        # if hasattr(config, "bound_param") and isinstance(config.bound_param, float):
        #     self.bound_param = config.bound_param
        # else:
        #     self.bound_param = 1

    ###################### version 3: re-compute gaussian distribution each round
    def defend_before_aggregation(
            self,
            raw_client_grad_list: List[Tuple[float, OrderedDict]],
            extra_auxiliary_info: Any = None,
    ):
        importance_feature_list = self._get_importance_feature(raw_client_grad_list)
        if self.average is None:
            self.average = self.compute_avg_with_krum(importance_feature_list)
        client_scores = self.compute_l2_scores(importance_feature_list)
        mu, sigma = compute_gaussian_distribution(client_scores)
        # self.upper_bound = mu + self.bound_param * sigma
        # self.upper_bound0_1 = mu + 0.1 * sigma
        # self.upper_bound0_3 = mu + 0.3 * sigma
        # self.upper_bound0_5 = mu + 0.5 * sigma
        self.upper_bound = mu + self.lambda_value * sigma
        # self.upper_bound1_5 = mu + 1.5 * sigma
        # self.upper_bound2 = mu + 2 * sigma
        # self.upper_bound2_5 = mu + 2.5 * sigma
        # self.upper_bound3 = mu + 3 * sigma
        # print(f"client socres = {client_scores}")
        # print(f"mu = {mu}, sigma = {sigma}, upperbound = {self.upper_bound}")
        new_client_models, _ = self.kick_out_poisoned_local_models(
            client_scores, raw_client_grad_list
        )
        importance_feature_list = self._get_importance_feature(new_client_models)
        self.average = self.compute_an_average_feature(importance_feature_list)
        return new_client_models

    def compute_an_average_feature(self, importance_feature_list):
        alphas = [1 / len(importance_feature_list)] * len(importance_feature_list)
        return compute_middle_point(alphas, importance_feature_list)

    ##################### version 2: remove poisoned model scores in score list
    # def defend_before_aggregation(
    #     self,
    #     raw_client_grad_list: List[Tuple[float, OrderedDict]],
    #     extra_auxiliary_info: Any = None,
    # ):
    #     if self.median is None:
    #         self.median = self.compute_median_with_krum(raw_client_grad_list)
    #     client_scores = self.compute_scores(raw_client_grad_list)
    #     print(f"client scores = {client_scores}")
    #     if self.iteration_num < self.pretraining_round_number:
    #         mu, sigma = compute_gaussian_distribution(self.score_list, client_scores)
    #         self.upper_bound = mu + self.bound_param * sigma
    #         print(f"mu = {mu}, sigma = {sigma}, upperbound = {self.upper_bound}")
    #         new_client_models, client_scores = self.kick_out_poisoned_local_models(client_scores, raw_client_grad_list)
    #         print(f"new scores after kicking out = {client_scores}")
    #         self.score_list.extend(list(client_scores))
    #         mu, sigma = compute_gaussian_distribution(self.score_list, [])
    #         self.upper_bound = mu + self.bound_param * sigma
    #         print(f"mu = {mu}, sigma = {sigma}, upperbound = {self.upper_bound}")
    #     else:
    #         new_client_models, _ = self.kick_out_poisoned_local_models(client_scores, raw_client_grad_list)
    #     self.iteration_num += 1
    #     return new_client_models

    ###################### version 1: do not remove poisoned model scores in score list
    # def defend_before_aggregation(
    #     self,
    #     raw_client_grad_list: List[Tuple[float, OrderedDict]],
    #     extra_auxiliary_info: Any = None,
    # ):
    #     if self.median is None:
    #         self.median = self.compute_median_with_krum(raw_client_grad_list)
    #     client_scores = self.compute_scores(raw_client_grad_list)
    #     print(f"client scores = {client_scores}")
    #
    #     if self.iteration_num < self.pretraining_round_number:
    #         self.score_list.extend(list(client_scores))
    #         self.mu, self.sigma = compute_gaussian_distribution(self.score_list)
    #         self.upper_bound = self.mu + self.bound_param * self.sigma
    #         self.iteration_num += 1
    #
    #     for i in range(len(client_scores) - 1, -1, -1):
    #         if client_scores[i] > self.upper_bound:
    #      # we do not remove the score in self.score_list to avoid mis-deleting due to severe non-iid among clients
    #             raw_client_grad_list.pop(i)
    #             print(f"pop -- i = {i}")
    #     return raw_client_grad_list

    def kick_out_poisoned_local_models(self, client_scores, raw_client_grad_list):
        # print(f"upper bound = {self.upper_bound}")
        # traverse the score list in a reversed order
        # malicious_client_idxs_0_5dev = []
        # malicious_client_idxs_0_1dev = []
        # malicious_client_idxs_0_3dev = []
        self.malicious_client_idxs = []
        # malicious_client_idxs_1_5dev = []
        # # logging.info(f"!!!!!!!!!!!!!!!!!!!!potential_malicious_client_idxs = {self.potential_malicious_client_idxs}")
        # # self.log_file.write(f"potential_malicious_client_idxs = {self.potential_malicious_client_idxs}")
        # malicious_client_idxs_2dev = []
        # malicious_client_idxs_2_5dev = []
        # malicious_client_idxs_3dev = []
        for i in range(len(client_scores) - 1, -1, -1):
            # if client_scores[i] > self.upper_bound0_1:
            #     malicious_client_idxs_0_1dev.append(i)
            # if client_scores[i] > self.upper_bound0_3:
            #     malicious_client_idxs_0_3dev.append(i)
            # if client_scores[i] > self.upper_bound0_5:
            #     malicious_client_idxs_0_5dev.append(i)
            # if client_scores[i] > self.upper_bound1_5:
            #     malicious_client_idxs_1_5dev.append(i)
            # if client_scores[i] > self.upper_bound2:
            #     malicious_client_idxs_2dev.append(i)
            # if client_scores[i] > self.upper_bound2_5:
            #     malicious_client_idxs_2_5dev.append(i)
            # if client_scores[i] > self.upper_bound3:
            #     malicious_client_idxs_3dev.append(i)
            if client_scores[i] > self.upper_bound:
                logging.info(f"second phase: scores that exeed bound: {i}")
                if self.potential_malicious_client_idxs is None or i in self.potential_malicious_client_idxs:
                    raw_client_grad_list.pop(i)
                    self.malicious_client_idxs.append(i)
                    # logging.info(f"kick out -- {i}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious0.1dev: {malicious_client_idxs_0_1dev}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious0.3dev: {malicious_client_idxs_0_3dev}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious0.5dev: {malicious_client_idxs_0_5dev}")
        logging.info(f"!!!!!!!!!!!!!!!!!!!!second phase: detected malicious: {self.malicious_client_idxs}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious1.5dev: {malicious_client_idxs_1_5dev}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious2dev: {malicious_client_idxs_2dev}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious2.5dev: {malicious_client_idxs_2_5dev}")
        # logging.info(f"!!!!!!!!!!!!!!!!!!!!detected malicious3dev: {malicious_client_idxs_3dev}")
        # self.log_file.write(f"detected malicious: {self.malicious_client_idxs}\n\n")
        return raw_client_grad_list, client_scores

    def get_malicious_client_idxs(self):
        return self.malicious_client_idxs

    def set_potential_malicious_clients(self, potential_malicious_client_idxs):
        self.potential_malicious_client_idxs = potential_malicious_client_idxs

    def compute_avg_with_krum(self, importance_feature_list):
        krum_scores = compute_krum_score(
            importance_feature_list,
            client_num_after_trim=math.ceil(len(importance_feature_list) / 2) - 1,
        )
        score_index = torch.argsort(
            torch.Tensor(krum_scores)
        ).tolist()  # indices; ascending
        score_index = score_index[0: math.ceil(len(importance_feature_list) / 2) - 1]
        honest_importance_feature_list = [
            importance_feature_list[i] for i in score_index
        ]
        return self.compute_an_average_feature(honest_importance_feature_list)

    def compute_l2_scores(self, importance_feature_list):
        scores = []
        for feature in importance_feature_list:
            score = compute_euclidean_distance(torch.Tensor(feature), self.average)
            scores.append(score)
        return scores

    # def compute_client_cosine_scores(self, raw_client_grad_list):
    #     importance_feature_list = get_importance_feature(raw_client_grad_list)
    #     cosine_scores = []
    #     num_client = len(importance_feature_list)
    #     for i in range(0, num_client):
    #         dists = []
    #         for j in range(0, num_client):
    #             if i != j:
    #                 dists.append(
    #                     1
    #                     - spatial.distance.cosine(
    #                         importance_feature_list[i], importance_feature_list[j]
    #                     )
    #                 )
    #         cosine_scores.append(sum(dists) / len(dists))
    #     return cosine_scores

    def _get_importance_feature(self, raw_client_grad_list):
        # print(f"raw_client_grad_list = {raw_client_grad_list}")
        # Foolsgold uses the last layer's gradient/weights as the importance feature.
        ret_feature_vector_list = []
        for idx in range(len(raw_client_grad_list)):
            raw_grad = raw_client_grad_list[idx]
            (p, grads) = raw_grad

            # Get last key-value tuple
            (weight_name, importance_feature) = list(grads.items())[-2]
            # print(importance_feature)
            feature_len = np.array(
                importance_feature.cpu().data.detach().numpy().shape
            ).prod()
            feature_vector = importance_feature.cpu().data.detach().view(feature_len)
            # feature_vector = np.reshape(
            #     importance_feature.cpu().data.detach().numpy(), feature_len
            # )
            ret_feature_vector_list.append(feature_vector)
        return ret_feature_vector_list