# -*- coding: utf-8 -*-
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import TensorDataset
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import matplotlib.pyplot as plt
from models import Generator, Discriminator
def show_images(e, x, x_adv, x_fake, save_dir):
fig, axes = plt.subplots(3, 5, figsize=(10, 6))
for i in range(5):
axes[0, i].axis("off"), axes[1, i].axis("off"), axes[2, i].axis("off")
axes[0, i].imshow(x[i].cpu().numpy().transpose((1, 2, 0)))
# axes[0, i].imshow(x[i, 0].cpu().numpy(), cmap="gray")
axes[0, i].set_title("Normal")
axes[1, i].imshow(x_adv[i].cpu().numpy().transpose((1, 2, 0)))
# axes[1, i].imshow(x_adv[i, 0].cpu().numpy(), cmap="gray")
axes[1, i].set_title("Adv")
axes[2, i].imshow(x_fake[i].cpu().numpy().transpose((1, 2, 0)))
# axes[2, i].imshow(x_fake[i, 0].cpu().numpy(), cmap="gray")
axes[2, i].set_title("APE-GAN")
plt.axis("off")
plt.savefig(os.path.join(save_dir, "result_{}.png".format(e)))
def main(args):
lr = args.lr
epochs = args.epochs
batch_size = 128
xi1, xi2 = args.xi1, args.xi2
check_path = args.checkpoint
os.makedirs(check_path, exist_ok=True)
train_data = torch.load("data.tar")
x_tmp = train_data["normal"][:5]
x_adv_tmp = train_data["adv"][:5]
train_data = TensorDataset(train_data["normal"], train_data["adv"])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
in_ch = 1 if args.data == "mnist" else 3
G = Generator(in_ch)
D = Discriminator(in_ch)
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
loss_bce = nn.BCELoss()
loss_mse = nn.MSELoss()
cudnn.benchmark = True
print_str = "\t".join(["{}"] + ["{:.6f}"] * 2)
print("\t".join(["{:}"] * 3).format("Epoch", "Gen_Loss", "Dis_Loss"))
for e in range(epochs):
G.eval()
x_fake = G(Variable(x_adv_tmp)).data
show_images(e, x_tmp, x_adv_tmp, x_fake, check_path)
G.train()
gen_loss, dis_loss, n = 0, 0, 0
for x, x_adv in tqdm(train_loader, total=len(train_loader), leave=False):
current_size = x.size(0)
x, x_adv = Variable(x), Variable(x_adv )
# Train D
t_real = Variable(torch.ones(current_size))
t_fake = Variable(torch.zeros(current_size))
y_real = D(x).squeeze()
x_fake = G(x_adv)
y_fake = D(x_fake).squeeze()
loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake)
opt_D.zero_grad()
loss_D.backward()
opt_D.step()
# Train G
for _ in range(2):
x_fake = G(x_adv)
y_fake = D(x_fake).squeeze()
loss_G = xi1 * loss_mse(x_fake, x) + xi2 * loss_bce(y_fake, t_real)
opt_G.zero_grad()
loss_G.backward()
opt_G.step()
gen_loss += loss_D.data.item() * x.size(0)
dis_loss += loss_G.data.item() * x.size(0)
n += x.size(0)
print(print_str.format(e, gen_loss / n, dis_loss / n))
torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()},
os.path.join(check_path, "{}.tar".format(e + 1)))
G.eval()
x_fake = G(Variable(x_adv_tmp)).data
show_images(epochs, x_tmp, x_adv_tmp, x_fake, check_path)
G.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="mnist")
parser.add_argument("--lr", type=float, default=0.0002)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--xi1", type=float, default=0.7)
parser.add_argument("--xi2", type=float, default=0.3)
parser.add_argument("--checkpoint", type=str, default="./checkpoint/test")
args = parser.parse_args()
main(args)