ViTGuard / detection / run / table3.py
table3.py
Raw
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve
import torch.nn.functional as F
import sys
import argparse
sys.path.append('../')
import ViTMAEConfigs_pretrain as configs
from ViTMAEModels_pretrain import ViTMAEForPreTraining_custom
from ViTMAEModels_salient import ViTMAEForPreTraining_salient
from utils import get_attn, get_success_adv_index, l2_distance, get_cls, remove_nan_from_dataset 
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Validation, GetCIFAR10Validation
import DataManagerPytorch as DMP
sys.path.append('../../target_models/')
from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def get_aug():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='TinyImagenet', type=str)
    parser.add_argument('--attack', default='PGD', type=str)
    parser.add_argument('--detector', default='Attention', type=str) #Attention or CLS or RL or PD_T10 or PD_T40
    parser.add_argument('--ratio', default=0.5) #masking ratio
    args = parser.parse_args()
    return args

args = get_aug()
print("Loading the dataset.")
if args.dataset == 'CIFAR10':
    test_loader = GetCIFAR10Validation(imgSize=224, ratio=0.2)
    num_labels = 10
    layer_index = -1
elif args.dataset == 'CIFAR100':
    test_loader = GetCIFAR100Validation(imgSize=224, ratio=0.2)
    num_labels = 100
    layer_index = -1
elif args.dataset == 'TinyImagenet':
    test_loader = load_tiny()
    num_labels = 200
    layer_index = 1

print("Loading the target model.")
#load the classification model
model_arch = 'ViT-16'
config = configs.get_b16_config()

model = ViTModel_custom(config=config)
model = ViTForImageClassification(config, model, num_labels)
filename = "../../target_models/results/{}/{}/weights.pth".format(model_arch, args.dataset)
model.load_state_dict(torch.load(filename), strict=False)
model = nn.DataParallel(model).cuda()
model.eval()

print("Loading the adversarial examples.")
# load adversarial examples
adv_filepath = "../../target_models/results/{}/{}/adv_results/".format(model_arch, args.dataset)
advLoader = torch.load(adv_filepath + args.attack + '_advLoader.pth')
advLoader.pin_memory_device = 'cuda'

print("Loading the MAE model.")
#load the MAE model
config = configs.ViTMAEConfig(ratio=args.ratio)
if args.attack in ['PatchFool', 'AttentionFool']:
    model_mae = ViTMAEForPreTraining_salient(config=config)
else: 
    model_mae = ViTMAEForPreTraining_custom(config=config)

weights_path = '../results/{}/'.format(args.dataset)
model_mae.load_state_dict(torch.load(weights_path + 'weights.pth'), strict=False)
model_mae = nn.DataParallel(model_mae).cuda()

# Extract successful adv samples
print('Extracting successful adv examples.')
if args.attack == 'AttentionFool':
    advLoader, test_loader = remove_nan_from_dataset(advLoader, test_loader)
detect_index = get_success_adv_index(test_loader, advLoader, model, device)

print("Running the detector.")
if args.detector in ['RL', 'PD_T10', 'PD_T40']:
    reLoader_adv = DMP.get_reconstructed_dataset(model_mae, advLoader, device, random_seed=1)
    reLoader_test = DMP.get_reconstructed_dataset(model_mae, test_loader, device, random_seed=1)

if args.detector == 'RL':
    def get_error(inputs, outputs, p=1):
        diff = torch.abs(inputs - outputs)
        marks = torch.mean(torch.pow(diff, p), dim=(1, 2, 3))
        return marks
    error_clean_all = []
    for (inputs, _), (inputs_re, _) in zip (test_loader, reLoader_test):
        error = get_error(inputs, inputs_re, p=1)
        error_clean_all.append(error)
    error_clean_all = np.concatenate(error_clean_all)
    error_adv_all = []
    for (inputs, _), (inputs_re, _) in zip (advLoader, reLoader_adv):
        error = get_error(inputs, inputs_re, p=1)
        error_adv_all.append(error)
    error_adv_all = np.concatenate(error_adv_all)

elif args.detector in ['PD_T10', 'PD_T40']:
    def get_jsd(softmax_tensor1, softmax_tensor2):
        middle_distribution = 0.5 * (softmax_tensor1 + softmax_tensor2)

        # Calculate KL divergences for both tensors with respect to the middle distribution
        kl_divergence1 = torch.sum(softmax_tensor1 * torch.log(softmax_tensor1 / middle_distribution), dim=1)
        kl_divergence2 = torch.sum(softmax_tensor2 * torch.log(softmax_tensor2 / middle_distribution), dim=1)

        # Calculate the Jensen-Shannon Divergence
        jsd = 0.5 * (kl_divergence1 + kl_divergence2)

        return jsd
    T1 = 10 if args.detector == 'PD_T10' else 40
    div_clean_all = []
    for (inputs, _), (inputs_re, _) in zip (test_loader, reLoader_test):
        inputs, inputs_re = inputs.to(device), inputs_re.to(device)
        input_logits = model(inputs).detach().cpu()
        output_logits = model(inputs_re).detach().cpu()
        input_logits = F.softmax(input_logits/T1, dim=1)
        output_logits = F.softmax(output_logits/T1, dim=1)
        d1 = get_jsd(input_logits, output_logits)
        div_clean_all.append(d1)  
    div_clean_all = np.concatenate(div_clean_all)

    div_adv_all = []
    for (inputs, _), (inputs_re, _) in zip (advLoader, reLoader_adv):
        inputs, inputs_re = inputs.to(device), inputs_re.to(device)
        input_logits = model(inputs).detach().cpu()
        output_logits = model(inputs_re).detach().cpu()
        input_logits = F.softmax(input_logits/T1, dim=1)
        output_logits = F.softmax(output_logits/T1, dim=1)
        d1 = get_jsd(input_logits, output_logits)
        div_adv_all.append(d1)
    div_adv_all = np.concatenate(div_adv_all)



elif args.detector == 'Attention':
    attn_test = get_attn(test_loader, model, device, layer_index)
    attn_adv = get_attn(advLoader, model, device, layer_index)
    sim_test_noise_all, sim_adv_noise_all = [], []

    if args.attack in ['PatchFool', 'AttentionFool']:
        n_patches = int(224/16) * int(224/16)
        random_index = np.zeros((len(test_loader.dataset), n_patches))
        for i in range(len(random_index)):
            temp = np.random.choice(n_patches, int(n_patches/2), replace=False)
            random_index[i, temp] = 1
        reLoader_test = DMP.recoverall(model_mae, test_loader, device, salient_index=random_index)
        reLoader_adv = DMP.recoverall(model_mae, advLoader, device, salient_index=random_index)   
        attn_adv_noise = get_attn(reLoader_adv, model, device, layer_index)
        attn_test_noise = get_attn(reLoader_test, model, device, layer_index)
        # calculate distances
        sim_test_noise_all.append(l2_distance(attn_test, attn_test_noise))
        sim_adv_noise_all.append(l2_distance(attn_adv, attn_adv_noise))

    else:
        #reconstruct images
        for random_seed in range(2):
            reLoader_adv = DMP.get_reconstructed_dataset(model_mae, advLoader, device, random_seed)
            reLoader_test = DMP.get_reconstructed_dataset(model_mae, test_loader, device, random_seed)
            attn_adv_noise = get_attn(reLoader_adv, model, device, layer_index)
            attn_test_noise = get_attn(reLoader_test, model, device, layer_index)
            # calculate distances
            sim_test_noise_all.append(l2_distance(attn_test, attn_test_noise))
            sim_adv_noise_all.append(l2_distance(attn_adv, attn_adv_noise))

elif args.detector == 'CLS':
    cls_test = get_cls(test_loader, model, device, layer_index)
    cls_adv = get_cls(advLoader, model, device, layer_index)
    sim_test_noise_all, sim_adv_noise_all = [], []
    if args.attack in ['PatchFool', 'AttentionFool']:
        n_patches = int(224/16) * int(224/16)
        random_index = np.zeros((len(test_loader.dataset), n_patches))
        for i in range(len(random_index)):
            temp = np.random.choice(n_patches, int(n_patches/2), replace=False)
            random_index[i, temp] = 1
        reLoader_test = DMP.recoverall(model_mae, test_loader, device, salient_index=random_index)
        reLoader_adv = DMP.recoverall(model_mae, advLoader, device, salient_index=random_index)   
        cls_adv_noise = get_cls(reLoader_adv, model, device, layer_index)
        cls_test_noise = get_cls(reLoader_test, model, device, layer_index)
        sim_test_noise_all.append(l2_distance(cls_test, cls_test_noise))
        sim_adv_noise_all.append(l2_distance(cls_adv, cls_adv_noise))

    else:
        for random_seed in range(2):
            reLoader_adv = DMP.get_reconstructed_dataset(model_mae, advLoader, device, random_seed)
            reLoader_test = DMP.get_reconstructed_dataset(model_mae, test_loader, device, random_seed)
            cls_adv_noise = get_cls(reLoader_adv, model, device, layer_index)
            cls_test_noise = get_cls(reLoader_test, model, device, layer_index)
            sim_test_noise_all.append(l2_distance(cls_test, cls_test_noise))
            sim_adv_noise_all.append(l2_distance(cls_adv, cls_adv_noise))

if args.detector in ['Attention', 'CLS']:
    sim_test_noise_all = np.asarray(sim_test_noise_all)
    sim_adv_noise_all = np.asarray(sim_adv_noise_all)
    sim_test = np.mean(sim_test_noise_all, axis=0)
    sim_adv = np.mean(sim_adv_noise_all, axis=0)
    sim_test_correct, sim_adv_correct = sim_test[detect_index], sim_adv[detect_index]
    sim_all_correct = np.concatenate((sim_test_correct, sim_adv_correct), axis=0)
    true_label_correct = [0]*len(sim_test_correct) + [1]*len(sim_adv_correct) 
    true_label_correct = np.asarray(true_label_correct)


elif args.detector == 'RL':
    error_clean_detect = error_clean_all[detect_index]
    error_adv_detect = error_adv_all[detect_index]
    sim_all_correct = np.concatenate((error_clean_detect, error_adv_detect), axis=0)
    true_label_correct = [0]*len(error_clean_detect) + [1]*len(error_adv_detect) 
    true_label_correct = np.asarray(true_label_correct)

elif args.detector in ['PD_T10', 'PD_T40']:
    div_clean_all = div_clean_all[detect_index]
    div_adv_all = div_adv_all[detect_index]
    sim_all_correct = np.concatenate((div_clean_all, div_adv_all), axis=0)
    true_label_correct = [0]*len(div_clean_all) + [1]*len(div_adv_all) 
    true_label_correct = np.asarray(true_label_correct)

auc1 = roc_auc_score(true_label_correct, sim_all_correct)
print('AUC score is', auc1)