# -*- coding: utf-8 -*- import os import argparse import torch import torch.nn as nn from torch.autograd import Variable from torchvision import datasets from torchvision import transforms from tqdm import tqdm from models import MnistCNN, CifarCNN, Generator from utils import fgsm, accuracy from torchvision.utils import save_image def load_dataset(args): if args.data == "cifar": test_loader = torch.utils.data.DataLoader( datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=False, download=False, transform=transforms.Compose([ transforms.ToTensor()])), batch_size=128, shuffle=False) return test_loader def load_cnn(args): if args.data == "mnist": return MnistCNN elif args.data == "cifar": return CifarCNN def defend_image_with_gan(starting_img): gan_path = "./APE_GAN/checkpoint/cifar/10.tar" #torch.load(args.gan_path) gan_point = torch.load(gan_path) G = Generator(3) G.load_state_dict(gan_point["generator"]) loss_cre = nn.CrossEntropyLoss() # model.eval(), G.eval() G.eval() x_ape = G(starting_img) # save_image(x_ape, str(i)+"_image_ape.png") return x_ape def main(args): eps = args.eps test_loader = load_dataset(args) model_point = torch.load("cnn.tar") gan_point = torch.load(args.gan_path) CNN = load_cnn(args) model = CNN() model.load_state_dict(model_point["state_dict"]) in_ch = 1 if args.data == "mnist" else 3 G = Generator(in_ch) G.load_state_dict(gan_point["generator"]) loss_cre = nn.CrossEntropyLoss() # model.eval(), G.eval() G.eval() normal_acc, adv_acc, ape_acc, n = 0, 0, 0, 0 i = 0 for x, t in tqdm(test_loader, total=len(test_loader), leave=False): x, t = Variable(x), Variable(t) #y = model(x) #normal_acc += accuracy(y, t) x_adv = fgsm(model, x, t, loss_cre, eps) # y_adv = model(x_adv) # adv_acc += accuracy(y_adv, t) x_ape = G(x_adv) print(f"{type(x_ape)}") save_image(x_ape, str(i)+"_image_ape.png") i = i +1 """ y_ape = model(x_ape) ape_acc += accuracy(y_ape, t) n += t.size(0) print("Accuracy: normal {:.6f}, fgsm {:.6f}, ape {:.6f}".format( normal_acc / n * 100, adv_acc / n * 100, ape_acc / n * 100)) """ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data", type=str, default="mnist") parser.add_argument("--eps", type=float, default=0.15) parser.add_argument("--gan_path", type=str, default="./checkpoint/test/3.tar") args = parser.parse_args() main(args)