SalFormer_GitFront / model.py
model.py
Raw
import torch
import torch.nn as nn
from sa import SpatialAttention
import torchvision.models as models
from ResNet import ResNet50
from torch.nn import functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from einops import rearrange
from torch import Tensor

from typing import Optional
from edge_fn import Edge_Module


from cmt_module import CMTStem, Patch_Aggregate, CMTBlock
from sa import SpatialAttention

def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=has_bias)


def conv3x3_bn_relu(in_planes, out_planes, stride=1):
    return nn.Sequential(
            conv3x3(in_planes, out_planes, stride),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True),
            )

class ConvNormAct(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int,
        norm: nn.Module = nn.BatchNorm2d,
        act: nn.Module = nn.ReLU,
        **kwargs
    ):

        super().__init__(
            nn.Conv2d(
                in_features,
                out_features,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
            ),
            norm(out_features),
            act(),
        )


Conv1X1BnReLU = partial(ConvNormAct, kernel_size=1)
Conv3X3BnReLU = partial(ConvNormAct, kernel_size=3)

class ResidualAdd(nn.Module):
    def __init__(self, block: nn.Module, shortcut: Optional[nn.Module] = None):
        super().__init__()
        self.block = block
        self.shortcut = shortcut
        
    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        if self.shortcut:
            res = self.shortcut(res)
        x += res
        return x


class BottleNeck(nn.Sequential):
    def __init__(self, in_features: int, out_features: int, reduction: int = 4):
        reduced_features = out_features // reduction
       
        super(BottleNeck,self).__init__()
        layers=[]
        layers.append(Conv1X1BnReLU(in_features, reduced_features))
        layers.extend([Conv3X3BnReLU(reduced_features, reduced_features),
                     Conv1X1BnReLU(reduced_features, out_features, act=nn.Identity)] )
        
        self.conv = nn.Sequential(*layers)
    def forward(self,x):
        return self.conv(x)      

class IR_Net(nn.Module):
    def __init__(self,
        in_channels = 3,
        stem_channel = 32,
        cmt_channel = [46, 92, 184, 368],
        patch_channel = [46, 92, 184, 368],
        block_layer = [2, 2, 10, 2],
        R = 3.6,
        img_size = 224,
        num_class = 10
    ):
        super(IR_Net, self).__init__()

        # Image size for each stage

        size = [img_size // 4, img_size // 8, img_size // 16, img_size // 32]

         #Backbone model
        self.resnet = ResNet50('rgb')
        self.resnet_depth=ResNet50('rgb')

        self.upsample2= nn. Upsample(scale_factor=4, mode='bilinear')
        self.upsample1= nn. Upsample(scale_factor=2, mode='bilinear')
        self.low_conv1= nn.Conv2d(320,64,3,1,1)
        self.bn_low= nn.BatchNorm2d(64)
        self.bn_last= nn.BatchNorm2d(56)
        self.spa_att = SpatialAttention(in_channels=64)
        self.training=True
        self.feature_conv= nn.Conv2d(3584,46,3,stride=1, padding= 1)
        self.bn_f= nn.BatchNorm2d(46)
        self.bn_fd= nn.BatchNorm2d(46)
        self.bn_44= nn.BatchNorm2d(368)
        self.bn_l0= nn.BatchNorm2d(1472)
        self.bn_l1= nn.BatchNorm2d(736)
        self.bn_l2= nn.BatchNorm2d(92)
        
        self.initialize_weights()
        # Stem layer
        self.stem = CMTStem(in_channels, stem_channel)

        # Patch Aggregation Layer
        self.patch1 = Patch_Aggregate(stem_channel, patch_channel[0])
        self.patch2 = Patch_Aggregate(patch_channel[0], patch_channel[1])
        self.patch3 = Patch_Aggregate(patch_channel[1], patch_channel[2])
        self.patch4 = Patch_Aggregate(patch_channel[2], patch_channel[3])

        # CMT Block Layer
        stage1 = []
        for _ in range(block_layer[0]):
            cmt_layer = CMTBlock(
                    img_size = 56,
                    stride = 8,
                    d_k = cmt_channel[0],
                    d_v = cmt_channel[0],
                    num_heads = 1,
                    R = R,
                    in_channels = patch_channel[0]
            )
            stage1.append(cmt_layer)
        self.stage1 = nn.Sequential(*stage1)

        stage2 = []
        for _ in range(block_layer[1]):
            cmt_layer = CMTBlock(
                    img_size = size[1],
                stride = 4,
                d_k = cmt_channel[1] // 2,
                d_v = cmt_channel[1] // 2,
                num_heads = 2,
                R = R,
                in_channels = patch_channel[1]
            )
            stage2.append(cmt_layer)
        self.stage2 = nn.Sequential(*stage2)

        stage3 = []
        for _ in range(block_layer[2]):
            cmt_layer = CMTBlock(
                img_size = size[2],
                stride = 2,
                d_k = cmt_channel[2] // 4,
                d_v = cmt_channel[2] // 4,
                num_heads = 4,
                R = R,
                in_channels = patch_channel[2]
            )
            stage3.append(cmt_layer)
        self.stage3 = nn.Sequential(*stage3)

        stage4 = []
        for _ in range(block_layer[3]):
            cmt_layer = CMTBlock(
                img_size = size[3],
                stride = 1,
                d_k = cmt_channel[3] // 8,
                d_v = cmt_channel[3] // 8,
                num_heads = 8,
                R = R,
                in_channels = patch_channel[3]
            )
            stage4.append(cmt_layer)
        self.stage4 = nn.Sequential(*stage4)

        # Global Average Pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # FC
        self.fc = nn.Sequential(
            nn.Linear(cmt_channel[3], 1280),
            nn.ReLU(inplace = True),
        )
        
        self.last_conv=nn.Conv2d(3044,1472,1,1,0)
        self.last_conv1=nn.Conv2d(1472,736,1,1,0)
        self.last_conv2=nn.Conv2d(736,92,1,1,0)
        self.last_conv3=nn.Conv2d(92,1,1,1,0)
        #self.last_conv1=nn.Conv2d(56,1,1,1,0)
        # Final Classifier
        self.classifier = nn.Linear(1280, num_class)

        self.edge_layer = Edge_Module()
        self.edge_feature = conv3x3_bn_relu(1, 32)
        self.up_edge = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor = 2),
            conv3x3(32, 1)
        )
        self.up_edge4 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor = 4),
            conv3x3(32, 1)
        )
        self.relu = nn.ReLU(True)
        self.up4 = nn.UpsamplingBilinear2d(scale_factor = 4)
        self.conv1381_1 = conv3x3(1381, 1)

        self.b4=BottleNeck(368,368,4)
        self.b3=BottleNeck(184,184,4)
        self.b2=BottleNeck(92,92,2)
        self.b1=BottleNeck(46,46,2)

        self.conv_dec4= nn.Conv2d(736,368,1,1,0)
        self.conv_dec3= nn.Conv2d(1472,368,1,1,0)
        self.conv_dec2= nn.Conv2d(2024,368,1,1,0)
        self.conv_dec1= nn.Conv2d(2484,368,1,1,0)
        self.sigmoid = nn.Sigmoid()

        self.convv4= nn.Conv2d(1104,1,3,1,1)
        self.convv3= nn.Conv2d(1840,1,3,1,1)
        self.convv2= nn.Conv2d(2392,1,3,1,1)
        self.convv1= nn.Conv2d(2852,1,3,1,1)

    def forward(self, x,x_depth):
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x) #112

        x1 = self.resnet.layer1(x)  #112

        low_concatenated= torch.cat((x,x1),1)
        
        
        low_conv1= self.low_conv1(low_concatenated)
        low_conv1= self.bn_low(low_conv1)
        
        low_conv1= nn.functional.relu(low_conv1)
        spa_out=self.spa_att(low_conv1)  #112  #spa on rgb image
        spa_out= torch.mul(spa_out,low_conv1)
        #conv_12_feats = torch.mul(conv_12_feats, conv_12_sa)

        x2 = self.resnet.layer2(x1) 
        x2_1=x2 #56

        x3_1 = self.resnet.layer3_1(x2_1) #28
        x4_1 = self.resnet.layer4_1(x3_1) #14

        #Concatenating layer3,4 and 5 (both rgb  )
        feature_out= torch.cat((x2,self.upsample1(x3_1),self.upsample2(x4_1)),1)
         #56
        #Resnet over 
         #3584, 56, 56
        feature_out_conv= self.feature_conv(feature_out)
         #5,46,28,28
        feature_out_conv= self.bn_f(feature_out_conv)
        feature_out_conv= nn.functional.relu(feature_out_conv)

        

        x11 = self.stage1(feature_out_conv) # 2CMT blocks done in first level
        
        x22 = self.patch2(x11) 
        
        x22 = self.stage2(x22) # 2nd level- 2 blocks
        

        x33 = self.patch3(x22)
        
        x33 = self.stage3(x33) #3rd level- 2 blocks , 10 in original paper
        
        x44 = self.patch4(x33)
        
        x44 = self.stage4(x44) #4th level- 2 blocks
      

        #DEPTH BRANCH
        x_depth = self.resnet_depth.conv1(x_depth)
        
        x_depth = self.resnet_depth.bn1(x_depth)
        x_depth = self.resnet_depth.relu(x_depth)
        
        x1_depth=self.resnet_depth.layer1(x_depth)

        edge_map = self.edge_layer(x_depth,x1_depth)
        
       #print(edge_map.shape) #(1, 112, 112)
        edge_feature = self.edge_feature(edge_map)
       # print(edge_feature.shape) #32, 112, 112
        #end_sal = self.conv256_32(fuse_fea4)
        up_edge = self.up_edge(edge_feature)

        low_concatenated_depth= torch.cat((x_depth,x1_depth),1)
        
        

        low_conv1_depth= self.low_conv1(low_concatenated_depth)
        low_conv1_depth= self.bn_low(low_conv1_depth)
        
        low_conv1_depth= nn.functional.relu(low_conv1_depth)
        spa_out_depth=self.spa_att(low_conv1_depth) #112 # spa on depth 
        spa_out_depth= torch.mul(spa_out_depth,low_conv1_depth)
       
        spa_final= torch.mul(spa_out,spa_out_depth)
        spa_final1= torch.cat((spa_final,spa_out,spa_out_depth),1)
        



        x2_depth=self.resnet_depth.layer2(x1_depth) 
        x3_1_depth=self.resnet_depth.layer3_1(x2_depth)
        x4_1_depth=self.resnet_depth.layer4_1(x3_1_depth) 

        feature_out_depth= torch.cat((x2_depth,self.upsample1(x3_1_depth),self.upsample2(x4_1_depth)),1)

        feature_out_conv_depth= self.feature_conv(feature_out_depth)
        feature_out_conv_depth= self.bn_fd(feature_out_conv_depth)
        feature_out_conv_depth= nn.functional.relu(feature_out_conv_depth)


        x11_depth = self.stage1(feature_out_conv_depth) # 2CMT blocks done in first level

        x22_depth = self.patch2( x11_depth) 
        x22_depth = self.stage2(x22_depth)

        x33_depth = self.patch3(x22_depth)
        x33_depth = self.stage3(x33_depth)

        x44_depth = self.patch4(x33_depth)
        x44_depth = self.stage4(x44_depth)

        mul_4= torch.mul(x44,x44_depth)
        
         # bottlemeck block
        #BottleNeck(32, 64)(x).shape
        a=self.b4(mul_4)
        b=self.b4(a)
        cmf4=b+mul_4
       
        

        mul_3= torch.mul(x33,x33_depth)
         #184,14,14
         # bottlemeck block
        #BottleNeck(32, 64)(x).shape
        a=self.b3(mul_3)
        b=self.b3(a)
        cmf3=b+mul_3
       
        

        mul_2= torch.mul(x22,x22_depth)
         #368,7,7
         # bottlemeck block
        #BottleNeck(32, 64)(x).shape
        a=self.b2(mul_2)
        b=self.b2(a)
        cmf2=b+mul_2
       
        

        mul_1= torch.mul(x11,x11_depth)
         #184,14,14
         # bottlemeck block
        #BottleNeck(32, 64)(x).shape
        a=self.b1(mul_1)
        b=self.b1(a)
        cmf1=b+mul_1
        

      
        #stage4_down= self.conv_stage4(x44)
        #cmf4_down= self.conv_stage4(cmf4)
        dec4= torch.cat((x44,cmf4),1) # shape is (736,7,7), need to upsample to 14,14
        dec4_conv= self.conv_dec4(dec4) #368
        
        dec4_conv= self.bn_44(dec4_conv)
       
        dec4_conv= nn.functional.relu(dec4_conv)
        dec4_conv= torch.cat((dec4,dec4_conv),1) #1104,7,7
        dec4= self.upsample1(dec4_conv) #(1104,14,14)
        #dec44= self.convv4(dec4)
       
        dec3= torch.cat((x33,cmf3,dec4),1)  #shape is (1472,14,14)
        dec3_conv= self.conv_dec3(dec3) # 368
        
        dec3_conv= self.bn_44(dec3_conv)
        dec3_conv= nn.functional.relu(dec3_conv)

        dec3_conv= torch.cat((dec3,dec3_conv),1) #(1840,14,14)
        dec3= self.upsample1(dec3_conv) #(1840,28,28)
        #dec33= self.convv3(dec3)
        
        dec2= torch.cat((x22,cmf2,dec3),1)  #shape is (2024,28,28)
        dec2_conv= self.conv_dec2(dec2) #368
        dec2_conv= self.bn_44(dec2_conv)
        dec2_conv= nn.functional.relu(dec2_conv)
        dec2_conv= torch.cat((dec2,dec2_conv),1) #(2392,28,28)
        dec2= self.upsample1(dec2_conv) #(2392,56,56)
        #dec22= self.convv2(dec2)
       
        dec1= torch.cat((x11,cmf1,dec2),1)  #shape is (2484,56,56)
        dec1_conv= self.conv_dec1(dec1)  #368
        dec1_conv= self.bn_44(dec1_conv)
        dec1_conv= nn.functional.relu(dec1_conv)
        dec1_conv= torch.cat((dec1,dec1_conv),1) #(2852,56,56)
        dec1= dec1_conv
        dec1= self.upsample1(dec1)
        #dec11= self.convv1(dec1)


        spa_con=torch.cat((dec1,spa_final1),1) # 3044
        out= self.last_conv(spa_con)
        out= self.bn_l0(out)
        out= nn.functional.relu(out)

        #out=self.bn_last(out)
        #out=nn.functional.relu(out)
        #print(out.shape)
        
        out= self.last_conv1(out)
        out= self.bn_l1(out)
        out= nn.functional.relu(out)
        out= self.last_conv2(out)
        out= self.bn_l2(out)
        out= nn.functional.relu(out)

        out= self.last_conv3(out)
        
        #print(out.shape)
        out= self.upsample1(out)
        #print(out.shape)
        
        #**********************************

      
        
        return out, up_edge, self.sigmoid(out)
        #return out, up_edge, dec11,dec22,dec33,dec44, self.sigmoid(out), self.sigmoid(dec11), self.sigmoid(dec22), self.sigmoid(dec33), self.sigmoid(dec44)
         


    def initialize_weights(self):
        res50 = models.resnet50(pretrained=True)
        pretrained_dict = res50.state_dict()
        all_params = {}
        for k, v in self.resnet.state_dict().items():
            if k in pretrained_dict.keys():
                v = pretrained_dict[k]
                all_params[k] = v
            elif '_1' in k:
                name = k.split('_1')[0] + k.split('_1')[1]
                v = pretrained_dict[name]
                all_params[k] = v
            elif '_2' in k:
                name = k.split('_2')[0] + k.split('_2')[1]
                v = pretrained_dict[name]
                all_params[k] = v
        assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
        self.resnet.load_state_dict(all_params)

        all_params = {}
        for k, v in self.resnet_depth.state_dict().items():
            if k=='conv1.weight':
                all_params[k]=torch.nn.init.normal_(v, mean=0, std=1)
            elif k in pretrained_dict.keys():
                v = pretrained_dict[k]
                all_params[k] = v
            elif '_1' in k:
                name = k.split('_1')[0] + k.split('_1')[1]
                v = pretrained_dict[name]
                all_params[k] = v
            elif '_2' in k:
                name = k.split('_2')[0] + k.split('_2')[1]
                v = pretrained_dict[name]
                all_params[k] = v
        assert len(all_params.keys()) == len(self.resnet_depth.state_dict().keys())
        self.resnet_depth.load_state_dict(all_params)