import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import sys import numpy as np import time # sys.path.append('/home/artifacts/ViTGuard-main/') sys.path.append('/home/native/ViTGuard_artifacts/ViTGuard-main/') import DataManagerPytorch as DMP sys.path.append('target_models/') from utils import clamp, PCGrad def FGSMAttack(device, dataLoader, model, epsilonMax, clipMin=0, clipMax=1): model.eval() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 #Go through each sample tracker = 0 for xData, yData in dataLoader: batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later tracker = tracker + batchSize #Put the data from the batch onto the device xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device) yData = yData.type(torch.LongTensor).to(device) # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) xDataTemp.requires_grad = True # Forward pass the data through the model output = model(xDataTemp) # Calculate the loss loss = torch.nn.CrossEntropyLoss() # Zero all existing gradients model.zero_grad() # Calculate gradients of model in backward pass cost = loss(output, yData).to(device) cost.backward() # Collect the element-wise sign of the data gradient signDataGrad = xDataTemp.grad.data.sign() perturbedImage = xData + epsilonMax*signDataGrad.cpu().detach() #FGSM # Adding clipping to maintain the range perturbedImage = torch.clamp(perturbedImage, clipMin, clipMax) #Save the adversarial images from the batch for j in range(0, batchSize): xAdv[advSampleIndex] = perturbedImage[j] yClean[advSampleIndex] = yData[j] advSampleIndex = advSampleIndex+1 #increment the sample index #Not sure if we need this but do some memory clean up del xDataTemp del signDataGrad torch.cuda.empty_cache() #All samples processed, now time to save in a dataloader and return yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader def PGDAttack(device, dataLoader, model, eps, num_steps, step_size, clipMin=0, clipMax=1): model.eval() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 #Go through each sample tracker = 0 for xData, yData in dataLoader: batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later tracker = tracker + batchSize #Put the data from the batch onto the device xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device) xData = xData.to(device) yData = yData.type(torch.LongTensor).to(device) # Forward pass the data through the model # Calculate the loss loss = torch.nn.CrossEntropyLoss() for i in range(num_steps): # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) xDataTemp.requires_grad = True output = model(xDataTemp) # Zero all existing gradients model.zero_grad() # Calculate gradients of model in backward pass cost = loss(output, yData) cost.backward() # Collect the element-wise sign of the data gradient signDataGrad = xDataTemp.grad.data.sign() xDataTemp = xDataTemp.detach() + step_size*signDataGrad #FGSM # perturbedImage = perturbedImage.detach().cpu() # Adding clipping to maintain the range delta = torch.clamp(xDataTemp-xData, min=-eps, max=eps) xDataTemp = torch.clamp(xData+delta, clipMin, clipMax) # xDataTemp = xDataTemp.detach().cpu() #Save the adversarial images from the batch for j in range(0, batchSize): xAdv[advSampleIndex] = xDataTemp[j] yClean[advSampleIndex] = yData[j] advSampleIndex = advSampleIndex+1 #increment the sample index del xDataTemp del signDataGrad torch.cuda.empty_cache() #All samples processed, now time to save in a dataloader and return yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader, xAdv.numpy(), yClean.numpy() class CW: def __init__(self, model, device=None, c=1, kappa=0, steps=50, lr=0.01): self.c = c self.kappa = kappa self.steps = steps self.lr = lr self.device = device self.model = model def forward(self, images, labels): images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) # w = torch.zeros_like(images).detach() # Requires 2x times w = self.inverse_tanh_space(images).detach() w.requires_grad = True best_adv_images = images.clone().detach() best_L2 = 1e10*torch.ones((len(images))).to(self.device) prev_cost = 1e10 dim = len(images.shape) MSELoss = nn.MSELoss(reduction='none') Flatten = nn.Flatten() optimizer = optim.Adam([w], lr=self.lr) for step in range(self.steps): # Get adversarial images adv_images = self.tanh_space(w) # Calculate loss current_L2 = MSELoss(Flatten(adv_images), Flatten(images)).sum(dim=1) L2_loss = current_L2.sum() # outputs = self.get_logits(adv_images) outputs = self.model(adv_images) f_loss = self.f(outputs, labels).sum() cost = L2_loss + self.c*f_loss optimizer.zero_grad() cost.backward() optimizer.step() best_adv_images = self.tanh_space(w).detach() return best_adv_images def tanh_space(self, x): return 1/2*(torch.tanh(x) + 1) def inverse_tanh_space(self, x): # torch.atanh is only for torch >= 1.7.0 # atanh is defined in the range -1 to 1 return self.atanh(torch.clamp(x*2-1, min=-1, max=1)) def atanh(self, x): return 0.5*torch.log((1+x)/(1-x)) # f-function in the paper def f(self, outputs, labels): one_hot_labels = torch.eye(outputs.shape[1]).to(self.device)[labels] # find the max logit other than the target class other = torch.max((1-one_hot_labels)*outputs, dim=1)[0] # get the target class's logit real = torch.max(one_hot_labels*outputs, dim=1)[0] return torch.clamp((real-other), min=-self.kappa) def CWAttack_L2(device, dataLoader, model, c=1e-4, kappa=0, max_iter=30, learning_rate=0.001): model.eval() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 cw_attack = CW(model=model, device=device, c=c, kappa=kappa, steps=max_iter, lr=learning_rate) for images, labels in dataLoader: batchSize = images.shape[0] adversarial_images = cw_attack.forward(images, labels) for j in range(0, batchSize): xAdv[advSampleIndex] = adversarial_images[j] yClean[advSampleIndex] = labels[j] advSampleIndex = advSampleIndex+1 #increment the sample index del adversarial_images torch.cuda.empty_cache() #All samples processed, now time to save in a dataloader and return yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader def PatchFool(dataLoader, model, atten_select=5, patch_size=16, num_patch=4, n_tokens=197): ''' atten_select: Select patch based on which attention layer num_patch: the number of patches selected ''' criterion = nn.CrossEntropyLoss().cuda() attack_learning_rate = 0.05 step_size = 1 gamma = 0.95 train_attack_iters = 250 atten_loss_weight = 0.002 mu = [0, 0, 0] std = [1,1,1] mu = torch.tensor(mu).view(3, 1, 1).cuda() std = torch.tensor(std).view(3, 1, 1).cuda() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 tracker = 0 for i, (X, y) in enumerate(dataLoader): batchSize = X.shape[0] tracker = tracker + batchSize X, y = X.cuda(), y.cuda() patch_num_per_line = int(X.size(-1) / patch_size) delta = torch.zeros_like(X).cuda() delta.requires_grad = True model.zero_grad() out, atten = model(X + delta, output_attentions=True) loss = criterion(out, y) ### choose patch atten_layer = atten[atten_select].mean(dim=1) #average across heads atten_layer = atten_layer.mean(dim=-2)[:, 1:] max_patch_index = atten_layer.argsort(descending=True)[:, :num_patch] #build mask mask = torch.zeros([X.size(0), 1, X.size(2), X.size(3)]).cuda() for j in range(X.size(0)): index_list = max_patch_index[j] for index in index_list: row = (index // patch_num_per_line) * patch_size column = (index % patch_num_per_line) * patch_size mask[j, :, row:row + patch_size, column:column + patch_size] = 1 # adv attack max_patch_index_matrix = max_patch_index[:, 0] max_patch_index_matrix = max_patch_index_matrix.repeat(n_tokens, 1) max_patch_index_matrix = max_patch_index_matrix.permute(1, 0) max_patch_index_matrix = max_patch_index_matrix.flatten().long() delta = torch.rand_like(X) X = torch.mul(X, 1 - mask) delta = delta.cuda() delta.requires_grad = True opt = torch.optim.Adam([delta], lr=attack_learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma) # start adv attack for train_iter_num in range(train_attack_iters): model.zero_grad() opt.zero_grad() out, atten = model(X + torch.mul(delta, mask), output_attentions=True) '''final CE-loss''' loss = criterion(out, y) grad = torch.autograd.grad(loss, delta, retain_graph=True)[0] ce_loss_grad_temp = grad.view(X.size(0), -1).detach().clone() # Attack the first 6 layers' Attn range_list = range(len(atten)//2) for atten_num in range_list: if atten_num == 0: continue atten_map = atten[atten_num] atten_map = atten_map.mean(dim=1) atten_map = atten_map.view(-1, atten_map.size(-1)) atten_map = -torch.log(atten_map) atten_loss = F.nll_loss(atten_map, max_patch_index_matrix) atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0] atten_grad_temp = atten_grad.view(X.size(0), -1) cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1) atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape) grad += atten_grad * atten_loss_weight opt.zero_grad() delta.grad = -grad opt.step() scheduler.step() delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std) # Eval Adv Attack with torch.no_grad(): perturb_x = X + torch.mul(delta, mask) out = model(perturb_x) classification_result_after_attack = out.max(1)[1] == y loss = criterion(out, y) # print(classification_result_after_attack.sum().item()) for j in range(0, batchSize): xAdv[advSampleIndex] = perturb_x[j] yClean[advSampleIndex] = y[j] advSampleIndex = advSampleIndex+1 #increment the sample index yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader def AttentionFool(dataLoader, model, atten_select=5, patch_size=16, num_patch=4, n_tokens=197): ''' atten_select: Select patch based on which attention layer num_patch: the number of patches selected ''' criterion = nn.CrossEntropyLoss().cuda() attack_learning_rate = 0.04 step_size = 1 gamma = 0.95 train_attack_iters = 250 atten_loss_weight = 0.00001 mu = [0, 0, 0] std = [1,1,1] mu = torch.tensor(mu).view(3, 1, 1).cuda() std = torch.tensor(std).view(3, 1, 1).cuda() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 tracker = 0 for i, (X, y) in enumerate(dataLoader): batchSize = X.shape[0] tracker = tracker + batchSize X, y = X.cuda(), y.cuda() patch_num_per_line = int(X.size(-1) / patch_size) delta = torch.zeros_like(X).cuda() delta.requires_grad = True model.zero_grad() out, atten = model(X + delta, output_attentions=True) loss = criterion(out, y) ### choose patch atten_layer = atten[atten_select].mean(dim=1) #average across heads atten_layer = atten_layer.mean(dim=-2)[:, 1:] max_patch_index = atten_layer.argsort(descending=True)[:, :num_patch] #build mask mask = torch.zeros([X.size(0), 1, X.size(2), X.size(3)]).cuda() for j in range(X.size(0)): index_list = max_patch_index[j] for index in index_list: row = (index // patch_num_per_line) * patch_size column = (index % patch_num_per_line) * patch_size mask[j, :, row:row + patch_size, column:column + patch_size] = 1 # adv attack max_patch_index_matrix = max_patch_index[:, 0] max_patch_index_matrix = max_patch_index_matrix.repeat(n_tokens, 1) max_patch_index_matrix = max_patch_index_matrix.permute(1, 0) max_patch_index_matrix = max_patch_index_matrix.flatten().long() delta = torch.rand_like(X) X = torch.mul(X, 1 - mask) delta = delta.cuda() delta.requires_grad = True opt = torch.optim.Adam([delta], lr=attack_learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma) # start adv attack for train_iter_num in range(train_attack_iters): model.zero_grad() opt.zero_grad() out, atten = model(X + torch.mul(delta, mask), output_attentions=True) '''final CE-loss''' loss = criterion(out, y) grad = torch.autograd.grad(loss, delta, retain_graph=True)[0] ce_loss_grad_temp = grad.view(X.size(0), -1).detach().clone() # Attack the first 6 layers' Attn range_list = range(len(atten)//2) for atten_num in range_list: if atten_num == 0: continue atten_map = atten[atten_num] atten_map = atten_map.mean(dim=1) atten_map = atten_map.view(-1, atten_map.size(-1)) atten_map = -torch.log(atten_map) atten_loss = F.nll_loss(atten_map, max_patch_index_matrix) atten_grad = torch.autograd.grad(atten_loss, delta, retain_graph=True)[0] atten_grad_temp = atten_grad.view(X.size(0), -1) cos_sim = F.cosine_similarity(atten_grad_temp, ce_loss_grad_temp, dim=1) atten_grad = PCGrad(atten_grad_temp, ce_loss_grad_temp, cos_sim, grad.shape) grad += atten_grad * atten_loss_weight opt.zero_grad() delta.grad = -grad opt.step() scheduler.step() delta.data = clamp(delta, (0 - mu) / std, (1 - mu) / std) # Eval Adv Attack with torch.no_grad(): perturb_x = X + torch.mul(delta, mask) out = model(perturb_x) classification_result_after_attack = out.max(1)[1] == y loss = criterion(out, y) # print(classification_result_after_attack.sum().item()) for j in range(0, batchSize): xAdv[advSampleIndex] = perturb_x[j] yClean[advSampleIndex] = y[j] advSampleIndex = advSampleIndex+1 #increment the sample index yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader class APGD: def __init__(self, model, device=None, norm='Linf', eps=8/255, steps=10, n_restarts=1, seed=0, loss='ce', eot_iter=1, rho=.75, verbose=False): # super().__init__('APGD', model, device) self.eps = eps self.steps = steps self.norm = norm self.n_restarts = n_restarts self.seed = seed self.loss = loss self.eot_iter = eot_iter self.thr_decr = rho self.verbose = verbose self.supported_mode = ['default'] self.device = device self.model = model def forward(self, images, labels): images = images.clone().detach().to(self.device) labels = labels.clone().detach().to(self.device) _, adv_images = self.perturb(images, labels, cheap=True) return adv_images def perturb(self, x_in, y_in, best_loss=False, cheap=True): x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) adv = x.clone() logits = self.model(x_in) acc = logits.max(1)[1] == y if self.verbose: print('-------------------------- running {}-attack with epsilon {:.4f} --------------------------'.format(self.norm, self.eps)) print('initial accuracy: {:.2%}'.format(acc.float().mean())) startt = time.time() if not best_loss: torch.random.manual_seed(self.seed) torch.cuda.random.manual_seed(self.seed) if not cheap: raise ValueError('not implemented yet') else: for counter in range(self.n_restarts): ind_to_fool = acc.nonzero().squeeze() ######### extract the correctly classified samples if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) if ind_to_fool.numel() != 0: x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() # nopep8 best_curr, acc_curr, loss_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) # nopep8 #x_best, acc, loss_best, x_best_adv ind_curr = (acc_curr == 0).nonzero().squeeze() # acc[ind_to_fool[ind_curr]] = 0 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() if self.verbose: print('restart {} - robust accuracy: {:.2%} - cum. time: {:.1f} s'.format( counter, acc.float().mean(), time.time() - startt)) return acc, adv def attack_single_run(self, x_in, y_in): x = x_in.clone() if len(x_in.shape) == 4 else x_in.clone().unsqueeze(0) y = y_in.clone() if len(y_in.shape) == 1 else y_in.clone().unsqueeze(0) self.steps_2, self.steps_min, self.size_decr = max(int( 0.22 * self.steps), 1), max(int(0.06 * self.steps), 1), max(int(0.03 * self.steps), 1) if self.verbose: print('parameters: ', self.steps, self.steps_2, self.steps_min, self.size_decr) if self.norm == 'Linf': t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / (t.reshape([t.shape[0], -1]).abs().max(dim=1, keepdim=True)[0].reshape([-1, 1, 1, 1])) # nopep8 elif self.norm == 'L2': t = torch.randn(x.shape).to(self.device).detach() x_adv = x.detach() + self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * t / ((t ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) # nopep8 x_adv = x_adv.clamp(0., 1.) x_best = x_adv.clone() x_best_adv = x_adv.clone() loss_steps = torch.zeros([self.steps, x.shape[0]]) loss_best_steps = torch.zeros([self.steps + 1, x.shape[0]]) acc_steps = torch.zeros_like(loss_best_steps) if self.loss == 'ce': criterion_indiv = nn.CrossEntropyLoss(reduction='none') elif self.loss == 'dlr': criterion_indiv = self.dlr_loss else: raise ValueError('unknown loss') x_adv.requires_grad_() grad = torch.zeros_like(x) for _ in range(self.eot_iter): with torch.enable_grad(): # 1 forward pass (eot_iter = 1) logits = self.model(x_adv) loss_indiv = criterion_indiv(logits, y) loss = loss_indiv.sum() grad += torch.autograd.grad(loss, [x_adv])[0].detach() grad /= float(self.eot_iter) grad_best = grad.clone() acc = logits.detach().max(1)[1] == y acc_steps[0] = acc + 0 #?????????????????????????? loss_best = loss_indiv.detach().clone() step_size = self.eps * torch.ones([x.shape[0], 1, 1, 1]).to(self.device).detach() * torch.Tensor([2.0]).to(self.device).detach().reshape([1, 1, 1, 1]) # nopep8 x_adv_old = x_adv.clone() counter = 0 k = self.steps_2 + 0 u = np.arange(x.shape[0]) counter3 = 0 loss_best_last_check = loss_best.clone() reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape) # n_reduced = 0 for i in range(self.steps): # print(i) # gradient step with torch.no_grad(): x_adv = x_adv.detach() grad2 = x_adv - x_adv_old x_adv_old = x_adv.clone() a = 0.75 if i > 0 else 1.0 if self.norm == 'Linf': x_adv_1 = x_adv + step_size * torch.sign(grad) x_adv_1 = torch.clamp( torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0) x_adv_1 = torch.clamp(torch.min(torch.max( x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self.eps), x + self.eps), 0.0, 1.0) elif self.norm == 'L2': x_adv_1 = x_adv + step_size * grad / ((grad ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) # nopep8 x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt()), 0.0, 1.0) # nopep8 x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) x_adv_1 = torch.clamp(x + (x_adv_1 - x) / (((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12) * torch.min(self.eps * torch.ones(x.shape).to(self.device).detach(), ((x_adv_1 - x) ** 2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + 1e-12), 0.0, 1.0) # nopep8 x_adv = x_adv_1 + 0. # get gradient x_adv.requires_grad_() grad = torch.zeros_like(x) for _ in range(self.eot_iter): with torch.enable_grad(): # 1 forward pass (eot_iter = 1) logits = self.model(x_adv) loss_indiv = criterion_indiv(logits, y) loss = loss_indiv.sum() # 1 backward pass (eot_iter = 1) grad += torch.autograd.grad(loss, [x_adv])[0].detach() grad /= float(self.eot_iter) pred = logits.detach().max(1)[1] == y acc = torch.min(acc, pred) acc_steps[i + 1] = acc + 0 # acc is a Boolean matrix, adding zero to it will not alter its elements x_best_adv[(pred == 0).nonzero().squeeze()] = x_adv[(pred == 0).nonzero().squeeze()] + 0. # nopep8 #(pred == 0) make wrong classification if self.verbose: print('iteration: {} - Best loss: {:.6f}'.format(i, loss_best.sum())) # check step size with torch.no_grad(): y1 = loss_indiv.detach().clone() loss_steps[i] = y1.cpu() + 0 ind = (y1 > loss_best).nonzero().squeeze() # loss increase x_best[ind] = x_adv[ind].clone() grad_best[ind] = grad[ind].clone() loss_best[ind] = y1[ind] + 0 loss_best_steps[i + 1] = loss_best + 0 counter3 += 1 if counter3 == k: fl_oscillation = self.check_oscillation(loss_steps.detach().cpu( ).numpy(), i, k, loss_best.detach().cpu().numpy(), k3=self.thr_decr) fl_reduce_no_impr = (~reduced_last_check) * (loss_best_last_check.cpu().numpy() >= loss_best.cpu().numpy()) # nopep8 fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr) reduced_last_check = np.copy(fl_oscillation) loss_best_last_check = loss_best.clone() if np.sum(fl_oscillation) > 0: step_size[u[fl_oscillation]] /= 2.0 n_reduced = fl_oscillation.astype(float).sum() fl_oscillation = np.where(fl_oscillation) x_adv[fl_oscillation] = x_best[fl_oscillation].clone() grad[fl_oscillation] = grad_best[fl_oscillation].clone() counter3 = 0 k = np.maximum(k - self.size_decr, self.steps_min) # print(y.shape) # print(acc.shape) # print('----------') return x_best, acc, loss_best, x_best_adv def check_oscillation(self, x, j, k, y5, k3=0.75): t = np.zeros(x.shape[1]) for counter5 in range(k): t += x[j - counter5] > x[j - counter5 - 1] return t <= k*k3*np.ones(t.shape) def APGDAttack(device, dataLoader, model, norm='Linf', eps=0.03): model.eval() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 apgd_attack = APGD(model, device=device, norm=norm, eps=eps) for images, labels in dataLoader: batchSize = images.shape[0] adversarial_images = apgd_attack.forward(images, labels).detach().cpu() for j in range(0, batchSize): xAdv[advSampleIndex] = adversarial_images[j] yClean[advSampleIndex] = labels[j] advSampleIndex = advSampleIndex+1 #increment the sample index del adversarial_images torch.cuda.empty_cache() #All samples processed, now time to save in a dataloader and return yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader def PGDAttack_deit(device, dataLoader, model, eps, num_steps, step_size, clipMin=0, clipMax=1, vis=False): model.eval() numSamples = len(dataLoader.dataset) xShape = DMP.GetOutputShape(dataLoader) xAdv = torch.zeros(numSamples, xShape[0], xShape[1], xShape[2]) yClean = torch.zeros(numSamples) advSampleIndex = 0 #Go through each sample tracker = 0 for xData, yData in dataLoader: batchSize = xData.shape[0] #Get the batch size so we know indexing for saving later tracker = tracker + batchSize #Put the data from the batch onto the device xDataTemp = torch.from_numpy(xData.cpu().detach().numpy()).to(device) xData = xData.to(device) yData = yData.type(torch.LongTensor).to(device) # Forward pass the data through the model # Calculate the loss loss = torch.nn.CrossEntropyLoss() for i in range(num_steps): # Set requires_grad attribute of tensor. Important for attack. (Pytorch comment, not mine) xDataTemp.requires_grad = True if vis == True: output, _ = model(xDataTemp) else: output = model(xDataTemp) # Zero all existing gradients model.zero_grad() # Calculate gradients of model in backward pass # cost = loss(output, yData) cost = 0 for idx in range(len(output)): cost += loss(output[idx], yData) cost.backward() # Collect the element-wise sign of the data gradient signDataGrad = xDataTemp.grad.data.sign() xDataTemp = xDataTemp.detach() + step_size*signDataGrad #FGSM # perturbedImage = perturbedImage.detach().cpu() # Adding clipping to maintain the range delta = torch.clamp(xDataTemp-xData, min=-eps, max=eps) xDataTemp = torch.clamp(xData+delta, clipMin, clipMax) # xDataTemp = xDataTemp.detach().cpu() #Save the adversarial images from the batch for j in range(0, batchSize): xAdv[advSampleIndex] = xDataTemp[j] yClean[advSampleIndex] = yData[j] advSampleIndex = advSampleIndex+1 #increment the sample index del xDataTemp del signDataGrad torch.cuda.empty_cache() #All samples processed, now time to save in a dataloader and return yClean = yClean.type(torch.LongTensor) advLoader = DMP.TensorToDataLoader(xAdv, yClean, transforms= None, batchSize= dataLoader.batch_size, randomizer=None) #use the same batch size as the original loader return advLoader, xAdv.numpy(), yClean.numpy()