import torch import torch.nn as nn from sa import SpatialAttention import torchvision.models as models from ResNet import ResNet50 from torch.nn import functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from functools import partial from einops import rearrange from torch import Tensor from typing import Optional from edge_fn import Edge_Module from cmt_module import CMTStem, Patch_Aggregate, CMTBlock from sa import SpatialAttention def conv3x3(in_planes, out_planes, stride=1, has_bias=False): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias) def conv3x3_bn_relu(in_planes, out_planes, stride=1): return nn.Sequential( conv3x3(in_planes, out_planes, stride), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True), ) class ConvNormAct(nn.Sequential): def __init__( self, in_features: int, out_features: int, kernel_size: int, norm: nn.Module = nn.BatchNorm2d, act: nn.Module = nn.ReLU, **kwargs ): super().__init__( nn.Conv2d( in_features, out_features, kernel_size=kernel_size, padding=kernel_size // 2, ), norm(out_features), act(), ) Conv1X1BnReLU = partial(ConvNormAct, kernel_size=1) Conv3X3BnReLU = partial(ConvNormAct, kernel_size=3) class ResidualAdd(nn.Module): def __init__(self, block: nn.Module, shortcut: Optional[nn.Module] = None): super().__init__() self.block = block self.shortcut = shortcut def forward(self, x: Tensor) -> Tensor: res = x x = self.block(x) if self.shortcut: res = self.shortcut(res) x += res return x class BottleNeck(nn.Sequential): def __init__(self, in_features: int, out_features: int, reduction: int = 4): reduced_features = out_features // reduction super(BottleNeck,self).__init__() layers=[] layers.append(Conv1X1BnReLU(in_features, reduced_features)) layers.extend([Conv3X3BnReLU(reduced_features, reduced_features), Conv1X1BnReLU(reduced_features, out_features, act=nn.Identity)] ) self.conv = nn.Sequential(*layers) def forward(self,x): return self.conv(x) class IR_Net(nn.Module): def __init__(self, in_channels = 3, stem_channel = 32, cmt_channel = [46, 92, 184, 368], patch_channel = [46, 92, 184, 368], block_layer = [2, 2, 10, 2], R = 3.6, img_size = 224, num_class = 10 ): super(IR_Net, self).__init__() # Image size for each stage size = [img_size // 4, img_size // 8, img_size // 16, img_size // 32] #Backbone model self.resnet = ResNet50('rgb') self.resnet_depth=ResNet50('rgb') self.upsample2= nn. Upsample(scale_factor=4, mode='bilinear') self.upsample1= nn. Upsample(scale_factor=2, mode='bilinear') self.low_conv1= nn.Conv2d(320,64,3,1,1) self.bn_low= nn.BatchNorm2d(64) self.bn_last= nn.BatchNorm2d(56) self.spa_att = SpatialAttention(in_channels=64) self.training=True self.feature_conv= nn.Conv2d(3584,46,3,stride=1, padding= 1) self.bn_f= nn.BatchNorm2d(46) self.bn_fd= nn.BatchNorm2d(46) self.bn_44= nn.BatchNorm2d(368) self.bn_l0= nn.BatchNorm2d(1472) self.bn_l1= nn.BatchNorm2d(736) self.bn_l2= nn.BatchNorm2d(92) self.initialize_weights() # Stem layer self.stem = CMTStem(in_channels, stem_channel) # Patch Aggregation Layer self.patch1 = Patch_Aggregate(stem_channel, patch_channel[0]) self.patch2 = Patch_Aggregate(patch_channel[0], patch_channel[1]) self.patch3 = Patch_Aggregate(patch_channel[1], patch_channel[2]) self.patch4 = Patch_Aggregate(patch_channel[2], patch_channel[3]) # CMT Block Layer stage1 = [] for _ in range(block_layer[0]): cmt_layer = CMTBlock( img_size = 56, stride = 8, d_k = cmt_channel[0], d_v = cmt_channel[0], num_heads = 1, R = R, in_channels = patch_channel[0] ) stage1.append(cmt_layer) self.stage1 = nn.Sequential(*stage1) stage2 = [] for _ in range(block_layer[1]): cmt_layer = CMTBlock( img_size = size[1], stride = 4, d_k = cmt_channel[1] // 2, d_v = cmt_channel[1] // 2, num_heads = 2, R = R, in_channels = patch_channel[1] ) stage2.append(cmt_layer) self.stage2 = nn.Sequential(*stage2) stage3 = [] for _ in range(block_layer[2]): cmt_layer = CMTBlock( img_size = size[2], stride = 2, d_k = cmt_channel[2] // 4, d_v = cmt_channel[2] // 4, num_heads = 4, R = R, in_channels = patch_channel[2] ) stage3.append(cmt_layer) self.stage3 = nn.Sequential(*stage3) stage4 = [] for _ in range(block_layer[3]): cmt_layer = CMTBlock( img_size = size[3], stride = 1, d_k = cmt_channel[3] // 8, d_v = cmt_channel[3] // 8, num_heads = 8, R = R, in_channels = patch_channel[3] ) stage4.append(cmt_layer) self.stage4 = nn.Sequential(*stage4) # Global Average Pooling self.avg_pool = nn.AdaptiveAvgPool2d(1) # FC self.fc = nn.Sequential( nn.Linear(cmt_channel[3], 1280), nn.ReLU(inplace = True), ) self.last_conv=nn.Conv2d(3044,1472,1,1,0) self.last_conv1=nn.Conv2d(1472,736,1,1,0) self.last_conv2=nn.Conv2d(736,92,1,1,0) self.last_conv3=nn.Conv2d(92,1,1,1,0) #self.last_conv1=nn.Conv2d(56,1,1,1,0) # Final Classifier self.classifier = nn.Linear(1280, num_class) self.edge_layer = Edge_Module() self.edge_feature = conv3x3_bn_relu(1, 32) self.up_edge = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor = 2), conv3x3(32, 1) ) self.up_edge4 = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor = 4), conv3x3(32, 1) ) self.relu = nn.ReLU(True) self.up4 = nn.UpsamplingBilinear2d(scale_factor = 4) self.conv1381_1 = conv3x3(1381, 1) self.b4=BottleNeck(368,368,4) self.b3=BottleNeck(184,184,4) self.b2=BottleNeck(92,92,2) self.b1=BottleNeck(46,46,2) self.conv_dec4= nn.Conv2d(736,368,1,1,0) self.conv_dec3= nn.Conv2d(1472,368,1,1,0) self.conv_dec2= nn.Conv2d(2024,368,1,1,0) self.conv_dec1= nn.Conv2d(2484,368,1,1,0) self.sigmoid = nn.Sigmoid() self.convv4= nn.Conv2d(1104,1,3,1,1) self.convv3= nn.Conv2d(1840,1,3,1,1) self.convv2= nn.Conv2d(2392,1,3,1,1) self.convv1= nn.Conv2d(2852,1,3,1,1) def forward(self, x,x_depth): x = self.resnet.conv1(x) x = self.resnet.bn1(x) x = self.resnet.relu(x) #112 x1 = self.resnet.layer1(x) #112 low_concatenated= torch.cat((x,x1),1) low_conv1= self.low_conv1(low_concatenated) low_conv1= self.bn_low(low_conv1) low_conv1= nn.functional.relu(low_conv1) spa_out=self.spa_att(low_conv1) #112 #spa on rgb image spa_out= torch.mul(spa_out,low_conv1) #conv_12_feats = torch.mul(conv_12_feats, conv_12_sa) x2 = self.resnet.layer2(x1) x2_1=x2 #56 x3_1 = self.resnet.layer3_1(x2_1) #28 x4_1 = self.resnet.layer4_1(x3_1) #14 #Concatenating layer3,4 and 5 (both rgb ) feature_out= torch.cat((x2,self.upsample1(x3_1),self.upsample2(x4_1)),1) #56 #Resnet over #3584, 56, 56 feature_out_conv= self.feature_conv(feature_out) #5,46,28,28 feature_out_conv= self.bn_f(feature_out_conv) feature_out_conv= nn.functional.relu(feature_out_conv) x11 = self.stage1(feature_out_conv) # 2CMT blocks done in first level x22 = self.patch2(x11) x22 = self.stage2(x22) # 2nd level- 2 blocks x33 = self.patch3(x22) x33 = self.stage3(x33) #3rd level- 2 blocks , 10 in original paper x44 = self.patch4(x33) x44 = self.stage4(x44) #4th level- 2 blocks #DEPTH BRANCH x_depth = self.resnet_depth.conv1(x_depth) x_depth = self.resnet_depth.bn1(x_depth) x_depth = self.resnet_depth.relu(x_depth) x1_depth=self.resnet_depth.layer1(x_depth) edge_map = self.edge_layer(x_depth,x1_depth) #print(edge_map.shape) #(1, 112, 112) edge_feature = self.edge_feature(edge_map) # print(edge_feature.shape) #32, 112, 112 #end_sal = self.conv256_32(fuse_fea4) up_edge = self.up_edge(edge_feature) low_concatenated_depth= torch.cat((x_depth,x1_depth),1) low_conv1_depth= self.low_conv1(low_concatenated_depth) low_conv1_depth= self.bn_low(low_conv1_depth) low_conv1_depth= nn.functional.relu(low_conv1_depth) spa_out_depth=self.spa_att(low_conv1_depth) #112 # spa on depth spa_out_depth= torch.mul(spa_out_depth,low_conv1_depth) spa_final= torch.mul(spa_out,spa_out_depth) spa_final1= torch.cat((spa_final,spa_out,spa_out_depth),1) x2_depth=self.resnet_depth.layer2(x1_depth) x3_1_depth=self.resnet_depth.layer3_1(x2_depth) x4_1_depth=self.resnet_depth.layer4_1(x3_1_depth) feature_out_depth= torch.cat((x2_depth,self.upsample1(x3_1_depth),self.upsample2(x4_1_depth)),1) feature_out_conv_depth= self.feature_conv(feature_out_depth) feature_out_conv_depth= self.bn_fd(feature_out_conv_depth) feature_out_conv_depth= nn.functional.relu(feature_out_conv_depth) x11_depth = self.stage1(feature_out_conv_depth) # 2CMT blocks done in first level x22_depth = self.patch2( x11_depth) x22_depth = self.stage2(x22_depth) x33_depth = self.patch3(x22_depth) x33_depth = self.stage3(x33_depth) x44_depth = self.patch4(x33_depth) x44_depth = self.stage4(x44_depth) mul_4= torch.mul(x44,x44_depth) # bottlemeck block #BottleNeck(32, 64)(x).shape a=self.b4(mul_4) b=self.b4(a) cmf4=b+mul_4 mul_3= torch.mul(x33,x33_depth) #184,14,14 # bottlemeck block #BottleNeck(32, 64)(x).shape a=self.b3(mul_3) b=self.b3(a) cmf3=b+mul_3 mul_2= torch.mul(x22,x22_depth) #368,7,7 # bottlemeck block #BottleNeck(32, 64)(x).shape a=self.b2(mul_2) b=self.b2(a) cmf2=b+mul_2 mul_1= torch.mul(x11,x11_depth) #184,14,14 # bottlemeck block #BottleNeck(32, 64)(x).shape a=self.b1(mul_1) b=self.b1(a) cmf1=b+mul_1 #stage4_down= self.conv_stage4(x44) #cmf4_down= self.conv_stage4(cmf4) dec4= torch.cat((x44,cmf4),1) # shape is (736,7,7), need to upsample to 14,14 dec4_conv= self.conv_dec4(dec4) #368 dec4_conv= self.bn_44(dec4_conv) dec4_conv= nn.functional.relu(dec4_conv) dec4_conv= torch.cat((dec4,dec4_conv),1) #1104,7,7 dec4= self.upsample1(dec4_conv) #(1104,14,14) #dec44= self.convv4(dec4) dec3= torch.cat((x33,cmf3,dec4),1) #shape is (1472,14,14) dec3_conv= self.conv_dec3(dec3) # 368 dec3_conv= self.bn_44(dec3_conv) dec3_conv= nn.functional.relu(dec3_conv) dec3_conv= torch.cat((dec3,dec3_conv),1) #(1840,14,14) dec3= self.upsample1(dec3_conv) #(1840,28,28) #dec33= self.convv3(dec3) dec2= torch.cat((x22,cmf2,dec3),1) #shape is (2024,28,28) dec2_conv= self.conv_dec2(dec2) #368 dec2_conv= self.bn_44(dec2_conv) dec2_conv= nn.functional.relu(dec2_conv) dec2_conv= torch.cat((dec2,dec2_conv),1) #(2392,28,28) dec2= self.upsample1(dec2_conv) #(2392,56,56) #dec22= self.convv2(dec2) dec1= torch.cat((x11,cmf1,dec2),1) #shape is (2484,56,56) dec1_conv= self.conv_dec1(dec1) #368 dec1_conv= self.bn_44(dec1_conv) dec1_conv= nn.functional.relu(dec1_conv) dec1_conv= torch.cat((dec1,dec1_conv),1) #(2852,56,56) dec1= dec1_conv dec1= self.upsample1(dec1) #dec11= self.convv1(dec1) spa_con=torch.cat((dec1,spa_final1),1) # 3044 out= self.last_conv(spa_con) out= self.bn_l0(out) out= nn.functional.relu(out) #out=self.bn_last(out) #out=nn.functional.relu(out) #print(out.shape) out= self.last_conv1(out) out= self.bn_l1(out) out= nn.functional.relu(out) out= self.last_conv2(out) out= self.bn_l2(out) out= nn.functional.relu(out) out= self.last_conv3(out) #print(out.shape) out= self.upsample1(out) #print(out.shape) #********************************** return out, up_edge, self.sigmoid(out) #return out, up_edge, dec11,dec22,dec33,dec44, self.sigmoid(out), self.sigmoid(dec11), self.sigmoid(dec22), self.sigmoid(dec33), self.sigmoid(dec44) def initialize_weights(self): res50 = models.resnet50(pretrained=True) pretrained_dict = res50.state_dict() all_params = {} for k, v in self.resnet.state_dict().items(): if k in pretrained_dict.keys(): v = pretrained_dict[k] all_params[k] = v elif '_1' in k: name = k.split('_1')[0] + k.split('_1')[1] v = pretrained_dict[name] all_params[k] = v elif '_2' in k: name = k.split('_2')[0] + k.split('_2')[1] v = pretrained_dict[name] all_params[k] = v assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) self.resnet.load_state_dict(all_params) all_params = {} for k, v in self.resnet_depth.state_dict().items(): if k=='conv1.weight': all_params[k]=torch.nn.init.normal_(v, mean=0, std=1) elif k in pretrained_dict.keys(): v = pretrained_dict[k] all_params[k] = v elif '_1' in k: name = k.split('_1')[0] + k.split('_1')[1] v = pretrained_dict[name] all_params[k] = v elif '_2' in k: name = k.split('_2')[0] + k.split('_2')[1] v = pretrained_dict[name] all_params[k] = v assert len(all_params.keys()) == len(self.resnet_depth.state_dict().keys()) self.resnet_depth.load_state_dict(all_params)