LCYE / models / DenseNet.py
DenseNet.py
Raw
from __future__ import absolute_import

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

__all__ = ['DenseNet121']


class DenseNet121(nn.Module):
    def __init__(self, num_classes, loss={'xent'}, **kwargs):
        super(DenseNet121, self).__init__()
        self.loss = loss
        densenet121 = torchvision.models.densenet121(pretrained=True)
        self.base = densenet121.features
        self.classifier = nn.Linear(1024, num_classes)
        self.feat_dim = 1024 # feature dimension

    def forward(self, x, is_training):
        x = self.base(x)
        x = F.avg_pool2d(x, x.size()[2:])
        f = x.view(x.size(0), -1)
        if not is_training:
            return f
        y = self.classifier(f)
        
        if self.loss == {'xent'}:
            return [y]
        elif self.loss == {'xent', 'htri'}:
            return [y, f]
        elif self.loss == {'cent'}:
            return [y, f]
        elif self.loss == {'ring'}:
            return [y, f]
        else:
            raise KeyError("Unsupported loss: {}".format(self.loss))