LCYE / models / __init__.py
__init__.py
Raw
from __future__ import absolute_import
import torch
import torch.nn as nn

from .DenseNet import *
from .MuDeep import *
from .AlignedReID import *
from .PCB import *
from .HACNN import *
from .IDE import *
from .LSRO import *

__factory = {
  # 1. 
  'hacnn': HACNN,
  'densenet121': DenseNet121,
  'ide': IDE,
  # 2.
  'aligned': ResNet50,
  'pcb': PCB,
  'mudeep': MuDeep,
  # 3.
  'cam': IDE,
  'hhl': IDE, 
  'lsro': DenseNet121,
  'spgan': IDE,
}

def get_names():
  return __factory.keys()

def init_model(name, pre_dir, *args, **kwargs):
  if name not in __factory.keys(): 
    raise KeyError("Unknown model: {}".format(name))

  print("Initializing model: {}".format(name))
  net = __factory[name](*args, **kwargs)
  # load pretrained model
  checkpoint = torch.load(pre_dir) # for Python 2
  # checkpoint = torch.load(pre_dir, encoding="latin1") # for Python 3
  state_dict = checkpoint['state_dict'] if isinstance(checkpoint, dict) and 'state_dict' in checkpoint else checkpoint
  change = False
  for k, v in state_dict.items():
    if k[:6] == 'module':
      change = True
      break
  if not change: 
    new_state_dict = state_dict
  else:
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
      name = k[7:] # remove 'module.' of dataparallel
      new_state_dict[name]=v
  net.load_state_dict(new_state_dict)
  # freeze
  net.eval()
  net.volatile = True
  return net