ImageKernelsForPatchAttackDefence / APE-GAN / train.py
train.py
Raw
# -*- coding: utf-8 -*-

import os
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset
import torch.backends.cudnn as cudnn

from tqdm import tqdm
import matplotlib.pyplot as plt

from models import Generator, Discriminator


def show_images(e, x, x_adv, x_fake, save_dir):
    fig, axes = plt.subplots(3, 5, figsize=(10, 6))
    for i in range(5):
        axes[0, i].axis("off"), axes[1, i].axis("off"), axes[2, i].axis("off")
        axes[0, i].imshow(x[i].cpu().numpy().transpose((1, 2, 0)))
        # axes[0, i].imshow(x[i, 0].cpu().numpy(), cmap="gray")
        axes[0, i].set_title("Normal")

        axes[1, i].imshow(x_adv[i].cpu().numpy().transpose((1, 2, 0)))
        # axes[1, i].imshow(x_adv[i, 0].cpu().numpy(), cmap="gray")
        axes[1, i].set_title("Adv")

        axes[2, i].imshow(x_fake[i].cpu().numpy().transpose((1, 2, 0)))
        # axes[2, i].imshow(x_fake[i, 0].cpu().numpy(), cmap="gray")
        axes[2, i].set_title("APE-GAN")
    plt.axis("off")
    plt.savefig(os.path.join(save_dir, "result_{}.png".format(e)))


def main(args):
    lr = args.lr
    epochs = args.epochs
    batch_size = 128
    xi1, xi2 = args.xi1, args.xi2

    check_path = args.checkpoint
    os.makedirs(check_path, exist_ok=True)

    train_data = torch.load("data.tar")
    x_tmp = train_data["normal"][:5]
    x_adv_tmp = train_data["adv"][:5]

    train_data = TensorDataset(train_data["normal"], train_data["adv"])
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

    in_ch = 1 if args.data == "mnist" else 3
    G = Generator(in_ch)
    D = Discriminator(in_ch)

    opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    loss_bce = nn.BCELoss()
    loss_mse = nn.MSELoss()
    cudnn.benchmark = True

    print_str = "\t".join(["{}"] + ["{:.6f}"] * 2)
    print("\t".join(["{:}"] * 3).format("Epoch", "Gen_Loss", "Dis_Loss"))
    for e in range(epochs):
        G.eval()
        x_fake = G(Variable(x_adv_tmp)).data
        show_images(e, x_tmp, x_adv_tmp, x_fake, check_path)
        G.train()
        gen_loss, dis_loss, n = 0, 0, 0
        for x, x_adv in tqdm(train_loader, total=len(train_loader), leave=False):
            current_size = x.size(0)
            x, x_adv = Variable(x), Variable(x_adv )
            # Train D
            t_real = Variable(torch.ones(current_size))
            t_fake = Variable(torch.zeros(current_size))

            y_real = D(x).squeeze()
            x_fake = G(x_adv)
            y_fake = D(x_fake).squeeze()

            loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake)
            opt_D.zero_grad()
            loss_D.backward()
            opt_D.step()

            # Train G
            for _ in range(2):
                x_fake = G(x_adv)
                y_fake = D(x_fake).squeeze()

                loss_G = xi1 * loss_mse(x_fake, x) + xi2 * loss_bce(y_fake, t_real)
                opt_G.zero_grad()
                loss_G.backward()
                opt_G.step()

            gen_loss += loss_D.data.item() * x.size(0)
            dis_loss += loss_G.data.item() * x.size(0)
            n += x.size(0)
        print(print_str.format(e, gen_loss / n, dis_loss / n))
        torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()},
                   os.path.join(check_path, "{}.tar".format(e + 1)))

    G.eval()
    x_fake = G(Variable(x_adv_tmp)).data
    show_images(epochs, x_tmp, x_adv_tmp, x_fake, check_path)
    G.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, default="mnist")
    parser.add_argument("--lr", type=float, default=0.0002)
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--xi1", type=float, default=0.7)
    parser.add_argument("--xi2", type=float, default=0.3)
    parser.add_argument("--checkpoint", type=str, default="./checkpoint/test")
    args = parser.parse_args()
    main(args)