ImageKernelsForPatchAttackDefence / APE-GAN / generate.py
generate.py
Raw
# -*- coding: utf-8 -*-

import os
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

from torchvision import datasets
from torchvision import transforms

from tqdm import tqdm

from models import MnistCNN, CifarCNN
from utils import accuracy, fgsm


def load_dataset(args):
    if args.data == "mnist":
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(os.path.expanduser("~/.torch/data/mnist"), train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor()])),
            batch_size=128, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(os.path.expanduser("~/.torch/data/mnist"), train=False, download=False,
                           transform=transforms.Compose([
                               transforms.ToTensor()])),
            batch_size=128, shuffle=False)
    elif args.data == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor()])),
            batch_size=128, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=False, download=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor()])),
            batch_size=128, shuffle=False)
    return train_loader, test_loader


def load_cnn(args):
    if args.data == "mnist":
        return MnistCNN
    elif args.data == "cifar":
        return CifarCNN


def main(args):
    print("Generating Model ...")
    print("-" * 30)

    train_loader, test_loader = load_dataset(args)
    CNN = load_cnn(args)
    model = CNN()
    cudnn.benchmark = True

    opt = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001)
    scheduler = lr_scheduler.MultiStepLR(opt, milestones=args.milestones, gamma=args.gamma)
    loss_func = nn.CrossEntropyLoss().cuda()

    epochs = args.epochs
    print_str = "\t".join(["{}"] + ["{:.6f}"] * 4)
    print("\t".join(["{:}"] * 5).format("Epoch", "TrainLoss", "TestLoss", "TrainAcc.", "TestAcc."))
    for e in range(epochs):
        train_loss, train_acc, train_n = 0, 0, 0
        test_loss, test_acc, test_n = 0, 0, 0

        model.train()
        for x, t in tqdm(train_loader, total=len(train_loader), leave=False):
            x, t = Variable(x), Variable(t)
            y = model(x)
            loss = loss_func(y, t)
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item() * t.size(0)
            train_acc += accuracy(y, t)
            train_n += t.size(0)

        model.eval()
        for x, t in tqdm(test_loader, total=len(test_loader), leave=False):
            x, t = Variable(x), Variable(t)
            y = model(x)
            loss = loss_func(y, t)

            test_loss += loss.item() * t.size(0)
            test_acc += accuracy(y, t)
            test_n += t.size(0)
        scheduler.step()
        print(print_str.format(e, train_loss / train_n, test_loss / test_n,
                               train_acc / train_n * 100, test_acc / test_n * 100))

    # Generate Adversarial Examples
    print("-" * 30)
    print("Genrating Adversarial Examples ...")
    eps = args.eps
    train_acc, adv_acc, train_n = 0, 0, 0
    normal_data, adv_data = None, None
    for x, t in tqdm(train_loader, total=len(train_loader), leave=False):
        x, t = Variable(x), Variable(t)
        y = model(x)
        train_acc += accuracy(y, t)

        x_adv = fgsm(model, x, t, loss_func, eps)
        y_adv = model(x_adv)
        adv_acc += accuracy(y_adv, t)
        train_n += t.size(0)

        x, x_adv = x.data, x_adv.data
        if normal_data is None:
            normal_data, adv_data = x, x_adv
        else:
            normal_data = torch.cat((normal_data, x))
            adv_data = torch.cat((adv_data, x_adv))

    print("Accuracy(normal) {:.6f}, Accuracy(FGSM) {:.6f}".format(train_acc / train_n * 100, adv_acc / train_n * 100))
    torch.save({"normal": normal_data, "adv": adv_data}, "data.tar")
    torch.save({"state_dict": model.state_dict()}, "cnn.tar")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, default="mnist")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--milestones", type=list, default=[50, 75])
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--eps", type=float, default=0.15)
    args = parser.parse_args()
    main(args)