import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve
import torch.nn.functional as F
import sys
import argparse
from model_dino import dino_small
sys.path.append('../')
import ViTMAEConfigs_pretrain as configs
# from ViTMAEModels_pretrain import ViTMAEForPreTraining_custom
from ViTMAEModels_salient import ViTMAEForPreTraining_salient
from utils import get_attn, get_success_adv_index, l2_distance, get_cls, remove_nan_from_dataset
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Validation, GetCIFAR10Validation
import DataManagerPytorch as DMP
sys.path.append('../../target_models/')
from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def normalize(t, mean, std):
t[:, 0, :, :] = (t[:, 0, :, :] - mean[0]) / std[0]
t[:, 1, :, :] = (t[:, 1, :, :] - mean[1]) / std[1]
t[:, 2, :, :] = (t[:, 2, :, :] - mean[2]) / std[2]
return t
def get_aug():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='TinyImagenet', type=str)
parser.add_argument('--attack', default='PGD', type=str)
parser.add_argument('--detector', default='Attention', type=str) #Attention or CLS or RL or PD_T10 or PD_T40
parser.add_argument('--ratio', default=0.5) #masking ratio
parser.add_argument('--masking', default='salient') #"salient" or "non-salient"
args = parser.parse_args()
return args
args = get_aug()
print("Loading the dataset.")
if args.dataset == 'CIFAR10':
test_loader = GetCIFAR10Validation(imgSize=224, ratio=0.2)
num_labels = 10
layer_index = -1
elif args.dataset == 'CIFAR100':
test_loader = GetCIFAR100Validation(imgSize=224, ratio=0.2)
num_labels = 100
layer_index = -1
elif args.dataset == 'TinyImagenet':
test_loader = load_tiny()
num_labels = 200
layer_index = 1
print("Loading the target model.")
#load the classification model
model_arch = 'ViT-16'
config = configs.get_b16_config()
model = ViTModel_custom(config=config)
model = ViTForImageClassification(config, model, num_labels)
filename = "../../target_models/results/{}/{}/weights.pth".format(model_arch, args.dataset)
model.load_state_dict(torch.load(filename), strict=False)
model = nn.DataParallel(model).cuda()
model.eval()
print("Loading the adversarial examples.")
# load adversarial examples
adv_filepath = "../../target_models/results/{}/{}/adv_results/".format(model_arch, args.dataset)
advLoader = torch.load(adv_filepath + args.attack + '_advLoader.pth')
advLoader.pin_memory_device = 'cuda'
print("Generating masks.")
acc_dict = {}
drop_best = False
acc_dict[f"{'best' if drop_best else 'worst'}"] = {}
dino_model = dino_small(patch_size=16, pretrained=True)
ckpt = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", map_location="cpu")
msg = dino_model.load_state_dict(ckpt["model"])
dino_model.to(device)
dino_model.eval()
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
head_number = 1
drop_lambda = 0.5
th_attn_all = []
if args.masking in['salient', 'non-salient']:
advLoader_masking = advLoader
for i, (img, label) in enumerate(advLoader_masking):
batch_len = len(img)
with torch.no_grad():
img, label = img.to(device), label.to(device)
attentions = dino_model.forward_selfattention(normalize(img.clone(), mean=mean, std=std)).detach().cpu()
attentions = attentions[:, head_number, 0, 1:] #shape:[64,196]
w_featmap = int(np.sqrt(attentions.shape[-1])) #14
h_featmap = int(np.sqrt(attentions.shape[-1])) #14
scale = img.shape[2] // w_featmap #16
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
th_attn = torch.zeros_like(idx)
attn_large_index = idx.numpy()[:,:int(w_featmap*h_featmap*(1-drop_lambda))]
th_attn[torch.arange(batch_len).unsqueeze(1), attn_large_index] = 1 # smaller attentions
if args.masking == 'salient':
th_attn_all.append(th_attn.numpy())
elif args.masking == 'non-salient':
th_attn = 1 - th_attn
th_attn_all.append(th_attn.numpy())
salient_index_adv = np.concatenate(th_attn_all)
th_attn_all = []
for i, (img, label) in enumerate(test_loader):
batch_len = len(img)
with torch.no_grad():
img, label = img.to(device), label.to(device)
attentions = dino_model.forward_selfattention(normalize(img.clone(), mean=mean, std=std)).detach().cpu()
attentions = attentions[:, head_number, 0, 1:] #shape:[64,196]
w_featmap = int(np.sqrt(attentions.shape[-1])) #14
h_featmap = int(np.sqrt(attentions.shape[-1])) #14
scale = img.shape[2] // w_featmap #16
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
th_attn = torch.zeros_like(idx)
attn_large_index = idx.numpy()[:,:int(w_featmap*h_featmap*(1-drop_lambda))]
th_attn[torch.arange(batch_len).unsqueeze(1), attn_large_index] = 1 # smaller attentions
if args.masking == 'salient':
th_attn_all.append(th_attn.numpy())
elif args.masking == 'non-salient':
th_attn = 1 - th_attn
th_attn_all.append(th_attn.numpy())
salient_index_test = np.concatenate(th_attn_all)
print("Loading the MAE model.")
#load the MAE model
config = configs.ViTMAEConfig(ratio=args.ratio)
model_mae = ViTMAEForPreTraining_salient(config=config)
weights_path = '../results/{}/'.format(args.dataset)
model_mae.load_state_dict(torch.load(weights_path + 'weights.pth'), strict=False)
model_mae = nn.DataParallel(model_mae).cuda()
# Extract successful adv samples
print('Extracting successful adv examples.')
if args.attack == 'AttentionFool':
advLoader, test_loader = remove_nan_from_dataset(advLoader, test_loader)
detect_index = get_success_adv_index(test_loader, advLoader, model, device)
if args.detector == 'Attention':
attn_test = get_attn(test_loader, model, device, layer_index)
attn_adv = get_attn(advLoader, model, device, layer_index)
sim_test_noise_all, sim_adv_noise_all = [], []
#reconstruct images
rounds = 2 if args.masking == 'random' else 1
for random_seed in range(rounds):
if args.masking == 'random':
reLoader_adv = DMP.get_reconstructed_dataset(model_mae, advLoader, device, random_seed)
reLoader_test = DMP.get_reconstructed_dataset(model_mae, test_loader, device, random_seed)
else:
reLoader_adv = DMP.get_reconstructed_dataset_salient(model_mae, advLoader, device, salient_index=salient_index_adv)
reLoader_test = DMP.get_reconstructed_dataset_salient(model_mae, test_loader, device, salient_index=salient_index_test)
attn_adv_noise = get_attn(reLoader_adv, model, device, layer_index)
attn_test_noise = get_attn(reLoader_test, model, device, layer_index)
# calculate distances
sim_test_noise_all.append(l2_distance(attn_test, attn_test_noise))
sim_adv_noise_all.append(l2_distance(attn_adv, attn_adv_noise))
elif args.detector == 'CLS':
cls_test = get_cls(test_loader, model, device, layer_index)
cls_adv = get_cls(advLoader, model, device, layer_index)
sim_test_noise_all, sim_adv_noise_all = [], []
rounds = 2 if args.masking == 'random' else 1
for random_seed in range(rounds):
if args.masking == 'random':
reLoader_adv = DMP.get_reconstructed_dataset(model_mae, advLoader, device, random_seed)
reLoader_test = DMP.get_reconstructed_dataset(model_mae, test_loader, device, random_seed)
else:
reLoader_adv = DMP.get_reconstructed_dataset_salient(model_mae, advLoader, device, salient_index=salient_index_adv)
reLoader_test = DMP.get_reconstructed_dataset_salient(model_mae, test_loader, device, salient_index=salient_index_test)
cls_adv_noise = get_cls(reLoader_adv, model, device, layer_index)
cls_test_noise = get_cls(reLoader_test, model, device, layer_index)
sim_test_noise_all.append(l2_distance(cls_test, cls_test_noise))
sim_adv_noise_all.append(l2_distance(cls_adv, cls_adv_noise))
sim_test_noise_all = np.asarray(sim_test_noise_all)
sim_adv_noise_all = np.asarray(sim_adv_noise_all)
sim_test = np.mean(sim_test_noise_all, axis=0)
sim_adv = np.mean(sim_adv_noise_all, axis=0)
sim_test_correct, sim_adv_correct = sim_test[detect_index], sim_adv[detect_index]
sim_all_correct = np.concatenate((sim_test_correct, sim_adv_correct), axis=0)
true_label_correct = [0]*len(sim_test_correct) + [1]*len(sim_adv_correct)
true_label_correct = np.asarray(true_label_correct)
auc1 = roc_auc_score(true_label_correct, sim_all_correct)
print('AUC score is', auc1)