ImageKernelsForPatchAttackDefence / APE-GAN / test_model.py
test_model.py
Raw
# -*- coding: utf-8 -*-

import os
import argparse

import torch
import torch.nn as nn
from torch.autograd import Variable

from torchvision import datasets
from torchvision import transforms

from tqdm import tqdm

from models import MnistCNN, CifarCNN, Generator
from utils import fgsm, accuracy

from torchvision.utils import save_image


def load_dataset(args):
    if args.data == "cifar":
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=False, download=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor()])),
            batch_size=128, shuffle=False)
    return test_loader


def load_cnn(args):
    if args.data == "mnist":
        return MnistCNN
    elif args.data == "cifar":
        return CifarCNN

def defend_image_with_gan(starting_img):
    gan_path =  "./APE_GAN/checkpoint/cifar/10.tar" #torch.load(args.gan_path)

    gan_point = torch.load(gan_path)
    G = Generator(3)
    G.load_state_dict(gan_point["generator"])
    loss_cre = nn.CrossEntropyLoss()

   # model.eval(), G.eval()
    G.eval()
    x_ape = G(starting_img)
#    save_image(x_ape, str(i)+"_image_ape.png")
    return x_ape


def main(args):
    eps = args.eps
    test_loader = load_dataset(args)

    model_point = torch.load("cnn.tar")
    gan_point = torch.load(args.gan_path)
    CNN = load_cnn(args)

    model = CNN()
    model.load_state_dict(model_point["state_dict"])
    in_ch = 1 if args.data == "mnist" else 3

    G = Generator(in_ch)
    G.load_state_dict(gan_point["generator"])
    loss_cre = nn.CrossEntropyLoss()

   # model.eval(), G.eval()
    G.eval()
    normal_acc, adv_acc, ape_acc, n = 0, 0, 0, 0
    i = 0
    for x, t in tqdm(test_loader, total=len(test_loader), leave=False):
        x, t = Variable(x), Variable(t)
        
        #y = model(x)
        #normal_acc += accuracy(y, t)

        x_adv = fgsm(model, x, t, loss_cre, eps)
       # y_adv = model(x_adv)
       # adv_acc += accuracy(y_adv, t)

        x_ape = G(x_adv)
        print(f"{type(x_ape)}")
        save_image(x_ape, str(i)+"_image_ape.png")
        i = i +1
        """
        y_ape = model(x_ape)
        ape_acc += accuracy(y_ape, t)
        n += t.size(0)
    print("Accuracy: normal {:.6f}, fgsm {:.6f}, ape {:.6f}".format(
        normal_acc / n * 100,
        adv_acc / n * 100,
        ape_acc / n * 100))
"""

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, default="mnist")
    parser.add_argument("--eps", type=float, default=0.15)
    parser.add_argument("--gan_path", type=str, default="./checkpoint/test/3.tar")
    args = parser.parse_args()
    main(args)