from __future__ import absolute_import import torch from torch import nn from torch.nn import functional as F import torchvision __all__ = ['MuDeep'] class ConvBlock(nn.Module): """Basic convolutional block: convolution + batch normalization + relu. Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): in_c (int): number of input channels. out_c (int): number of output channels. k (int or tuple): kernel size. s (int or tuple): stride. p (int or tuple): padding. """ def __init__(self, in_c, out_c, k, s, p): super(ConvBlock, self).__init__() self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) self.bn = nn.BatchNorm2d(out_c) def forward(self, x): return F.relu(self.bn(self.conv(x))) class ConvLayers(nn.Module): """Preprocessing layers.""" def __init__(self): super(ConvLayers, self).__init__() self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1) self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.maxpool(x) return x class MultiScaleA(nn.Module): """Multi-scale stream layer A (Sec.3.1)""" def __init__(self): super(MultiScaleA, self).__init__() self.stream1 = nn.Sequential( ConvBlock(96, 96, k=1, s=1, p=0), ConvBlock(96, 24, k=3, s=1, p=1), ) self.stream2 = nn.Sequential( nn.AvgPool2d(kernel_size=3, stride=1, padding=1), ConvBlock(96, 24, k=1, s=1, p=0), ) self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0) self.stream4 = nn.Sequential( ConvBlock(96, 16, k=1, s=1, p=0), ConvBlock(16, 24, k=3, s=1, p=1), ConvBlock(24, 24, k=3, s=1, p=1), ) def forward(self, x): s1 = self.stream1(x) s2 = self.stream2(x) s3 = self.stream3(x) s4 = self.stream4(x) y = torch.cat([s1, s2, s3, s4], dim=1) return y class Reduction(nn.Module): """Reduction layer (Sec.3.1)""" def __init__(self): super(Reduction, self).__init__() self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1) self.stream3 = nn.Sequential( ConvBlock(96, 48, k=1, s=1, p=0), ConvBlock(48, 56, k=3, s=1, p=1), ConvBlock(56, 64, k=3, s=2, p=1), ) def forward(self, x): s1 = self.stream1(x) s2 = self.stream2(x) s3 = self.stream3(x) y = torch.cat([s1, s2, s3], dim=1) return y class MultiScaleB(nn.Module): """Multi-scale stream layer B (Sec.3.1)""" def __init__(self): super(MultiScaleB, self).__init__() self.stream1 = nn.Sequential( nn.AvgPool2d(kernel_size=3, stride=1, padding=1), ConvBlock(256, 256, k=1, s=1, p=0), ) self.stream2 = nn.Sequential( ConvBlock(256, 64, k=1, s=1, p=0), ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)), ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), ) self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0) self.stream4 = nn.Sequential( ConvBlock(256, 64, k=1, s=1, p=0), ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)), ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)), ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)), ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), ) def forward(self, x): s1 = self.stream1(x) s2 = self.stream2(x) s3 = self.stream3(x) s4 = self.stream4(x) return s1, s2, s3, s4 class Fusion(nn.Module): """Saliency-based learning fusion layer (Sec.3.2)""" def __init__(self): super(Fusion, self).__init__() self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1)) self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1)) self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1)) self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1)) # We add an average pooling layer to reduce the spatial dimension # of feature maps, which differs from the original paper. self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0) def forward(self, x1, x2, x3, x4): s1 = self.a1.expand_as(x1) * x1 s2 = self.a2.expand_as(x2) * x2 s3 = self.a3.expand_as(x3) * x3 s4 = self.a4.expand_as(x4) * x4 y = self.avgpool(s1 + s2 + s3 + s4) return y class MuDeep(nn.Module): """Multiscale deep neural network. Reference: Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017. """ def __init__(self, num_classes, loss={'xent', 'htri'}, **kwargs): super(MuDeep, self).__init__() self.loss = loss self.block1 = ConvLayers() self.block2 = MultiScaleA() self.block3 = Reduction() self.block4 = MultiScaleB() self.block5 = Fusion() # Due to this fully connected layer, input image has to be fixed # in shape, i.e. (3, 256, 128), such that the last convolutional feature # maps are of shape (256, 16, 8). If input shape is changed, # the input dimension of this layer has to be changed accordingly. self.fc = nn.Sequential( nn.Linear(256*16*8, 4096), nn.BatchNorm1d(4096), nn.ReLU(), ) self.classifier = nn.Linear(4096, num_classes) self.feat_dim = 4096 # feature dimension def forward(self, x, is_training): x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(*x) x = x.view(x.size(0), -1) x = self.fc(x) y = self.classifier(x) if self.loss == {'xent'}: return [y] elif self.loss == {'xent', 'htri'}: return [y, x] elif self.loss == {'cent'}: return [y, x] else: raise KeyError("Unsupported loss: {}".format(self.loss))