ViTGuard / target_models / run / attack.py
attack.py
Raw
import torch
import torch.nn as nn
import torchvision.models as models
import sys
import os
import argparse
from source_model.deit_ensemble import ModifiedDeiT
from source_model.deit_tr import VisionTransformer_hierarchical

sys.path.append('../')
from WhiteBox import PGDAttack, CWAttack_L2, PatchFool, FGSMAttack, APGDAttack, AttentionFool, PGDAttack_deit
from utils import register_hook_for_resnet
import TransformerConfigs_pretrain as configs
from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification

sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Validation, GetCIFAR10Validation
from Evaluations import test_vit



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

def get_aug():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='TinyImagenet', type=str)
    parser.add_argument('--attack', default='PGD', type=str)
    args = parser.parse_args()
    return args

args = get_aug()
if args.dataset == 'CIFAR10':
    test_loader = GetCIFAR10Validation(imgSize=224, ratio=0.2)
    num_labels = 10
elif args.dataset == 'CIFAR100':
    test_loader = GetCIFAR100Validation(imgSize=224, ratio=0.2)
    num_labels = 100
elif args.dataset == 'TinyImagenet':
    test_loader = load_tiny()
    num_labels = 200

#load the classification model
model_arch = 'ViT-16'
config = configs.get_b16_config()
num_patch = 4 # This is the number of altered patches in the PatchFool attack
patch_size = 16

model = ViTModel_custom(config=config)
model = ViTForImageClassification(config, model, num_labels)
filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset)
model.load_state_dict(torch.load(filename), strict=False)
model = nn.DataParallel(model).cuda()
model.eval()

adv_filepath = "../results/{}/{}/adv_results/".format(model_arch, args.dataset)
if not os.path.isdir(adv_filepath):
    os.mkdir(adv_filepath)
if args.dataset in ['CIFAR10', 'CIFAR100']:
    if args.attack == 'PGD':
        advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=model, \
                                                    eps=0.03, num_steps=10, step_size=0.003)

    elif args.attack == 'CW':
        advLoader = CWAttack_L2(device=device, dataLoader=test_loader, model=model, \
                            c=1, kappa=50, max_iter=30, learning_rate=0.01)

    elif args.attack == 'SGM':
        # load source model
        resnet_model = models.resnet18(weights='DEFAULT')
        resnet_model.fc = nn.Linear(512, num_labels)
        sur_filename = {'CIFAR10':'../results/BlackBox/cifar10_resnet18_weights.pth',
                        'CIFAR100':'../results/BlackBox/cifar100_resnet18_weights.pth'}
        resnet_model.load_state_dict(torch.load(sur_filename[args.dataset]), strict=False)
        resnet_model = nn.DataParallel(resnet_model).cuda()
        resnet_model.eval()

        arch = 'resnet18'
        gamma = 0.5
        register_hook_for_resnet(resnet_model, arch=arch, gamma=gamma)

        advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=resnet_model, \
                                                    eps=0.03, num_steps=10, step_size=0.003)
        
    elif args.attack == 'PatchFool':
        n_tokens = int(224/patch_size)*int(224/patch_size) + 1
        advLoader = PatchFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)


elif args.dataset == 'TinyImagenet':
    if args.attack == 'FGSM':
        advLoader = FGSMAttack(device=device, dataLoader=test_loader, model=model, epsilonMax=0.06)
    if args.attack == 'APGD':
        advLoader = APGDAttack(device, test_loader, model, norm='Linf', eps=0.06)

    if args.attack == 'PGD':
        advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=model, \
                                                    eps=0.06, num_steps=10, step_size=0.006)

    if args.attack == 'CW':
        advLoader = CWAttack_L2(device=device, dataLoader=test_loader, model=model, \
                            c=1, kappa=50, max_iter=30, learning_rate=0.02)
        
    if args.attack == 'SGM':
        # load source model
        resnet_model = models.resnet18(weights='DEFAULT')
        resnet_model.fc = nn.Linear(512, num_labels)
        resnet_model.load_state_dict(torch.load('../results/BlackBox/tiny_resnet18_weights.pth'), strict=False)
        resnet_model = nn.DataParallel(resnet_model).cuda()
        resnet_model.eval()

        arch = 'resnet18'
        gamma = 0.5
        register_hook_for_resnet(resnet_model, arch=arch, gamma=gamma)

        advLoader, _, _ = PGDAttack(device=device, dataLoader=test_loader, model=resnet_model, \
                                                    eps=0.06, num_steps=10, step_size=0.006)

    if args.attack == 'PatchFool':
        n_tokens = int(224/patch_size)*int(224/patch_size) + 1
        advLoader = PatchFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)
    if args.attack == 'AttentionFool':
        from TransformerModels_pretrain_presoftmax import ViTModel_custom
        model = ViTModel_custom(config=config)
        model = ViTForImageClassification(config, model, num_labels)
        filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset)
        model.load_state_dict(torch.load(filename), strict=False)
        model = nn.DataParallel(model).cuda()
        model.eval()
        n_tokens = int(224/patch_size)*int(224/patch_size) + 1
        advLoader = AttentionFool(dataLoader=test_loader, model=model, patch_size=patch_size, num_patch=num_patch, n_tokens=n_tokens)
    
    if args.attack == 'SE':
        src_model = ModifiedDeiT(num_classes=200)
        src_model = src_model.to(device)
        src_model.eval()
        advLoader, _, _ = PGDAttack_deit(device=device, dataLoader=test_loader, model=src_model, eps=0.06, num_steps=50, step_size=0.0012)
    if args.attack == 'TR':
        src_model = VisionTransformer_hierarchical(num_classes=200)
        src_model = src_model.to(device)
        src_model.eval()
        src_model.load_state_dict(torch.load('../results/BlackBox/trm_weights_tiny.pth'), strict=False)
        advLoader, _, _ = PGDAttack_deit(device=device, dataLoader=test_loader, model=src_model, eps=0.06, num_steps=50, step_size=0.0012)



torch.save(advLoader, adv_filepath + args.attack + '_advLoader.pth')
# Classification accuracy on the adversarial examples
_, adv_acc = test_vit(model=model, test_loader=advLoader, device=device)