FedSecurity / attack / dlg_attack.py
dlg_attack.py
Raw
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from typing import List, Tuple, Any
from model.cv.resnet import resnet56, resnet20
from .attack_base import BaseAttackMethod
from ..common.net import LeNet
from ..common.utils import cross_entropy_for_onehot
from typing import Any


class DLGAttack(BaseAttackMethod):
    def __init__(self, args):
        self.model = None
        self.model_type = args.model


        if args.dataset in ["cifar10", "cifar100"]:
            self.original_data_size = torch.Size([1, 3, 32, 32])
            if args.dataset == "cifar10":
                self.original_label_size = torch.Size([1, 10])
            else:
                self.original_label_size = torch.Size([1, 100])
        elif args.dataset == "mnist":
            self.original_data_size = torch.Size([1, 28, 28])
            self.original_label_size = torch.Size([1, 10])
        else:
            raise Exception(f"do not support this dataset for DLG attack: {args.dataset}")
        self.criterion = cross_entropy_for_onehot
        # cifar 100:
        # original data size = torch.Size([1, 3, 32, 32])
        # original label size = torch.Size([1, 100])
        self.attack_round_num = args.attack_round_num
        self.device = device.get_device(args)

        # extension: we allow users to protect some layers, which means the adversary can not get access to such layers
        if hasattr(args, "protected_layers"):
            self.protected_layers = args.protected_layers  # it should be a list of integers
        else:
            self.protected_layers = None  # no protected layers
        self.iteration_num = 0
        if hasattr(args, "attack_iteration_idxs"):
            self.attack_iteration_idxs = args.attack_iteration_idxs
        else:
            self.attack_iteration_idxs = [args.comm_round - 1]
            # attack the last iteration, as it contains more information

    def get_model(self):
        if self.model_type == "LeNet":
            return LeNet()
        elif self.model_type == "resnet56":
            return resnet56(10, pretrained=False, path=None)
        elif self.model_type == "resnet20":
            return resnet20(10)
        else:
            raise Exception(f"do not support this model: {self.model_type}")

    def reconstruct_data(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None):
        if self.iteration_num in self.attack_iteration_idxs:
            for (_, local_model) in raw_client_grad_list:
                print(f"-----------attack---------------")
                self.reconstruct_data_using_a_model(a_model=local_model, extra_auxiliary_info=extra_auxiliary_info)
        self.iteration_num += 1

    def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_info: Any = None):
        self.model = self.get_model()
        global_model_of_last_round = extra_auxiliary_info
        gradient = []

        layer_counter = 0
        for k, _ in global_model_of_last_round.items():
            if "weight" in k or "bias" in k:
                if self.protected_layers is not None and layer_counter in self.protected_layers:
                    gradient.append(torch.from_numpy(np.zeros(global_model_of_last_round[k].size())).float())
                    # if the layer is protected, set to 0
                else:
                    gradient.append(a_model[k] - global_model_of_last_round[k].to(self.device))  # !!!!!!!!!!!!!!!!!!todo: to double check
                layer_counter += 1
        gradient = tuple(gradient)
        dummy_data = torch.randn(self.original_data_size)
        dummy_label = torch.randn(self.original_label_size)
        optimizer = torch.optim.LBFGS([dummy_data, dummy_label])
        tt = transforms.ToPILImage()

        history = []
        for iters in range(self.attack_round_num):
            def closure():
                optimizer.zero_grad()
                dummy_pred = self.model(dummy_data)
                dummy_loss = self.criterion(dummy_pred, F.softmax(dummy_label, dim=-1))
                dummy_grad = torch.autograd.grad(dummy_loss, self.model.parameters(), create_graph=True)
                dummy_grad = tuple(
                    g.to(self.device) for g in dummy_grad)  # extract tensor from tuple and move to device

                grad_diff = 0
                # print(f"len(dummy_grad) = {len(dummy_grad)}, len(gradient) = {len(gradient)}")
                # print(f"gradient={gradient}")
                # print(f"dummy_grad={dummy_grad}")
                for gx, gy in zip(dummy_grad, gradient):
                    grad_diff += ((gx - gy) ** 2).sum()
                grad_diff.backward()
                return grad_diff

            optimizer.step(closure)

            # print(f"dummy_data = {dummy_data}")
            # print(f"dummy_label = {dummy_label}")
            if iters % 50 == 0:
                current_loss = closure()
                print(iters, "%.4f" % current_loss.item())
                history.append(tt(dummy_data[0].cpu()))

        # plot figures:
        # plt.figure(figsize=(12, 8))
        # for i in range(len(history)):
        #     plt.subplot(3, 10, i + 1)
        #     plt.imshow(history[i])
        #     plt.title("iter=%d" % (i * 50))
        #     plt.axis('off')
        # # plt.show()
        # filename = str(self.iteration_num) + '_' + str(round(time.time()*1000)) + '.png'
        # plt.savefig(filename)

        return dummy_data, dummy_label