import torch
import torch.nn as nn
import torchvision.models as models
import sys
import os
import argparse
from source_model.deit_ensemble import ModifiedDeiT
from source_model.deit_tr import VisionTransformer_hierarchical
sys.path.append('../')
from WhiteBox import PGDAttack, CWAttack_L2, PatchFool, FGSMAttack, APGDAttack, AttentionFool, PGDAttack_deit
from utils import register_hook_for_resnet
import TransformerConfigs_pretrain as configs
from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Validation, GetCIFAR10Validation
from Evaluations import test_vit
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
def get_aug():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='TinyImagenet', type=str)
parser.add_argument('--attack', default='PGD', type=str)
args = parser.parse_args()
return args
args = get_aug()
if args.dataset == 'CIFAR10':
test_loader = GetCIFAR10Validation(imgSize=224, ratio=0.2)
num_labels = 10
elif args.dataset == 'CIFAR100':
test_loader = GetCIFAR100Validation(imgSize=224, ratio=0.2)
num_labels = 100
elif args.dataset == 'TinyImagenet':
test_loader = load_tiny()
num_labels = 200
#load the classification model
model_arch = 'ViT-16'
config = configs.get_b16_config()
num_patch = 4 # This is the number of altered patches in the PatchFool attack
patch_size = 16
model = ViTModel_custom(config=config)
model = ViTForImageClassification(config, model, num_labels)
filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset)
model.load_state_dict(torch.load(filename), strict=False)
model = nn.DataParallel(model).cuda()
model.eval()
adv_filepath = "../results/{}/{}/adv_results/".format(model_arch, args.dataset)
if not os.path.isdir(adv_filepath):
os.mkdir(adv_filepath)
if args.dataset in ['CIFAR10', 'CIFAR100']:
if args.attack == 'PGD':
advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=model, \
eps=0.03, num_steps=10, step_size=0.003)
elif args.attack == 'CW':
advLoader = CWAttack_L2(device=device, dataLoader=test_loader, model=model, \
c=1, kappa=50, max_iter=30, learning_rate=0.01)
elif args.attack == 'SGM':
# load source model
resnet_model = models.resnet18(weights='DEFAULT')
resnet_model.fc = nn.Linear(512, num_labels)
sur_filename = {'CIFAR10':'../results/BlackBox/cifar10_resnet18_weights.pth',
'CIFAR100':'../results/BlackBox/cifar100_resnet18_weights.pth'}
resnet_model.load_state_dict(torch.load(sur_filename[args.dataset]), strict=False)
resnet_model = nn.DataParallel(resnet_model).cuda()
resnet_model.eval()
arch = 'resnet18'
gamma = 0.5
register_hook_for_resnet(resnet_model, arch=arch, gamma=gamma)
advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=resnet_model, \
eps=0.03, num_steps=10, step_size=0.003)
elif args.attack == 'PatchFool':
n_tokens = int(224/patch_size)*int(224/patch_size) + 1
advLoader = PatchFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)
elif args.dataset == 'TinyImagenet':
if args.attack == 'FGSM':
advLoader = FGSMAttack(device=device, dataLoader=test_loader, model=model, epsilonMax=0.06)
if args.attack == 'APGD':
advLoader = APGDAttack(device, test_loader, model, norm='Linf', eps=0.06)
if args.attack == 'PGD':
advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=model, \
eps=0.06, num_steps=10, step_size=0.006)
if args.attack == 'CW':
advLoader = CWAttack_L2(device=device, dataLoader=test_loader, model=model, \
c=1, kappa=50, max_iter=30, learning_rate=0.02)
if args.attack == 'SGM':
# load source model
resnet_model = models.resnet18(weights='DEFAULT')
resnet_model.fc = nn.Linear(512, num_labels)
resnet_model.load_state_dict(torch.load('../results/BlackBox/tiny_resnet18_weights.pth'), strict=False)
resnet_model = nn.DataParallel(resnet_model).cuda()
resnet_model.eval()
arch = 'resnet18'
gamma = 0.5
register_hook_for_resnet(resnet_model, arch=arch, gamma=gamma)
advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=resnet_model, \
eps=0.06, num_steps=10, step_size=0.006)
if args.attack == 'PatchFool':
n_tokens = int(224/patch_size)*int(224/patch_size) + 1
advLoader = PatchFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)
if args.attack == 'AttentionFool':
from TransformerModels_pretrain_presoftmax import ViTModel_custom
model = ViTModel_custom(config=config)
model = ViTForImageClassification(config, model, num_labels)
filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset)
model.load_state_dict(torch.load(filename), strict=False)
model = nn.DataParallel(model).cuda()
model.eval()
n_tokens = int(224/patch_size)*int(224/patch_size) + 1
advLoader = AttentionFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)
if args.attack == 'SE':
src_model = ModifiedDeiT(num_classes=200)
src_model = src_model.to(device)
src_model.eval()
advLoader, _, _ = PGDAttack_deit(device=device, dataLoader=test_loader, model=src_model, eps=0.06, num_steps=50, step_size=0.0012)
if args.attack == 'TR':
src_model = VisionTransformer_hierarchical(num_classes=200)
src_model = src_model.to(device)
src_model.eval()
src_model.load_state_dict(torch.load('../results/BlackBox/trm_weights_tiny.pth'), strict=False)
advLoader, _, _ = PGDAttack_deit(device=device, dataLoader=test_loader, model=src_model, eps=0.06, num_steps=50, step_size=0.0012)
torch.save(advLoader, adv_filepath + args.attack + '_advLoader.pth')
# Classification accuracy on the adversarial examples
_, adv_acc = test_vit(model=model, test_loader=advLoader, device=device)