import numpy as np from scipy import spatial from .defense_base import BaseDefenseMethod from typing import Callable, List, Tuple, Dict, Any from collections import OrderedDict from ..common.utils import ( compute_euclidean_distance, ) import torch class CrossRoundDefense(BaseDefenseMethod): def __init__(self, config): self.potentially_poisoned_worker_list = [] self.lazy_worker_list = None # self.upperbound = 1 # cosine similarity > upperbound: ``very limited difference''-> lazy worker self.lowerbound = config.cosine_similarity_bound # cosine similarity < lowerbound attack may happen; need further defense self.client_cache = dict() self.training_round = 1 self.is_attack_existing = True # for the first round, true self.temp_client_features = None self.global_model_feature = None self.total_client_num = -1 self.zero_reference = None self.upperbound = 1 # 0.999999 def defend_before_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): self.temp_client_features = self._get_importance_feature(raw_client_grad_list) if self.training_round == 1: # set attack exists by default for the first round and leave for second phase self.training_round += 1 self.total_client_num = len(raw_client_grad_list) # self.client_cache = self.temp_client_features self.potentially_poisoned_worker_list = range(self.total_client_num) # Create a new vector with the same shape as feature_vector but with all weights being zero self.zero_reference = np.zeros(self.temp_client_features[0].shape) return raw_client_grad_list self.is_attack_existing = False self.lazy_worker_list = [] self.potentially_poisoned_worker_list = [] # extra_auxiliary_info: global model self.global_model_feature = self._get_importance_feature_of_a_model( extra_auxiliary_info ) if self.training_round == 2: for i in range(self.total_client_num): if i not in self.client_cache: self.client_cache[i] = self.global_model_feature client_wise_scores, global_wise_scores, zero_wise_scores = self.compute_client_cosine_scores( client_features=self.temp_client_features, global_model_feature=self.global_model_feature, zero_reference=self.zero_reference ) for i in range(len(client_wise_scores)): # if ( # client_wise_scores[i] < self.lowerbound # or global_wise_scores[i] < self.lowerbound # ): # self.lazy_worker_list.append(i) # will be directly kicked out later if client_wise_scores[i] < self.lowerbound or global_wise_scores[i] < self.lowerbound: self.is_attack_existing = True self.potentially_poisoned_worker_list.append(i) # for i in range(len(self.temp_client_features) - 1, -1, -1): # # if i in self.lazy_worker_list: # # raw_client_grad_list.pop(i) # if i not in self.potentially_poisoned_worker_list: # self.client_cache[i] = self.temp_client_features[i] self.training_round += 1 print( f"!!!!!!!!!!!!!!!!!!!!first phase: self.potentially_poisoned_worker_list = {self.potentially_poisoned_worker_list}") return raw_client_grad_list # def compute_gaussian_distribution(score_list): # n = len(score_list) # mu = sum(list(score_list)) / n # temp = 0 # for i in range(len(score_list)): # temp = (((score_list[i] - mu) ** 2) / (n - 1)) + temp # sigma = math.sqrt(temp) # return mu, sigma def compute_l2_scores(self, importance_feature_list): client_wise_distance_scores = [] global_wise_distance_scores = [] for i in range(len(importance_feature_list)): client_wise_distance_score = compute_euclidean_distance(torch.Tensor(importance_feature_list[i]), self.client_cache[i]) global_wise_distance_score = compute_euclidean_distance(torch.Tensor(importance_feature_list[i]), self.global_model_feature) client_wise_distance_scores.append(client_wise_distance_score) global_wise_distance_scores.append(global_wise_distance_score) return client_wise_distance_scores, global_wise_distance_scores def renew_cache(self, real_poisoned_client_ids): for i in range(self.total_client_num): if i not in real_poisoned_client_ids: self.client_cache[i] = self.temp_client_features[i] else: if i not in self.client_cache and self.global_model_feature is not None: self.client_cache[i] = self.global_model_feature def get_potential_poisoned_clients(self): return self.potentially_poisoned_worker_list def compute_client_cosine_scores(self, client_features, global_model_feature, zero_reference): client_wise_scores = [] global_wise_scores = [] zero_wise_scores = [] num_client = len(client_features) for i in range(0, num_client): # spatial.distance.cosine ranges from 0 to 2; cosine_similarity below ranges from -1 to 1 cosine_similarity = 1 - spatial.distance.cosine(client_features[i], self.client_cache[i]) client_wise_scores.append(cosine_similarity) cosine_similarity = 1 - spatial.distance.cosine(client_features[i], global_model_feature) global_wise_scores.append(cosine_similarity) # cosine_similarity = 1 - spatial.distance.cosine(client_features[i], zero_reference) cosine_similarity = 1 - spatial.distance.cosine(client_features[i], np.zeros(client_features[i].shape)) # np.zeros(self.temp_client_features[0].shape) zero_wise_scores.append(cosine_similarity) return client_wise_scores, global_wise_scores, zero_wise_scores def _get_importance_feature(self, raw_client_grad_list): ret_feature_vector_list = [] for idx in range(len(raw_client_grad_list)): raw_grad = raw_client_grad_list[idx] (p, grad) = raw_grad feature_vector = self._get_importance_feature_of_a_model(grad) ret_feature_vector_list.append(feature_vector) return ret_feature_vector_list @classmethod def _get_importance_feature_of_a_model(self, grad): # Get last key-value tuple (weight_name, importance_feature) = list(grad.items())[-2] # print(importance_feature) feature_len = np.array( importance_feature.cpu().data.detach().numpy().shape ).prod() feature_vector = np.reshape( importance_feature.cpu().data.detach().numpy(), feature_len ) return feature_vector