LCYE / models / HACNN.py
HACNN.py
Raw
from __future__ import absolute_import

import torch
from torch import nn
from torch.nn import functional as F
import torchvision

__all__ = ['HACNN']

class ConvBlock(nn.Module):
    """Basic convolutional block:
    convolution + batch normalization + relu.

    Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
        in_c (int): number of input channels.
        out_c (int): number of output channels.
        k (int or tuple): kernel size.
        s (int or tuple): stride.
        p (int or tuple): padding.
    """
    def __init__(self, in_c, out_c, k, s=1, p=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))

class InceptionA(nn.Module):
    """
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels AFTER concatenation
    """
    def __init__(self, in_channels, out_channels):
        super(InceptionA, self).__init__()
        single_out_channels = out_channels // 4

        self.stream1 = nn.Sequential(
            ConvBlock(in_channels, single_out_channels, 1),
            ConvBlock(single_out_channels, single_out_channels, 3, p=1),
        )
        self.stream2 = nn.Sequential(
            ConvBlock(in_channels, single_out_channels, 1),
            ConvBlock(single_out_channels, single_out_channels, 3, p=1),
        )
        self.stream3 = nn.Sequential(
            ConvBlock(in_channels, single_out_channels, 1),
            ConvBlock(single_out_channels, single_out_channels, 3, p=1),
        )
        self.stream4 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1),
            ConvBlock(in_channels, single_out_channels, 1),
        )

    def forward(self, x):
        s1 = self.stream1(x)
        s2 = self.stream2(x)
        s3 = self.stream3(x)
        s4 = self.stream4(x)
        y = torch.cat([s1, s2, s3, s4], dim=1)
        return y

class InceptionB(nn.Module):
    """
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels AFTER concatenation
    """
    def __init__(self, in_channels, out_channels):
        super(InceptionB, self).__init__()
        single_out_channels = out_channels // 4

        self.stream1 = nn.Sequential(
            ConvBlock(in_channels, single_out_channels, 1),
            ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
        )
        self.stream2 = nn.Sequential(
            ConvBlock(in_channels, single_out_channels, 1),
            ConvBlock(single_out_channels, single_out_channels, 3, p=1),
            ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
        )
        self.stream3 = nn.Sequential(
            nn.MaxPool2d(3, stride=2, padding=1),
            ConvBlock(in_channels, single_out_channels*2, 1),
        )

    def forward(self, x):
        s1 = self.stream1(x)
        s2 = self.stream2(x)
        s3 = self.stream3(x)
        y = torch.cat([s1, s2, s3], dim=1)
        return y

class SpatialAttn(nn.Module):
    """Spatial Attention (Sec. 3.1.I.1)"""
    def __init__(self):
        super(SpatialAttn, self).__init__()
        self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
        self.conv2 = ConvBlock(1, 1, 1)

    def forward(self, x):
        # global cross-channel averaging
        x = x.mean(1, keepdim=True)
        # 3-by-3 conv
        x = self.conv1(x)
        # bilinear resizing
        x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
        # scaling conv
        x = self.conv2(x)
        return x

class ChannelAttn(nn.Module):
    """Channel Attention (Sec. 3.1.I.2)"""
    def __init__(self, in_channels, reduction_rate=16):
        super(ChannelAttn, self).__init__()
        assert in_channels%reduction_rate == 0
        self.conv1 = ConvBlock(in_channels, in_channels//reduction_rate, 1)
        self.conv2 = ConvBlock(in_channels//reduction_rate, in_channels, 1)

    def forward(self, x):
        # squeeze operation (global average pooling)
        x = F.avg_pool2d(x, x.size()[2:])
        # excitation operation (2 conv layers)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class SoftAttn(nn.Module):
    """Soft Attention (Sec. 3.1.I)
    Aim: Spatial Attention + Channel Attention
    Output: attention maps with shape identical to input.
    """
    def __init__(self, in_channels):
        super(SoftAttn, self).__init__()
        self.spatial_attn = SpatialAttn()
        self.channel_attn = ChannelAttn(in_channels)
        self.conv = ConvBlock(in_channels, in_channels, 1)

    def forward(self, x):
        y_spatial = self.spatial_attn(x)
        y_channel = self.channel_attn(x)
        y = y_spatial * y_channel
        y = F.sigmoid(self.conv(y))
        return y

class HardAttn(nn.Module):
    """Hard Attention (Sec. 3.1.II)"""
    def __init__(self, in_channels):
        super(HardAttn, self).__init__()
        self.fc = nn.Linear(in_channels, 4*2)
        self.init_params()

    def init_params(self):
        self.fc.weight.data.zero_()
        self.fc.bias.data.copy_(torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float))

    def forward(self, x):
        # squeeze operation (global average pooling)
        x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
        # predict transformation parameters
        theta = F.tanh(self.fc(x))
        theta = theta.view(-1, 4, 2)
        return theta

class HarmAttn(nn.Module):
    """Harmonious Attention (Sec. 3.1)"""
    def __init__(self, in_channels):
        super(HarmAttn, self).__init__()
        self.soft_attn = SoftAttn(in_channels)
        self.hard_attn = HardAttn(in_channels)

    def forward(self, x):
        y_soft_attn = self.soft_attn(x)
        theta = self.hard_attn(x)
        return y_soft_attn, theta

class HACNN(nn.Module):
    """
    Harmonious Attention Convolutional Neural Network

    Reference:
    Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.

    Args:
        num_classes (int): number of classes to predict
        nchannels (list): number of channels AFTER concatenation
        feat_dim (int): feature dimension for a single stream
        learn_region (bool): whether to learn region features (i.e. local branch)
    """
    def __init__(self, num_classes, loss={'xent', 'htri'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs):
        super(HACNN, self).__init__()
        self.loss = loss
        self.learn_region = learn_region
        self.use_gpu = use_gpu

        self.conv = ConvBlock(3, 32, 3, s=2, p=1)

        # Construct Inception + HarmAttn blocks
        # ============== Block 1 ==============
        self.inception1 = nn.Sequential(
            InceptionA(32, nchannels[0]),
            InceptionB(nchannels[0], nchannels[0]),
        )
        self.ha1 = HarmAttn(nchannels[0])

        # ============== Block 2 ==============
        self.inception2 = nn.Sequential(
            InceptionA(nchannels[0], nchannels[1]),
            InceptionB(nchannels[1], nchannels[1]),
        )
        self.ha2 = HarmAttn(nchannels[1])

        # ============== Block 3 ==============
        self.inception3 = nn.Sequential(
            InceptionA(nchannels[1], nchannels[2]),
            InceptionB(nchannels[2], nchannels[2]),
        )
        self.ha3 = HarmAttn(nchannels[2])

        self.fc_global = nn.Sequential(
            nn.Linear(nchannels[2], feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(),
        )
        self.classifier_global = nn.Linear(feat_dim, num_classes)

        if self.learn_region:
            self.init_scale_factors()
            self.local_conv1 = InceptionB(32, nchannels[0])
            self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
            self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
            self.fc_local = nn.Sequential(
                nn.Linear(nchannels[2]*4, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.ReLU(),
            )
            self.classifier_local = nn.Linear(feat_dim, num_classes)
            self.feat_dim = feat_dim * 2
        else:
            self.feat_dim = feat_dim

    def init_scale_factors(self):
        # initialize scale factors (s_w, s_h) for four regions
        self.scale_factors = []
        self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
        self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
        self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
        self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))

    def stn(self, x, theta):
        """Perform spatial transform
        x: (batch, channel, height, width)
        theta: (batch, 2, 3)
        """
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

    def transform_theta(self, theta_i, region_idx):
        """Transform theta to include (s_w, s_h),
        resulting in (batch, 2, 3)"""
        scale_factors = self.scale_factors[region_idx]
        theta = torch.zeros(theta_i.size(0), 2, 3)
        theta[:,:,:2] = scale_factors
        theta[:,:,-1] = theta_i
        if self.use_gpu: theta = theta.cuda()
        return theta

    def forward(self, x, is_training):
        assert x.size(2) == 160 and x.size(3) == 64, \
            "Input size does not match, expected (160, 64) but got ({}, {})".format(x.size(2), x.size(3))
        x = self.conv(x)

        # ============== Block 1 ==============
        # global branch
        x1 = self.inception1(x)
        x1_attn, x1_theta = self.ha1(x1)
        x1_out = x1 * x1_attn
        # local branch
        if self.learn_region:
            x1_local_list = []
            for region_idx in range(4):
                x1_theta_i = x1_theta[:,region_idx,:]
                x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
                x1_trans_i = self.stn(x, x1_theta_i)
                x1_trans_i = F.upsample(x1_trans_i, (24, 28), mode='bilinear', align_corners=True)
                x1_local_i = self.local_conv1(x1_trans_i)
                x1_local_list.append(x1_local_i)

        # ============== Block 2 ==============
        # Block 2
        # global branch
        x2 = self.inception2(x1_out)
        x2_attn, x2_theta = self.ha2(x2)
        x2_out = x2 * x2_attn
        # local branch
        if self.learn_region:
            x2_local_list = []
            for region_idx in range(4):
                x2_theta_i = x2_theta[:,region_idx,:]
                x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
                x2_trans_i = self.stn(x1_out, x2_theta_i)
                x2_trans_i = F.upsample(x2_trans_i, (12, 14), mode='bilinear', align_corners=True)
                x2_local_i = x2_trans_i + x1_local_list[region_idx]
                x2_local_i = self.local_conv2(x2_local_i)
                x2_local_list.append(x2_local_i)

        # ============== Block 3 ==============
        # Block 3
        # global branch
        x3 = self.inception3(x2_out)
        x3_attn, x3_theta = self.ha3(x3)
        x3_out = x3 * x3_attn
        # local branch
        if self.learn_region:
            x3_local_list = []
            for region_idx in range(4):
                x3_theta_i = x3_theta[:,region_idx,:]
                x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
                x3_trans_i = self.stn(x2_out, x3_theta_i)
                x3_trans_i = F.upsample(x3_trans_i, (6, 7), mode='bilinear', align_corners=True)
                x3_local_i = x3_trans_i + x2_local_list[region_idx]
                x3_local_i = self.local_conv3(x3_local_i)
                x3_local_list.append(x3_local_i)

        # ============== Feature generation ==============
        # global branch
        x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
        x_global = self.fc_global(x_global)
        # local branch
        if self.learn_region:
            x_local_list = []
            for region_idx in range(4):
                x_local_i = x3_local_list[region_idx]
                x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(x_local_i.size(0), -1)
                x_local_list.append(x_local_i)
            x_local = torch.cat(x_local_list, 1)
            x_local = self.fc_local(x_local)

        if not is_training:
            # l2 normalization before concatenation
            if self.learn_region:
                x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
                x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
                return [torch.cat([x_global, x_local], 1)]
            else:
                return [x_global]

        prelogits_global = self.classifier_global(x_global)
        if self.learn_region:
            prelogits_local = self.classifier_local(x_local)
        
        if self.loss == {'xent'}:
            if self.learn_region:
                return [prelogits_global, prelogits_local]
            else:
                return [prelogits_global]
        elif self.loss == {'xent', 'htri'}:
            if self.learn_region:
                return [(prelogits_global, prelogits_local), (x_global, x_local)]
            else:
                return [prelogits_global, x_global]
        else:
            raise KeyError("Unsupported loss: {}".format(self.loss))