ViTGuard / target_models / WhiteBox.py
WhiteBox.py
Raw
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import sys
import numpy as np
import time
# sys.path.append('/home/artifacts/ViTGuard-main/')
sys.path.append('/home/native/ViTGuard_artifacts/ViTGuard-main/')
import DataManagerPytorch as DMP

sys.path.append('target_models/')
from utils import clamp, PCGrad


def FGSMAttack(device, dataLoader, model, epsilonMax, clipMin=0, clipMax=1):
    model.eval() 
    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    #Go through each sample 
    tracker = 0

    for xData, yData in dataLoader:
        batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later
        tracker = tracker + batchSize
        #Put the data from the batch onto the device 
        xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device)
        yData = yData.type(torch.LongTensor).to(device)
        # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) 
        xDataTemp.requires_grad = True
        # Forward pass the data through the model
        output = model(xDataTemp)
        # Calculate the loss
        loss = torch.nn.CrossEntropyLoss()
        # Zero all existing gradients
        model.zero_grad()
        # Calculate gradients of model in backward pass
        cost = loss(output, yData).to(device)
        cost.backward()
        # Collect the element-wise sign of the data gradient
        signDataGrad = xDataTemp.grad.data.sign()
        perturbedImage = xData + epsilonMax*signDataGrad.cpu().detach() #FGSM
        # Adding clipping to maintain the range
        perturbedImage = torch.clamp(perturbedImage, clipMin, clipMax)
        #Save the adversarial images from the batch 
        for j in range(0, batchSize):
            xAdv[advSampleIndex] = perturbedImage[j]
            yClean[advSampleIndex] = yData[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
        #Not sure if we need this but do some memory clean up 
        del xDataTemp
        del signDataGrad
        torch.cuda.empty_cache()    
    #All samples processed, now time to save in a dataloader and return
    yClean = yClean.type(torch.LongTensor)
    advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader


def PGDAttack(device, dataLoader, model, eps, num_steps, step_size, clipMin=0, clipMax=1):
    model.eval() 
    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    #Go through each sample 
    tracker = 0

    for xData, yData in dataLoader:
        batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later
        tracker = tracker + batchSize
        #Put the data from the batch onto the device 
        xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device)
        xData = xData.to(device)
        yData = yData.type(torch.LongTensor).to(device)

        # Forward pass the data through the model
        # Calculate the loss
        loss = torch.nn.CrossEntropyLoss()

        for i in range(num_steps):
            # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) 
            xDataTemp.requires_grad = True

            output = model(xDataTemp)

            # Zero all existing gradients
            model.zero_grad()
            # Calculate gradients of model in backward pass
            cost = loss(output, yData)
            cost.backward()
            # Collect the element-wise sign of the data gradient
            signDataGrad = xDataTemp.grad.data.sign()
            xDataTemp = xDataTemp.detach() + step_size*signDataGrad #FGSM
            # perturbedImage = perturbedImage.detach().cpu()

            # Adding clipping to maintain the range
            delta = torch.clamp(xDataTemp-xData, min=-eps, max=eps)
            xDataTemp = torch.clamp(xData+delta, clipMin, clipMax)

        # xDataTemp = xDataTemp.detach().cpu()
        #Save the adversarial images from the batch 
        for j in range(0, batchSize):
            xAdv[advSampleIndex] = xDataTemp[j]
            yClean[advSampleIndex] = yData[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
        del xDataTemp
        del signDataGrad
        torch.cuda.empty_cache()    
    #All samples processed, now time to save in a dataloader and return
    yClean = yClean.type(torch.LongTensor)
    advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader, xAdv.numpy(), yClean.numpy()




class CW:
    def __init__(self, model, device=None, c=1, kappa=0, steps=50, lr=0.01):
        self.c = c
        self.kappa = kappa
        self.steps = steps
        self.lr = lr
        self.device = device
        self.model = model
    
    def forward(self, images, labels):
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        # w = torch.zeros_like(images).detach() # Requires 2x times
        w = self.inverse_tanh_space(images).detach()
        w.requires_grad = True

        best_adv_images = images.clone().detach()
        best_L2 = 1e10*torch.ones((len(images))).to(self.device)
        prev_cost = 1e10
        dim = len(images.shape)
    
        MSELoss = nn.MSELoss(reduction='none')
        Flatten = nn.Flatten()
        
        optimizer = optim.Adam([w], lr=self.lr)

        for step in range(self.steps):
            # Get adversarial images
            adv_images = self.tanh_space(w)

            # Calculate loss
            current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(images)).sum(dim=1)
            L2_loss = current_L2.sum()

            # outputs = self.get_logits(adv_images)
            outputs = self.model(adv_images)
            f_loss = self.f(outputs, labels).sum()
            cost = L2_loss + self.c*f_loss
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()

        best_adv_images = self.tanh_space(w).detach()
        return best_adv_images


    def tanh_space(self, x):
        return 1/2*(torch.tanh(x) + 1)

    def inverse_tanh_space(self, x):
        # torch.atanh is only for torch >= 1.7.0
        # atanh is defined in the range -1 to 1
        return self.atanh(torch.clamp(x*2-1, min=-1, max=1))

    def atanh(self, x):
        return 0.5*torch.log((1+x)/(1-x))

    # f-function in the paper
    def f(self, outputs, labels):
        one_hot_labels = torch.eye(outputs.shape[1]).to(self.device)[labels]

        # find the max logit other than the target class
        other = torch.max((1-one_hot_labels)*outputs, dim=1)[0]
        # get the target class's logit
        real = torch.max(one_hot_labels*outputs, dim=1)[0]
        
        return torch.clamp((real-other), min=-self.kappa)


def CWAttack_L2(device, dataLoader, model, c=1e-4, kappa=0, max_iter=30, learning_rate=0.001):

    model.eval() 
    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 

    cw_attack = CW(model=model, device=device, c=c, kappa=kappa, steps=max_iter, lr=learning_rate)

    for images, labels in dataLoader:
        batchSize = images.shape[0]
        adversarial_images = cw_attack.forward(images, labels)

        for j in range(0, batchSize):
            xAdv[advSampleIndex] = adversarial_images[j]
            yClean[advSampleIndex] = labels[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index

        del adversarial_images
        torch.cuda.empty_cache()   

    #All samples processed, now time to save in a dataloader and return
    yClean = yClean.type(torch.LongTensor)
    advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader

def PatchFool(dataLoader, model, atten_select=5, patch_size=16, num_patch=4, n_tokens=197):
    '''
    atten_select: Select patch based on which attention layer
    num_patch: the number of patches selected
    '''
    criterion = nn.CrossEntropyLoss().cuda()
    attack_learning_rate = 0.05
    step_size = 1
    gamma = 0.95
    train_attack_iters = 250
    atten_loss_weight = 0.002

    mu = [0, 0, 0]
    std = [1,1,1]
    mu = torch.tensor(mu).view(3, 1, 1).cuda()
    std = torch.tensor(std).view(3, 1, 1).cuda()


    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    tracker = 0

    for i, (X, y) in enumerate(dataLoader):
        batchSize = X.shape[0]
        tracker = tracker + batchSize
    
        X, y = X.cuda(), y.cuda()
        patch_num_per_line = int(X.size(-1) / patch_size)
        delta = torch.zeros_like(X).cuda()
        delta.requires_grad = True

        model.zero_grad()
        out, atten = model(X + delta, output_attentions=True) 
        loss = criterion(out, y)

        ### choose patch
        atten_layer = atten[atten_select].mean(dim=1) #average across heads
        atten_layer = atten_layer.mean(dim=-2)[:, 1:] 
        max_patch_index = atten_layer.argsort(descending=True)[:, :num_patch]

        #build mask
        mask = torch.zeros([X.size(0), 1, X.size(2), X.size(3)]).cuda()
        for j in range(X.size(0)):
            index_list = max_patch_index[j]
            for index in index_list:
                row = (index // patch_num_per_line) * patch_size
                column = (index % patch_num_per_line) * patch_size
                mask[j, :, row:row + patch_size, column:column + patch_size] = 1

        # adv attack
        max_patch_index_matrix = max_patch_index[:, 0] 
        max_patch_index_matrix = max_patch_index_matrix.repeat(n_tokens, 1) 
        max_patch_index_matrix = max_patch_index_matrix.permute(1, 0) 
        max_patch_index_matrix = max_patch_index_matrix.flatten().long() 

        delta = torch.rand_like(X)
        X = torch.mul(X, 1 - mask)

        delta = delta.cuda()
        delta.requires_grad = True
        opt = torch.optim.Adam([delta], lr=attack_learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)

        # start adv attack
        for train_iter_num in range(train_attack_iters):
            model.zero_grad()
            opt.zero_grad()
            out, atten = model(X + torch.mul(delta, mask), output_attentions=True)

            '''final CE-loss'''
            loss = criterion(out, y)
            grad = torch.autograd.grad(loss, delta, retain_graph=True)[0]
            ce_loss_grad_temp = grad.view(X.size(0), -1).detach().clone()
            # Attack the first 6 layers' Attn
            range_list = range(len(atten)//2)
            for atten_num in range_list:
                if atten_num == 0:
                    continue
                atten_map = atten[atten_num] 
                atten_map = atten_map.mean(dim=1) 
                atten_map = atten_map.view(-1, atten_map.size(-1)) 
                atten_map = -torch.log(atten_map) 
                atten_loss = F.nll_loss(atten_map, max_patch_index_matrix)
                atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0]
                atten_grad_temp = atten_grad.view(X.size(0), -1)
                cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1)
                atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape)
                grad += atten_grad * atten_loss_weight


            opt.zero_grad()
            delta.grad = -grad
            opt.step()
            scheduler.step()
        delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std)
        # Eval Adv Attack
        with torch.no_grad():
            perturb_x = X + torch.mul(delta, mask)
            out = model(perturb_x)
            classification_result_after_attack = out.max(1)[1] == y
            loss = criterion(out, y)
            # print(classification_result_after_attack.sum().item())
    
        for j in range(0, batchSize):
            xAdv[advSampleIndex] = perturb_x[j]
            yClean[advSampleIndex] = y[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
        yClean = yClean.type(torch.LongTensor)
        advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader

def AttentionFool(dataLoader, model, atten_select=5, patch_size=16, num_patch=4, n_tokens=197):
    '''
    atten_select: Select patch based on which attention layer
    num_patch: the number of patches selected
    '''
    criterion = nn.CrossEntropyLoss().cuda()
    attack_learning_rate = 0.04
    step_size = 1
    gamma = 0.95
    train_attack_iters = 250
    atten_loss_weight = 0.00001

    mu = [0, 0, 0]
    std = [1,1,1]
    mu = torch.tensor(mu).view(3, 1, 1).cuda()
    std = torch.tensor(std).view(3, 1, 1).cuda()


    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    tracker = 0

    for i, (X, y) in enumerate(dataLoader):
        batchSize = X.shape[0]
        tracker = tracker + batchSize
    
        X, y = X.cuda(), y.cuda()
        patch_num_per_line = int(X.size(-1) / patch_size)
        delta = torch.zeros_like(X).cuda()
        delta.requires_grad = True

        model.zero_grad()
        out, atten = model(X + delta, output_attentions=True) 
        loss = criterion(out, y)

        ### choose patch
        atten_layer = atten[atten_select].mean(dim=1) #average across heads
        atten_layer = atten_layer.mean(dim=-2)[:, 1:] 
        max_patch_index = atten_layer.argsort(descending=True)[:, :num_patch]

        #build mask
        mask = torch.zeros([X.size(0), 1, X.size(2), X.size(3)]).cuda()
        for j in range(X.size(0)):
            index_list = max_patch_index[j]
            for index in index_list:
                row = (index // patch_num_per_line) * patch_size
                column = (index % patch_num_per_line) * patch_size
                mask[j, :, row:row + patch_size, column:column + patch_size] = 1

        # adv attack
        max_patch_index_matrix = max_patch_index[:, 0] 
        max_patch_index_matrix = max_patch_index_matrix.repeat(n_tokens, 1) 
        max_patch_index_matrix = max_patch_index_matrix.permute(1, 0) 
        max_patch_index_matrix = max_patch_index_matrix.flatten().long() 

        delta = torch.rand_like(X)
        X = torch.mul(X, 1 - mask)

        delta = delta.cuda()
        delta.requires_grad = True
        opt = torch.optim.Adam([delta], lr=attack_learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)

        # start adv attack
        for train_iter_num in range(train_attack_iters):
            model.zero_grad()
            opt.zero_grad()
            out, atten = model(X + torch.mul(delta, mask), output_attentions=True)

            '''final CE-loss'''
            loss = criterion(out, y)
            grad = torch.autograd.grad(loss, delta, retain_graph=True)[0]
            ce_loss_grad_temp = grad.view(X.size(0), -1).detach().clone()
            # Attack the first 6 layers' Attn
            range_list = range(len(atten)//2)
            for atten_num in range_list:
                if atten_num == 0:
                    continue
                atten_map = atten[atten_num] 
                atten_map = atten_map.mean(dim=1) 
                atten_map = atten_map.view(-1, atten_map.size(-1)) 
                atten_map = -torch.log(atten_map) 
                atten_loss = F.nll_loss(atten_map, max_patch_index_matrix)
                atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0]
                atten_grad_temp = atten_grad.view(X.size(0), -1)
                cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1)
                atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape)
                grad += atten_grad * atten_loss_weight


            opt.zero_grad()
            delta.grad = -grad
            opt.step()
            scheduler.step()
        delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std)
        # Eval Adv Attack
        with torch.no_grad():
            perturb_x = X + torch.mul(delta, mask)
            out = model(perturb_x)
            classification_result_after_attack = out.max(1)[1] == y
            loss = criterion(out, y)
            # print(classification_result_after_attack.sum().item())
    
        for j in range(0, batchSize):
            xAdv[advSampleIndex] = perturb_x[j]
            yClean[advSampleIndex] = y[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
        yClean = yClean.type(torch.LongTensor)
        advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader


class APGD:
    def __init__(self, model, device=None, norm='Linf', eps=8/255, steps=10, n_restarts=1, seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False):
        # super().__init__('APGD', model, device)
        self.eps = eps
        self.steps = steps
        self.norm = norm
        self.n_restarts = n_restarts
        self.seed = seed
        self.loss = loss
        self.eot_iter = eot_iter
        self.thr_decr = rho
        self.verbose = verbose
        self.supported_mode = ['default']

        self.device = device
        self.model = model
        

    def forward(self, images, labels):
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)
        _, adv_images = self.perturb(images, labels, cheap=True)

        return adv_images

    def perturb(self, x_in, y_in, best_loss=False, cheap=True):
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)

        adv = x.clone()
        logits = self.model(x_in)
        acc = logits.max(1)[1] == y
        if self.verbose:
            print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format(self.norm, self.eps))
            print('initial accuracy: {:.2%}'.format(acc.float().mean()))
        startt = time.time()
        if not best_loss:
            torch.random.manual_seed(self.seed)
            torch.cuda.random.manual_seed(self.seed)
            if not cheap:
                raise ValueError('not implemented yet')
            else:
                for counter in range(self.n_restarts):
                    ind_to_fool = acc.nonzero().squeeze() ######### extract the correctly classified samples
                    if len(ind_to_fool.shape) == 0:
                        ind_to_fool = ind_to_fool.unsqueeze(0)
                    if ind_to_fool.numel() != 0:
                        x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone()  # nopep8
                        best_curr, acc_curr, loss_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)  # nopep8 #x_best, acc, loss_best, x_best_adv
                        ind_curr = (acc_curr == 0).nonzero().squeeze()
                        #
                        acc[ind_to_fool[ind_curr]] = 0
                        adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
                        if self.verbose:
                            print('restart {} - robust accuracy: {:.2%} - cum. time: {:.1f} s'.format(
                                counter, acc.float().mean(), time.time() - startt))
            return acc, adv


    def attack_single_run(self, x_in, y_in):
        x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0)
        y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0)
        self.steps_2, self.steps_min, self.size_decr = max(int(
            0.22 * self.steps), 1), max(int(0.06 * self.steps), 1), max(int(0.03 * self.steps), 1)
        if self.verbose:
            print('parameters: ', self.steps, self.steps_2,
                  self.steps_min, self.size_decr)
        if self.norm == 'Linf':
            t = 2 * torch.rand(x.shape).to(self.device).detach() - 1
            x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / (t.reshape([t.shape[0], -1]).abs().max(dim=1, keepdim=True)[0].reshape([-1, 1, 1, 1]))  # nopep8
        elif self.norm == 'L2':
            t = torch.randn(x.shape).to(self.device).detach()
            x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ((t ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12)  # nopep8
        x_adv = x_adv.clamp(0., 1.)
        x_best = x_adv.clone()
        x_best_adv = x_adv.clone()
        loss_steps = torch.zeros([self.steps, x.shape[0]])
        loss_best_steps = torch.zeros([self.steps + 1, x.shape[0]])
        acc_steps = torch.zeros_like(loss_best_steps)
        if self.loss == 'ce':
            criterion_indiv = nn.CrossEntropyLoss(reduction='none')
        elif self.loss == 'dlr':
            criterion_indiv = self.dlr_loss
        else:
            raise ValueError('unknown loss')
        
        x_adv.requires_grad_()
        grad = torch.zeros_like(x)
        for _ in range(self.eot_iter):
            with torch.enable_grad():
                # 1 forward pass (eot_iter = 1)
                logits = self.model(x_adv)
                loss_indiv = criterion_indiv(logits, y)
                loss = loss_indiv.sum()
            grad += torch.autograd.grad(loss, [x_adv])[0].detach()
        grad /= float(self.eot_iter)
        grad_best = grad.clone()

        acc = logits.detach().max(1)[1] == y
        acc_steps[0] = acc + 0 #??????????????????????????
        loss_best = loss_indiv.detach().clone()

        step_size = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to(self.device).detach().reshape([1, 1, 1, 1])  # nopep8
        x_adv_old = x_adv.clone()
        counter = 0
        k = self.steps_2 + 0
        u = np.arange(x.shape[0])
        counter3 = 0

        loss_best_last_check = loss_best.clone()
        reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape)

        # n_reduced = 0
        for i in range(self.steps):
            # print(i)
            # gradient step
            with torch.no_grad():
                x_adv = x_adv.detach()
                grad2 = x_adv - x_adv_old
                x_adv_old = x_adv.clone()

                a = 0.75 if i > 0 else 1.0

                if self.norm == 'Linf':
                    x_adv_1 = x_adv + step_size * torch.sign(grad)
                    x_adv_1 = torch.clamp(
                        torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0)
                    x_adv_1 = torch.clamp(torch.min(torch.max(
                        x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self.eps), x + self.eps), 0.0, 1.0)

                elif self.norm == 'L2':
                    x_adv_1 = x_adv + step_size * grad / ((grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12)  # nopep8
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0)  # nopep8
                    x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
                    x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0)  # nopep8

                x_adv = x_adv_1 + 0.
            
            # get gradient
            x_adv.requires_grad_()
            grad = torch.zeros_like(x)
            for _ in range(self.eot_iter):
                with torch.enable_grad():
                    # 1 forward pass (eot_iter = 1)
                    logits = self.model(x_adv)
                    loss_indiv = criterion_indiv(logits, y)
                    loss = loss_indiv.sum()
                # 1 backward pass (eot_iter = 1)
                grad += torch.autograd.grad(loss, [x_adv])[0].detach()

            grad /= float(self.eot_iter)

            pred = logits.detach().max(1)[1] == y
            acc = torch.min(acc, pred)
            acc_steps[i + 1] = acc + 0 # acc is a Boolean matrix, adding zero to it will not alter its elements
            x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0.  # nopep8 #(pred == 0) make wrong classification
            if self.verbose:
                print('iteration: {} - Best loss: {:.6f}'.format(i, loss_best.sum()))
            
            # check step size
            with torch.no_grad():
                y1 = loss_indiv.detach().clone()
                loss_steps[i] = y1.cpu() + 0
                ind = (y1 > loss_best).nonzero().squeeze() # loss increase
                x_best[ind] = x_adv[ind].clone()
                grad_best[ind] = grad[ind].clone()
                loss_best[ind] = y1[ind] + 0
                loss_best_steps[i + 1] = loss_best + 0

                counter3 += 1

                if counter3 == k:
                    fl_oscillation = self.check_oscillation(loss_steps.detach().cpu(
                    ).numpy(), i, k, loss_best.detach().cpu().numpy(), k3=self.thr_decr)
                    fl_reduce_no_impr = (~reduced_last_check) * (loss_best_last_check.cpu().numpy() >= loss_best.cpu().numpy())  # nopep8
                    fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr)
                    reduced_last_check = np.copy(fl_oscillation)
                    loss_best_last_check = loss_best.clone()

                    if np.sum(fl_oscillation) > 0:
                        step_size[u[fl_oscillation]] /= 2.0
                        n_reduced = fl_oscillation.astype(float).sum()

                        fl_oscillation = np.where(fl_oscillation)

                        x_adv[fl_oscillation] = x_best[fl_oscillation].clone()
                        grad[fl_oscillation] = grad_best[fl_oscillation].clone()

                    counter3 = 0
                    k = np.maximum(k - self.size_decr, self.steps_min)

        # print(y.shape)
        # print(acc.shape)
        # print('----------')
        return x_best, acc, loss_best, x_best_adv

    def check_oscillation(self, x, j, k, y5, k3=0.75):
        t = np.zeros(x.shape[1])
        for counter5 in range(k):
            t += x[j - counter5] > x[j - counter5 - 1]

        return t <= k*k3*np.ones(t.shape)

def APGDAttack(device, dataLoader, model, norm='Linf', eps=0.03):
    model.eval() 
    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    apgd_attack = APGD(model, device=device, norm=norm, eps=eps)
    for images, labels in dataLoader:
        batchSize = images.shape[0]
        adversarial_images = apgd_attack.forward(images, labels).detach().cpu()

        for j in range(0, batchSize):
            xAdv[advSampleIndex] = adversarial_images[j]
            yClean[advSampleIndex] = labels[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index

        del adversarial_images
        torch.cuda.empty_cache()   
    #All samples processed, now time to save in a dataloader and return
    yClean = yClean.type(torch.LongTensor)
    advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader


def PGDAttack_deit(device, dataLoader, model, eps, num_steps, step_size, clipMin=0, clipMax=1, vis=False):
    model.eval() 
    numSamples = len(dataLoader.dataset)
    xShape = DMP.GetOutputShape(dataLoader)
    xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2])
    yClean = torch.zeros(numSamples)
    advSampleIndex = 0 
    #Go through each sample 
    tracker = 0

    for xData, yData in dataLoader:
        batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later
        tracker = tracker + batchSize
        #Put the data from the batch onto the device 
        xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device)
        xData = xData.to(device)
        yData = yData.type(torch.LongTensor).to(device)

        # Forward pass the data through the model
        # Calculate the loss
        loss = torch.nn.CrossEntropyLoss()

        for i in range(num_steps):
            # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) 
            xDataTemp.requires_grad = True
            if vis == True:
                output, _ = model(xDataTemp)
            else:
                output = model(xDataTemp)

            # Zero all existing gradients
            model.zero_grad()
            # Calculate gradients of model in backward pass
#             cost = loss(output, yData)
            cost = 0
            for idx in range(len(output)):
                cost += loss(output[idx], yData)
            cost.backward()
            # Collect the element-wise sign of the data gradient
            signDataGrad = xDataTemp.grad.data.sign()
            xDataTemp = xDataTemp.detach() + step_size*signDataGrad #FGSM
            # perturbedImage = perturbedImage.detach().cpu()

            # Adding clipping to maintain the range
            delta = torch.clamp(xDataTemp-xData, min=-eps, max=eps)
            xDataTemp = torch.clamp(xData+delta, clipMin, clipMax)

        # xDataTemp = xDataTemp.detach().cpu()
        #Save the adversarial images from the batch 
        for j in range(0, batchSize):
            xAdv[advSampleIndex] = xDataTemp[j]
            yClean[advSampleIndex] = yData[j]
            advSampleIndex = advSampleIndex+1 #increment the sample index
        del xDataTemp
        del signDataGrad
        torch.cuda.empty_cache()    
    #All samples processed, now time to save in a dataloader and return
    yClean = yClean.type(torch.LongTensor)
    advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader
    return advLoader, xAdv.numpy(), yClean.numpy()