import torch import numpy as np from torch.utils.data import DataLoader, Dataset def vis_from_attn(att_mat, layer_index=-1): att_mat = att_mat.to('cpu') # Average the attention weights across all heads. att_mat = torch.mean(att_mat, dim=1) residual_att = torch.eye(att_mat.size(-1)) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) #[num_layers, N, num_tokens, num_tokens] # Recursively multiply the weight matrices joint_attentions = torch.zeros(aug_att_mat.size()) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1]) # Attention from the output token to the input space. v = joint_attentions[layer_index] att_matrix = v[:,0, 1:].detach().numpy() return att_matrix #(N, 196) # Get the attention vector from the specific transformer layer def get_attn(dataloader, model, device, layer_index): attn_all = [] for x_batch, y_batch in dataloader: x_batch = x_batch.to(device) preds = model(x_batch, output_attentions=True) att_mat = [t.unsqueeze(0) for t in preds[-1]] att_mat = torch.cat(att_mat, dim=0) att_mat = torch.transpose(att_mat, 1, 2).detach().cpu() del preds att_matrix = vis_from_attn(att_mat, layer_index) attn_all.append(att_matrix) attn_all = np.vstack(attn_all) return attn_all # Extract the successful adversarial samples def get_success_adv_index(test_loader, advLoader, model, device): correct_index = [] batch_size = 64 tracker = 0 for x_batch, y_batch in test_loader: x_batch = x_batch.to(device) output = model(x_batch) pred = torch.argmax(output, 1).detach().cpu() index_temp = np.where(pred.numpy() == y_batch.numpy())[0] for i in index_temp: correct_index.append(i+batch_size*tracker) tracker += 1 adv_index = [] batch_size = 64 tracker = 0 for x_batch, y_batch in advLoader: x_batch = x_batch.to(device) output = model(x_batch) pred = torch.argmax(output, 1).detach().cpu() index_temp = np.where(pred.numpy() != y_batch.numpy())[0] for i in index_temp: adv_index.append(i+batch_size*tracker) tracker += 1 detect_index = np.intersect1d(np.asarray(correct_index), np.asarray(adv_index)) return detect_index def l2_distance(matrix1, matrix2): dis = np.linalg.norm(matrix1 - matrix2,axis=1) return dis def get_threshold(error_adv_all, drop_rate):#drop_rate is fpr num = int(len(error_adv_all) * drop_rate) marks = np.sort(error_adv_all) thrs = marks[-num] return thrs, marks # Get the CLS representations from the specific transformer layer def get_cls(test_loader, model, device, layer_index): features = {} def get_features(name): def hook(_, input, output): features[output[0].get_device()] = output[0][:,0].detach().cpu() return hook model.module.vit.encoder.layer[layer_index].register_forward_hook(get_features('cls')) cls_outputs = [] for x_batch, y_batch in test_loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) preds = model(x_batch) for index in list(features.keys()): cls_outputs.append(features[index]) del preds cls_outputs = np.vstack(cls_outputs) return cls_outputs class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] def remove_nan_from_dataset(dataset_adv, dataset_clean): adv_data = [] adv_labels = [] clean_data = [] clean_labels = [] for (data, label), (data_c, label_c) in zip(dataset_adv, dataset_clean): for i in range(len(data)): if not torch.isnan(data[i]).any() and not torch.isnan(label[i]).any(): adv_data.append(data[i]) adv_labels.append(label[i]) clean_data.append(data_c[i]) clean_labels.append(label_c[i]) advLoader = CustomDataset(torch.stack(adv_data), torch.stack(adv_labels)) test_loader = CustomDataset(torch.stack(clean_data), torch.stack(clean_labels)) test_loader = torch.utils.data.DataLoader(test_loader, batch_size=64, shuffle=False, pin_memory=True) advLoader = torch.utils.data.DataLoader(advLoader, batch_size=64, shuffle=False, pin_memory=True) return advLoader, test_loader