from __future__ import absolute_import import torch from torch import nn from torch.nn import functional as F import torchvision __all__ = ['DenseNet121'] class DenseNet121(nn.Module): def __init__(self, num_classes, loss={'xent'}, **kwargs): super(DenseNet121, self).__init__() self.loss = loss densenet121 = torchvision.models.densenet121(pretrained=True) self.base = densenet121.features self.classifier = nn.Linear(1024, num_classes) self.feat_dim = 1024 # feature dimension def forward(self, x, is_training): x = self.base(x) x = F.avg_pool2d(x, x.size()[2:]) f = x.view(x.size(0), -1) if not is_training: return f y = self.classifier(f) if self.loss == {'xent'}: return [y] elif self.loss == {'xent', 'htri'}: return [y, f] elif self.loss == {'cent'}: return [y, f] elif self.loss == {'ring'}: return [y, f] else: raise KeyError("Unsupported loss: {}".format(self.loss))