FedSecurity / attack / lazy_worker.py
lazy_worker.py
Raw
import random
from collections import OrderedDict

import numpy as np
from typing import Callable
import torch
from .attack_base import BaseAttackMethod
from ..common.utils import is_weight_param, sample_some_clients
from typing import List, Tuple, Dict, Any


class LazyWorkerAttack(BaseAttackMethod):
    def __init__(self, config):
        self.lazy_worker_num = config.lazy_worker_num
        self.attack_mode = (
            config.attack_mode
        )  # random: randomly generate a weight; zero: set the weight to 0
        self.device = device.get_device(config)
        if self.attack_mode == "gaussian":
            if (
                hasattr(config, "gaussian_mu")
                and isinstance(config.gaussian_mu, float)
                and hasattr(config, "gaussian_sigma")
                and isinstance(config.gaussian_sigma, float)
            ):
                self.gaussian_mu = config.gaussian_mu
                self.gaussian_sigma = config.gaussian_sigma
            else:
                self.gaussian_mu = 0.0
                self.gaussian_sigma = 1.0
        self.client_cache = []
        self.round = 1
        # global or client, indicating which model of the last round to use
        self.attack_base = config.attack_base

    def attack_model(
        self,
        raw_client_grad_list: List[Tuple[float, OrderedDict]],
        extra_auxiliary_info: Any = None,
    ):
        if self.round == 1:
            self.client_cache = [grad for (_, grad) in raw_client_grad_list]
            return raw_client_grad_list
        self.round = self.round + 1

        if self.attack_mode == "uniform":
            mask_func = self.uniform_mask
        elif self.attack_mode == "random":
            mask_func = self.random_mask
        elif self.attack_mode == "gaussian":
            mask_func = self.gaussian_mask
        elif self.attack_mode == "no_mask":
            mask_func = self.no_mask
        else:
            raise NotImplementedError("Method not implemented!")

        if len(raw_client_grad_list) < self.lazy_worker_num:
            self.lazy_worker_num = len(raw_client_grad_list)
        lazy_worker_idxs = sample_some_clients(
            len(raw_client_grad_list), self.lazy_worker_num
        )
        print(f"lazy_worker_idxs={lazy_worker_idxs}")
        new_model_list = []
        for i in range(0, len(raw_client_grad_list)):
            if i not in lazy_worker_idxs:
                new_model_list.append(raw_client_grad_list[i])
                _, params = raw_client_grad_list[i]
                self.client_cache[i] = params
            else:
                local_sample_number, _ = raw_client_grad_list[i]
                if self.attack_base == "global":
                    previous_model_params = extra_auxiliary_info  # global model
                else:  # client
                    previous_model_params = self.client_cache[i]
                previous_model_params = mask_func(previous_model_params)
                new_model_list.append((local_sample_number, previous_model_params))
        return new_model_list

    def _add_a_mask_on_clients(self, model_list, lazy_worker_idxs, mask_func: Callable):
        new_model_list = []
        for i in range(0, len(model_list)):
            if i not in lazy_worker_idxs:
                new_model_list.append(model_list[i])
                _, params = model_list[i]
                self.client_cache[i] = params
            else:
                local_sample_number, _ = model_list[i]
                previous_model_params = self.client_cache[i]
                previous_model_params = mask_func(previous_model_params)
                new_model_list.append((local_sample_number, previous_model_params))
        return new_model_list

    def _add_a_mask_on_global(self, model_list, lazy_worker_idxs, mask_func: Callable):
        new_model_list = []
        for i in range(0, len(model_list)):
            if i not in lazy_worker_idxs:
                new_model_list.append(model_list[i])
            else:
                local_sample_number, _ = model_list[i]
                previous_model_params = self.client_cache[i]
                previous_model_params = mask_func(previous_model_params)
                new_model_list.append((local_sample_number, previous_model_params))
        return new_model_list

    def random_mask(self, previous_model_params):
        # add a random mask in [-1, 1]
        for k in previous_model_params.keys():
            if is_weight_param(k):
                previous_model_params[k] = (
                    torch.from_numpy(
                        2
                        * np.random.random_sample(size=previous_model_params[k].size())
                        - 1
                        + previous_model_params[k]
                    )
                    .float()
                    .to(self.device)
                )
        return previous_model_params

    def gaussian_mask(self, previous_model_params):
        # add a gaussian mask
        for k in previous_model_params.keys():
            if is_weight_param(k):
                previous_model_params[k] = previous_model_params[k] + torch.normal(
                    mean=self.gaussian_mu,
                    std=self.gaussian_sigma,
                    size=previous_model_params[k].size(),
                )
        return previous_model_params

    def uniform_mask(self, previous_model):
        # randomly generate a uniform mask
        unif_param = random.uniform(-1, 1)
        print(f"unif_mode_param = {unif_param}")
        for k in previous_model.keys():
            if is_weight_param(k):
                previous_model[k] = (
                    torch.from_numpy(
                        np.full(previous_model[k].size(), unif_param)
                        + previous_model[k]
                    )
                    .float()
                    .to(self.device)
                )
        return previous_model

    def no_mask(self, previous_model_params):
        # directly return the model in the last round
        return previous_model_params