LCYE / train.py
train.py
Raw
from __future__ import absolute_import
from __future__ import print_function, division
import sys
import time
import datetime
import argparse
import os
import numpy as np
import os.path as osp
import math
from random import sample 
from scipy import io, spatial
import random

import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

import models
from models.PCB import PCB_test
# from ReID_attr import get_target_withattr # Need Attribute file
from opts import get_opts, Imagenet_mean, Imagenet_stddev
from GD import Generator, MS_Discriminator, Pat_Discriminator, GANLoss, weights_init, CollaspeNet
from advloss import DeepSupervision, adv_CrossEntropyLoss, adv_CrossEntropyLabelSmooth, adv_TripletLoss, CrossEntropyLoss, TripletLoss
from util import data_manager
from util.dataset_loader import ImageDataset
from util.utils import fliplr, Logger, save_checkpoint, visualize_ranked_results
from util.eval_metrics import make_results
from util.samplers import RandomIdentitySampler, AttrPool

# Training settings
parser = argparse.ArgumentParser(description='adversarial attack')
parser.add_argument('--root', type=str, default='/home/chenfeng10/dataset/', help="root path to data directory")
parser.add_argument('--targetmodel', type=str, default='aligned', choices=models.get_names())
parser.add_argument('--dataset', type=str, default='market1501', choices=data_manager.get_names())
# PATH
parser.add_argument('--G_resume_dir', type=str, default='', metavar='path to resume G')
parser.add_argument('--pre_dir', type=str, default='models', help='path to be attacked model')
parser.add_argument('--attr_dir', type=str, default='', help='path to attribute file')
parser.add_argument('--save_dir', type=str, default='logs', help='path to save model')
parser.add_argument('--vis_dir', type=str, default='vis', help='path to save visualization result')
parser.add_argument('--ablation', type=str, default='', help='for ablation study')
# var
parser.add_argument('--mode', type=str, default='train', help='train/test')
parser.add_argument('--D', type=str, default='MSGAN', help='Type of discriminator: PatchGAN or Multi-stage GAN')
parser.add_argument('--normalization', type=str, default='bn', help='bn or in')
parser.add_argument('--loss', type=str, default='xent_htri', choices=['cent', 'xent', 'htri', 'xent_htri'])
parser.add_argument('--ak_type', type=int, default=-1, help='-1 if non-targeted, 1 if attribute attack')
parser.add_argument('--attr_key', type=str, default='upwhite', help='[attribute, value]')
parser.add_argument('--attr_value', type=int, default=2, help='[attribute, value]')
parser.add_argument('--mag_in', type=float, default=16.0, help='l_inf magnitude of perturbation')
parser.add_argument('--temperature', type=float, default=-1, help="tau in paper")
parser.add_argument('--usegumbel', action='store_true', default=False, help='whether to use gumbel softmax')
parser.add_argument('--use_SSIM', type=int, default=2, help="0: None, 1: SSIM, 2: MS-SSIM ")
parser.add_argument('--reid_loss', type=str, default='softmax', help="loss for reid supervision: softmax or triplet or both")
parser.add_argument('--target', type=bool, default=True, help='target attack')
# Base
parser.add_argument('--train_batch', default=20, type=int,help="train batch size")
parser.add_argument('--test_batch', default=32, type=int, help="test batch size")
parser.add_argument('--epoch', type=int, default=50, help='number of epochs to train for')

parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--num_ker', type=int, default=32, help='generator filters in first conv layer')
parser.add_argument('--lr', type=float, default=0.0002, help='Learning Rate. Default=0.0002')
parser.add_argument('--lr_reid', type=float, default=0.0003, help='Learning Rate for reid. Default=0.01')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--print_freq', type=int, default=20, help="print frequency")
parser.add_argument('--eval_freq', type=int, default=2, help="eval frequency")
parser.add_argument('--usevis', type=bool, default=True, help='whether to save vis')

args = parser.parse_args()
is_training = args.mode == 'train'
attr_list = [args.attr_key, args.attr_value]
attr_matrix = None
if args.attr_dir: 
  assert args.dataset in ['dukemtmcreid', 'market1501']
  attr_matrix = io.loadmat(args.attr_dir)
  args.ablation = osp.join('attr', args.attr_key + '=' + str(args.attr_value))

pre_dir = osp.join(args.pre_dir, args.targetmodel, args.dataset+'.pth.tar')
save_dir = osp.join(args.save_dir, args.targetmodel, args.dataset, args.ablation)
vis_dir = osp.join(args.vis_dir, args.targetmodel, args.dataset, args.ablation)

pool_dim = {'aligned':2048}

def main(opt):
  if not osp.exists(save_dir): os.makedirs(save_dir)
  if not osp.exists(vis_dir): os.makedirs(vis_dir)

  use_gpu = torch.cuda.is_available()
  pin_memory = True if use_gpu else False

  if args.mode == 'train': 
    sys.stdout = Logger(osp.join(save_dir, 'log_train.txt'))
  else: 
    sys.stdout = Logger(osp.join(save_dir, 'log_test.txt'))
  print("==========\nArgs:{}\n==========".format(args))

  if use_gpu:
    print("GPU mode")
    cudnn.benchmark = True
    torch.cuda.manual_seed(args.seed)
  else:
    print("CPU mode")

  ### Setup dataset loader ###
  print("Initializing dataset {}".format(args.dataset))
  dataset = data_manager.init_img_dataset(root=args.root, name=args.dataset, split_id=opt['split_id'], cuhk03_labeled=opt['cuhk03_labeled'], cuhk03_classic_split=opt['cuhk03_classic_split'])
  if args.ak_type < 0:
    trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=RandomIdentitySampler(dataset.train, num_instances=opt['num_instances']), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True)
  elif args.ak_type > 0:
    trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=AttrPool(dataset.train, args.dataset, attr_matrix, attr_list, sample_num=16), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True)
  queryloader = DataLoader(ImageDataset(dataset.query, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False)
  galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False)
  
  ### Prepare criterion ###
  if args.ak_type<0:
    clf_criterion = adv_CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu) if args.loss in ['xent', 'xent_htri'] else adv_CrossEntropyLoss(use_gpu=use_gpu)
  else:
    clf_criterion = nn.MultiLabelSoftMarginLoss()
  metric_criterion = adv_TripletLoss(margin=args.margin, ak_type=args.ak_type)
  criterionGAN = GANLoss() 

  # for memory update idd reid
  criterionReID = {}
  if args.reid_loss == 'softmax':
    criterionReID['ce'] = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=True)
  elif args.reid_loss == 'triplet':
    criterionReID['tri'] = TripletLoss(margin=0.3)
  elif args.reid_loss == 'both':
    criterionReID['ce'] = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=True)
    criterionReID['tri'] = TripletLoss(margin=0.3)

  # for target attack
  criterionTarget = CrossEntropyLoss(num_classes=dataset.num_train_pids, use_gpu=use_gpu, label_smooth=True)

  ### Prepare pretrained model ###
  target_net = models.init_model(name=args.targetmodel, pre_dir=pre_dir, num_classes=dataset.num_train_pids)
  check_freezen(target_net, need_modified=True, after_modified=False)

  ### Prepare main net ###
  G = Generator(3, 3, args.num_ker, norm=args.normalization, pool_dim=pool_dim[args.targetmodel]).apply(weights_init)
  # Callaspe Memory net
  C = CollaspeNet(num_classes=dataset.num_train_pids, pool_dim=2048, mem_dim=dataset.num_train_pids, n_upsampling=5 ,temperature=args.temperature, use_gumbel=args.usegumbel).apply(weights_init)
  if args.D == 'PatchGAN':
    D = Pat_Discriminator(input_nc=6, norm=args.normalization).apply(weights_init)
  elif args.D == 'MSGAN':
    D = MS_Discriminator(input_nc=6, num_classes=dataset.num_train_pids, norm=args.normalization).apply(weights_init)
  check_freezen(G, need_modified=True, after_modified=True)
  check_freezen(D, need_modified=True, after_modified=True)
  check_freezen(C, need_modified=True, after_modified=True)# freezen C
  model_size = sum(g.numel() for g in G.parameters()) + sum(d.numel() for d in D.parameters()) + sum(c.numel() for c in C.parameters())
  model_size = model_size/1000000
  print("Model size: {:.5f}M".format((model_size)))
  # setup optimizer
  optimizer_G = optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
  optimizer_C = optim.Adam(C.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
  optimizer_D = optim.Adam(D.parameters(), lr=args.lr_reid, betas=(args.beta1, 0.999))
  
  if use_gpu: 
    test_target_net = nn.DataParallel(target_net).cuda() if not args.targetmodel == 'pcb' else nn.DataParallel(PCB_test(target_net)).cuda()
    target_net = nn.DataParallel(target_net).cuda() 
    G = nn.DataParallel(G).cuda()
    D = nn.DataParallel(D).cuda()
    C = nn.DataParallel(C).cuda()

  if args.mode == 'test':
    epoch = 'test'
    test(G, D, C, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=True)
    return 0

  # Ready
  start_time = time.time()
  train_time = 0
  worst_mAP, worst_rank1, worst_rank5, worst_rank10, worst_epoch = np.inf, np.inf, np.inf, np.inf, 0
  best_hit, best_epoch = -np.inf, 0
  print("==> Start training")

  for epoch in range(1,args.epoch+1):
    start_train_time = time.time()
    train(epoch, G, D, C, target_net, criterionGAN, clf_criterion, metric_criterion, criterionReID, criterionTarget, optimizer_G, optimizer_D, optimizer_C, trainloader, use_gpu, num_classes= dataset.num_train_pids,target=args.target)
    train_time += round(time.time() - start_train_time)

    if epoch % args.eval_freq == 0:
      print("==> Eval at epoch {}".format(epoch))
      if args.ak_type < 0:
        cmc, mAP = test(G, D, C, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False)
        is_worst = cmc[0]<=worst_rank1 and cmc[1]<=worst_rank5 and cmc[2]<=worst_rank10 and mAP<=worst_mAP
        if is_worst: 
          worst_mAP, worst_rank1, worst_epoch = mAP, cmc[0], epoch
        print("==> Worst_epoch is {}, Worst mAP {:.1%}, Worst rank-1 {:.1%}".format(worst_epoch, worst_mAP, worst_rank1))
        save_checkpoint(G.state_dict(), is_worst, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar'))
        save_checkpoint(D.state_dict(), is_worst, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar'))
        save_checkpoint(C.state_dict(), is_worst, 'C', osp.join(save_dir, 'C_ep' + str(epoch) + '.pth.tar'))

      else: 
        all_hits = test(G, D, C, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False)
        is_best = all_hits[0]>=best_hit
        if is_best: 
          best_hit, best_epoch = all_hits[0], epoch
        print("==> Best_epoch is {}, Best rank-1 {:.1%}".format(best_epoch, best_hit))
        save_checkpoint(G.state_dict(), is_best, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar'))
        save_checkpoint(D.state_dict(), is_best, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar'))
        save_checkpoint(C.state_dict(), is_best, 'C', osp.join(save_dir, 'C_ep' + str(epoch) + '.pth.tar'))

  elapsed = round(time.time() - start_time)
  elapsed = str(datetime.timedelta(seconds=elapsed))
  train_time = str(datetime.timedelta(seconds=train_time))
  print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))

def generate_labels(labels, num_classes):
  # label should be one hot
    #ones = torch.sparse.torch.eye(num_classes)
    #labels = ones.index_select(0, labels.cpu()).cuda()
    targets = torch.zeros_like(labels)
    for i in range(len(labels)):
        rand_v = random.randint(0, num_classes-1)
        while labels[i]==rand_v:
            rand_v = random.randint(0, num_classes-1)
        targets[i] = rand_v
    #targets = targets.astype(np.int32)

    return targets

def train(epoch, G, D, C, target_net, criterionGAN, clf_criterion, metric_criterion, criterionReID, criterionTarget, optimizer_G, optimizer_D, optimizer_C, trainloader, use_gpu, num_classes, target):
  G.train()
  D.train()
  C.train()
  global is_training
  is_training = True

  for batch_idx, (imgs, pids, _, pids_raw) in enumerate(trainloader):
    if use_gpu: 
      imgs, pids, pids_raw = imgs.cuda(), pids.cuda(), pids_raw.cuda()

    # build memory subnet
    spatial_real = target_net(imgs,is_training,spatial= True)# no update [B,2048, 8 ,4]
    if target:
      target_fake_label = generate_labels(pids, num_classes)
      target_fake_label_onehot = torch.sparse.torch.eye(num_classes).index_select(0, target_fake_label.cpu()).cuda()
    else:
      target_fake_label_onehot = torch.ones([pids.size(0), num_classes])

    new_imgs, mask, img_real_pool, img_real_feat = perturb(imgs, spatial_real, G, D, C, target_fake_label_onehot, train_or_test='train')
    new_imgs = new_imgs.cuda()
    mask = mask.cuda()

    # real_one_hot_label
    pids_onehot = torch.sparse.torch.eye(num_classes).index_select(0, pids.cpu()).cuda()
    
    # Fake Detection and Loss
    pred_fake_pool, _ = D(torch.cat((imgs, new_imgs.detach()), 1), C)
    loss_D_fake = criterionGAN(pred_fake_pool,target_fake_label_onehot, False)        

    # Real Detection and Loss
    num = args.train_batch//2
    pred_real, _ = D(torch.cat((imgs[0:num,:,:,:], imgs[num:,:,:,:].detach()), 1), C)
    loss_D_real_1 = criterionGAN(pred_real, pids_onehot[0:num], True)
    loss_D_real_2 = criterionGAN(pred_real, pids_onehot[num:], True)
    loss_D_real = (loss_D_real_2 + loss_D_real_1)/2

    # GAN loss (Fake Passability Loss)
    pred_fake, _ = D(torch.cat((imgs, new_imgs), 1), C)        
    loss_G_GAN = criterionGAN(pred_fake, target_fake_label_onehot,True) 

    # Re-ID loss (update memory)
    loss_C = 0.
    if criterionReID.get('ce') != None:
      loss_C = loss_C + criterionReID['ce'](img_real_feat, pids)
    if criterionReID.get('tri') != None:
      loss_C = loss_C + criterionReID['tri'](img_real_pool,pids)


                  
    
    # Re-ID advloss
    ls = target_net(new_imgs, is_training)
    if len(ls) == 1: new_outputs = ls[0]
    if len(ls) == 2: new_outputs, new_features = ls
    if len(ls) == 3: new_outputs, new_features, new_local_features = ls
    xent_loss, global_loss, loss_G_ssim = 0, 0, 0
    targets = None

    # target recheck loss
    loss_target = DeepSupervision(criterionTarget, new_outputs, target_fake_label) if isinstance(new_features, (tuple, list)) else criterionTarget(new_outputs, target_fake_label)

    if args.loss in ['cent', 'xent', 'xent_htri']:
      if args.ak_type < 0:
        xent_loss = DeepSupervision(clf_criterion, new_outputs, pids) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, pids)

      elif args.ak_type > 0:
        targets = get_target_withattr(attr_matrix, args.dataset, attr_list, pids, pids_raw).float().cuda()
        xent_loss = 0#DeepSupervision(clf_criterion, new_outputs, targets) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, targets)

    if args.loss in ['htri', 'xent_htri']:
      assert len(ls) >= 2 
      global_loss = DeepSupervision(metric_criterion, new_features, pids, targets) if isinstance(new_features, (tuple, list)) else metric_criterion(new_features, pids, targets)
    
    loss_G_ReID = (xent_loss+ global_loss)*opt['ReID_factor'] + loss_target

    # # SSIM loss
    if not args.use_SSIM == 0:
      from util.ms_ssim import msssim, ssim
      loss_func = msssim if args.use_SSIM == 2 else ssim
      loss_G_ssim = (1-loss_func(imgs, new_imgs))*0.1

    ############## Forward ###############
    loss_D = (loss_D_fake + loss_D_real)/2
    loss_G = loss_G_GAN + loss_G_ReID + loss_G_ssim
    ############## Backward #############
    # update generator weights
    optimizer_G.zero_grad()
    # loss_G.backward(retain_graph=True)
    loss_G.backward()
    optimizer_G.step()
    # update discriminator weights
    optimizer_D.zero_grad()
    loss_D.backward()
    optimizer_D.step()
    # update collaspe memory network
    optimizer_C.zero_grad()
    loss_C.backward()
    optimizer_C.step()
    if (batch_idx+1) % args.print_freq == 0:
      print("===> Epoch[{}]({}/{}) loss_D: {:.4f} loss_G_GAN: {:.4f} loss_G_ReID: {:.4f} loss_G_SSIM: {:.4f}  loss_Mem: {:.4f}".format(epoch, batch_idx, len(trainloader), loss_D.item(), loss_G_GAN.item(), loss_G_ReID.item(), loss_G_ssim, loss_C.item()))

def test(G, D, C, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False, ranks=[1, 5, 10, 20]):
  global is_training
  is_training = False
  if args.mode == 'test' and args.G_resume_dir:
    G_resume_dir, D_resume_dir, C_resume_dir = args.G_resume_dir, args.G_resume_dir.replace('G', 'D'), args.G_resume_dir.replace('G', 'C')
    G_checkpoint, D_checkpoint, C_checkpoint = torch.load(G_resume_dir), torch.load(D_resume_dir), torch.load(C_resume_dir)
    G_state_dict = G_checkpoint['state_dict'] if isinstance(G_checkpoint, dict) and 'state_dict' in G_checkpoint else G_checkpoint
    D_state_dict = D_checkpoint['state_dict'] if isinstance(D_checkpoint, dict) and 'state_dict' in D_checkpoint else D_checkpoint
    C_state_dict = C_checkpoint['state_dict'] if isinstance(C_checkpoint, dict) and 'state_dict' in C_checkpoint else C_checkpoint

    G.load_state_dict(G_state_dict)
    D.load_state_dict(D_state_dict)
    C.load_state_dict(C_state_dict)
    print("Sucessfully, loading {} and {}".format(G_resume_dir, D_resume_dir))

  with torch.no_grad():
    # target mode only work on query
    qf, lqf, new_qf, new_lqf, q_pids, q_camids = extract_and_perturb(queryloader, G, D, C, target_net, use_gpu, query_or_gallery='query', is_test=is_test, epoch=epoch, num_classes = dataset.num_train_pids,target= args.target)
    gf, lgf, g_pids, g_camids = extract_and_perturb(galleryloader, G, D, C, target_net, use_gpu, query_or_gallery='gallery', is_test=is_test, epoch=epoch, num_classes = dataset.num_train_pids,target= False)


    if args.ak_type > 0:
      distmat, hits, ignore_list = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type, attr_matrix, args.dataset, attr_list)
      print("Hits rate, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(ranks[0], hits[ranks[0]-1], ranks[1], hits[ranks[1]-1], ranks[2], hits[ranks[2]-1], ranks[3], hits[ranks[3]-1]))
      if not is_test:
        return hits

    else:
      if is_test:
        distmat, cmc, mAP = make_results(qf, gf, lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
        new_distmat, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
        print("Results ----------")
        print("Before, mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(mAP, ranks[0], cmc[ranks[0]-1], ranks[1], cmc[ranks[1]-1], ranks[2], cmc[ranks[2]-1], ranks[3], cmc[ranks[3]-1]))
        print("After , mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1]))
        if args.usevis: 
          visualize_ranked_results(distmat, dataset, save_dir=osp.join(vis_dir, 'origin_results'), topk=20)
        if args.usevis: 
          visualize_ranked_results(new_distmat, dataset, save_dir=osp.join(vis_dir, 'polluted_results'), topk=20)
      else:
        _, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
        print("mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1]))
        return new_cmc, new_mAP

def extract_and_perturb(loader, G, D, C, target_net, use_gpu, query_or_gallery, is_test, epoch, num_classes, target=True):
  f, lf, new_f, new_lf, l_pids, l_camids = [], [], [], [], [], []
  ave_mask, num = 0, 0
  for batch_idx, (imgs, pids, camids, pids_raw) in enumerate(loader):
    if use_gpu: 
      imgs = imgs.cuda()
    ls = extract(imgs, target_net)
    if len(ls) == 1: features = ls[0]
    if len(ls) == 2: 
      features, local_features = ls
      lf.append(local_features.detach().data.cpu())

    f.append(features.detach().data.cpu())
    l_pids.extend(pids)
    l_camids.extend(camids)

    if query_or_gallery == 'query':
      G.eval()
      D.eval()
      C.eval()

      if target:
        target_fake_label = generate_labels(pids, num_classes)
        target_fake_label_onehot = torch.sparse.torch.eye(num_classes).index_select(0, target_fake_label.cpu()).cuda()
      else:
        target_fake_label_onehot = torch.ones([pids.size(0), num_classes])

      spatial_real = target_net(imgs,is_training,spatial= True)
      new_imgs, delta, mask = perturb(imgs, spatial_real, G, D, C, target_fake_label_onehot, train_or_test='test')
      ave_mask += torch.sum(mask.detach()).cpu().numpy()
      num += imgs.size(0)

      ls = extract(new_imgs, target_net)
      if len(ls) == 1: new_features = ls[0]
      if len(ls) == 2: 
        new_features, new_local_features = ls
        new_lf.append(new_local_features.detach().data.cpu())
      new_f.append(new_features.detach().data.cpu())

      ls = [imgs, new_imgs, delta, mask]
      if is_test: 
        save_img(ls, pids, camids, epoch, batch_idx)

  f = torch.cat(f, 0)
  if not lf == []: lf = torch.cat(lf, 0)
  l_pids, l_camids = np.asarray(l_pids), np.asarray(l_camids)
  
  print("Extracted features for {} set, obtained {}-by-{} matrix".format(query_or_gallery, f.size(0), f.size(1)))
  if query_or_gallery == 'gallery':
    return [f, lf, l_pids, l_camids]
  elif query_or_gallery == 'query':
    new_f = torch.cat(new_f, 0)
    if not new_lf == []: 
      new_lf = torch.cat(new_lf, 0)
    return [f, lf, new_f, new_lf, l_pids, l_camids]

def extract(imgs, target_net):
  if args.targetmodel in ['pcb', 'lsro']:
    ls = [target_net(imgs, is_training)[0] + target_net(fliplr(imgs), is_training)[0]]
  else: 
    ls = target_net(imgs, is_training)
  for i in range(len(ls)): ls[i] = ls[i].data.cpu()
  return ls


def perturb(imgs, spatial_real, G, D, C,target_label, train_or_test='test'):
  n,c,h,w = imgs.size()
  img_real_pool, img_real_feat, mask = C(spatial_real, G_or_D='C') 
  middle_feat = G(imgs,'E')
  middle_mem_feat = C(middle_feat, target_label,'G')
  delta = G(middle_mem_feat,'D')
  delta = L_norm(delta, train_or_test)
  #new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda())

  #_, mask = D(torch.cat((imgs, new_imgs.detach()), 1))
  delta = delta * mask
  new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda())

  for c in range(3):
    new_imgs.data[:,c,:,:] = new_imgs.data[:,c,:,:].clamp(new_imgs.data[:,c,:,:].min(), new_imgs.data[:,c,:,:].max()) # do clamping per channel
  if train_or_test == 'train':
    return new_imgs, mask, img_real_pool, img_real_feat
  elif train_or_test == 'test':
    return new_imgs, delta, mask

def L_norm(delta, mode='train'):
  delta.data += 1 
  delta.data *= 0.5

  for c in range(3):
    delta.data[:,c,:,:] = (delta.data[:,c,:,:] - Imagenet_mean[c]) / Imagenet_stddev[c]

  bs = args.train_batch if (mode == 'train') else args.test_batch
  for i in range(bs):
    # do per channel l_inf normalization
    for ci in range(3):
      try:
        l_inf_channel = delta[i,ci,:,:].data.abs().max()
        # l_inf_channel = torch.norm(delta[i,ci,:,:]).data
        mag_in_scaled_c = args.mag_in/(255.0*Imagenet_stddev[ci])
        delta[i,ci,:,:].data *= np.minimum(1.0, mag_in_scaled_c / l_inf_channel.cpu()).float().cuda()
      except IndexError:
        break
  return delta

def save_img(ls, pids, camids, epoch, batch_idx):
  image, new_image, delta, mask = ls
  # undo normalize image color channels
  delta_tmp = torch.zeros(delta.size())
  for c in range(3):
    image.data[:,c,:,:] = (image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]
    new_image.data[:,c,:,:] = (new_image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]
    delta_tmp.data[:,c,:,:] = (delta.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]

  if args.usevis: 
    torchvision.utils.save_image(image.data, osp.join(vis_dir, 'original_epoch{}_batch{}.png'.format(epoch, batch_idx)))
    torchvision.utils.save_image(new_image.data, osp.join(vis_dir, 'polluted_epoch{}_batch{}.png'.format(epoch, batch_idx)))
    torchvision.utils.save_image(delta_tmp.data, osp.join(vis_dir, 'delta_epoch{}_batch{}.png'.format(epoch, batch_idx)))
    torchvision.utils.save_image(mask.data*255, osp.join(vis_dir, 'mask_epoch{}_batch{}.png'.format(epoch, batch_idx)))

def check_freezen(net, need_modified=False, after_modified=None):
  # print(net)
  cc = 0
  for child in net.children():
    for param in child.parameters():
      if need_modified: param.requires_grad = after_modified
      # if param.requires_grad: print('child', cc , 'was active')
      # else: print('child', cc , 'was forzen')
    cc += 1

if __name__ == '__main__':
  opt = get_opts(args.targetmodel)
  main(opt)