# ------------------------------------------------------------------------------ # Written by Jiacong Xu (jiacong.xu@tamu.edu) # ------------------------------------------------------------------------------ import torch import torch.nn as nn import torch.nn.functional as F from time import time import numpy as np import sys sys.path.insert(0, './models/') from pidnet_utils import BasicBlock, Bottleneck, segmenthead, DAPPM, PAPPM, PagFM, Bag, Light_Bag import math from torchsummary import summary import os from transformers import Swinv2Config from Swin2 import Swinv2Model, Swinv2Stage, Swinv2PatchMerging, Swinv2Embeddings, LambdaLayer from einops.layers.torch import Rearrange BatchNorm2d = nn.BatchNorm2d from pidnet import PIDNet bn_mom = 0.1 algc = False modality_paths = True short_path = True swin_blocks = True class ChannelAttentionModule(nn.Module): def __init__(self, in_channels, out_channels, reduction=16): super(ChannelAttentionModule, self).__init__() reduced_channels = max(1, in_channels // reduction) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.channel_attention = nn.Sequential( nn.Linear(in_channels, reduced_channels), nn.ReLU(inplace=True), nn.Linear(reduced_channels, in_channels), nn.Sigmoid() ) # Convolution to map to output channels self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) def forward(self, x): b, c, _, _ = x.size() avg_pooled = self.avg_pool(x).view(b, c) max_pooled = self.max_pool(x).view(b, c) avg_weights = self.channel_attention(avg_pooled) max_weights = self.channel_attention(max_pooled) channel_weights = (avg_weights + max_weights).view(b, c, 1, 1) x = x * channel_weights x = self.conv(x) return x def drop_path(x, drop_prob=0.2, training=True): if drop_prob == 0 or not training: return x batch_size = x.shape[0] keep_prob = 1 - drop_prob random_tensor = keep_prob + torch.rand(batch_size, 1, 1, 1, device=x.device) binary_mask = random_tensor.floor() x = x / keep_prob * binary_mask return x class RoboFireFuseNet(nn.Module): def __init__(self, m=2, n=3, num_classes=19, planes=64, ppm_planes=96, head_planes=128, augment=True, channels=3, input_resolution=(480, 640), window_size=(10, 5), tf_depths=(2, 6)): super(RoboFireFuseNet, self).__init__() self.augment = augment self.channels = channels self.window_size = window_size self.planes = planes input_resolution = np.array(input_resolution) self.norm = 'group' self.drop_paths_modalities = [0.2, 0.2, 0] if num_classes==3 else [0, 0, 0] # rgb, ir, fusion self.drop_paths_shortcuts = [0.15, 0.15, 0.15, 0.15] if num_classes==3 else [0, 0, 0, 0] # 0, 1, 2 # I Branch self.conv1_rgb_0 = nn.Sequential( nn.Conv2d(3,planes, kernel_size=3, stride=1, padding=1), nn.GroupNorm(8, planes) if self.norm=='group' else BatchNorm2d(planes), nn.ReLU(inplace=True), ) self.conv1_rgb_1 = nn.Sequential( nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1), nn.GroupNorm(8, planes) if self.norm=='group' else BatchNorm2d(planes), nn.ReLU(inplace=True), ) self.conv1_rgb_2 = nn.Sequential( nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1), BatchNorm2d(planes, momentum=bn_mom), nn.ReLU(inplace=True), ) self.conv1_ir_0 = nn.Sequential( nn.Conv2d(1,planes, kernel_size=3, stride=1, padding=1), nn.GroupNorm(8, planes) if self.norm=='group' else BatchNorm2d(planes), nn.ReLU(inplace=True), ) self.conv1_ir_1 = nn.Sequential( nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1), nn.GroupNorm(8, planes) if self.norm=='group' else BatchNorm2d(planes), nn.ReLU(inplace=True), ) self.conv1_ir_2 = nn.Sequential( nn.Conv2d(planes,planes,kernel_size=3, stride=2, padding=1), BatchNorm2d(planes, momentum=bn_mom), nn.ReLU(inplace=True), ) self.tf_conv0 = nn.Sequential( nn.Conv2d(2*planes,4*planes,kernel_size=1), nn.GroupNorm(16, 4*planes) if self.norm=='group' else BatchNorm2d(4*planes), nn.ReLU(inplace=True), ) self.tf_conv1 = nn.Sequential( nn.Conv2d(2*planes,4*planes,kernel_size=1), nn.GroupNorm(16, 4*planes) if self.norm=='group' else BatchNorm2d(4*planes), nn.ReLU(inplace=True), ) self.tf_conv2 = nn.Sequential( nn.Conv2d(2*planes,4*planes,kernel_size=1), nn.GroupNorm(16, 4*planes) if self.norm=='group' else BatchNorm2d(4*planes), nn.ReLU(inplace=True), ) self.tf_conv3 = nn.Sequential( nn.Conv2d(4*planes,4*planes,kernel_size=1), nn.GroupNorm(16, 4*planes) if self.norm=='group' else BatchNorm2d(4*planes), nn.ReLU(inplace=True), ) self.relu = nn.ReLU(inplace=True) self.layer1_rgb = self._make_layer(BasicBlock, planes, planes, m) self.layer2_rgb = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2) self.layer1_ir = self._make_layer(BasicBlock, planes, planes, m) self.layer2_ir = self._make_layer(BasicBlock, planes, planes * 2, m, stride=2) self.layer3_rgb = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2) self.layer4_rgb = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2) self.layer5_rgb = self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2) self.layer3_ir = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2) self.layer4_ir = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2) self.layer5_ir = self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2) self.weight_channels_pre = ChannelAttentionModule(self.planes * 2 * 2, self.planes * 2) self.weight_channels_post = ChannelAttentionModule(3 * self.planes * 16, self.planes * 16, 32) tf_configuration = Swinv2Config(window_size=window_size, image_size=input_resolution, num_channels=self.planes * 2) self.stage3 = Swinv2Stage( config=tf_configuration, dim=int(tf_configuration.embed_dim * 2**1), input_resolution=(tf_configuration.image_size[0] // (2**1), tf_configuration.image_size[1] // (2**1)), depth=tf_depths[0], num_heads=tf_configuration.num_heads[1], drop_path=0.05, downsample=Swinv2PatchMerging, pretrained_window_size=tf_configuration.pretrained_window_sizes[1], ) self.stage4 = Swinv2Stage( config=tf_configuration, dim=int(tf_configuration.embed_dim * 2**2), input_resolution=(tf_configuration.image_size[0] // (2**2), tf_configuration.image_size[1] // (2**2)), depth=tf_depths[1], num_heads=tf_configuration.num_heads[2], drop_path=0.05, downsample=Swinv2PatchMerging, pretrained_window_size=tf_configuration.pretrained_window_sizes[2], ) self.layer3 = torch.nn.Sequential(Swinv2Embeddings(Swinv2Config(window_size=window_size[0], num_channels=2*planes, patch_size=1, embed_dim=96 * 2)), \ LambdaLayer(lambda xinp: self.stage3(xinp[0], xinp[1])) ) self.layer3_unpatch = torch.nn.Sequential(LambdaLayer(lambda xinp: Rearrange('b (h w) c-> b c h w', h=xinp[2][-2]).forward(xinp[0])), torch.nn.Conv2d(96 * 4, 4 * planes, kernel_size=1)) self.layer4 = torch.nn.Sequential(LambdaLayer(lambda xinp: self.stage4(xinp[0], xinp[2][-2:])), \ LambdaLayer(lambda xinp: Rearrange('b (h w) c-> b c h w', h=xinp[2][-2]).forward(xinp[0])), torch.nn.Conv2d(96 * 8, 8 * planes, kernel_size=1)) if not swin_blocks: self.layer3 = self._make_layer(BasicBlock, planes * 2, planes * 4, n, stride=2) self.layer3_unpatch = nn.Identity() self.layer4 = self._make_layer(BasicBlock, planes * 4, planes * 8, n, stride=2) self.layer5 = self._make_layer(Bottleneck, planes * 8, planes * 8, 2, stride=2) # P Branch self.compression3 = nn.Sequential( nn.Conv2d(planes * 4, planes * 2, kernel_size=1, bias=False), BatchNorm2d(planes * 2, momentum=bn_mom), ) self.compression4 = nn.Sequential( nn.Conv2d(planes * 8, planes * 2, kernel_size=1, bias=False), BatchNorm2d(planes * 2, momentum=bn_mom), ) self.pag3 = PagFM(planes * 2, planes) self.pag4 = PagFM(planes * 2, planes) self.layer3_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m) self.layer4_ = self._make_layer(BasicBlock, planes * 2, planes * 2, m) self.layer5_ = self._make_layer(Bottleneck, planes * 2, planes * 2, 1) # D Branch if m == 2: self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes) self.layer4_d = self._make_layer(Bottleneck, planes, planes, 1) self.diff3 = nn.Sequential( nn.Conv2d(planes * 4, planes, kernel_size=3, padding=1, bias=False), BatchNorm2d(planes, momentum=bn_mom), ) self.diff4 = nn.Sequential( nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False), BatchNorm2d(planes * 2, momentum=bn_mom), ) self.spp = PAPPM(planes * 16, ppm_planes, planes * 4) self.dfm = Light_Bag(planes * 4, planes * 4) else: self.layer3_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2) self.layer4_d = self._make_single_layer(BasicBlock, planes * 2, planes * 2) self.diff3 = nn.Sequential( nn.Conv2d(planes * 4, planes * 2, kernel_size=3, padding=1, bias=False), BatchNorm2d(planes * 2, momentum=bn_mom), ) self.diff4 = nn.Sequential( nn.Conv2d(planes * 8, planes * 2, kernel_size=3, padding=1, bias=False), BatchNorm2d(planes * 2, momentum=bn_mom), ) self.spp = DAPPM(planes * 16, ppm_planes, planes * 4) self.dfm = Bag(planes * 4, planes * 4) self.layer5_d = self._make_layer(Bottleneck, planes * 2, planes * 2, 1) # Prediction Head if self.augment: self.seghead_p = segmenthead(planes * 2, head_planes, num_classes) self.seghead_d = segmenthead(planes * 2, planes, 1) self.final_layer = segmenthead(planes * 4, head_planes, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), ) layers = [] layers.append(block(inplanes, planes, stride, downsample)) inplanes = planes * block.expansion for i in range(1, blocks): if i == (blocks-1): layers.append(block(inplanes, planes, stride=1, no_relu=True)) else: layers.append(block(inplanes, planes, stride=1, no_relu=False)) return nn.Sequential(*layers) def _make_single_layer(self, block, inplanes, planes, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), ) layer = block(inplanes, planes, stride, downsample, no_relu=True) return layer def imgnet_pretrain(self, path): pretrained_state = torch.load(path, map_location='cpu') if 'state_dict' in pretrained_state.keys(): pretrained_state = pretrained_state['state_dict'] elif 'model_state_dict' in pretrained_state.keys(): pretrained_state = pretrained_state['model_state_dict'] model_dict = self.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)} model_dict.update(pretrained_state) msg = 'PIDnet: Loaded {}% parameters!'.format(len(pretrained_state)*100/len(model_dict)) self.load_state_dict(model_dict, strict = False) print(msg) def forward(self, x): # layer0 rgb x_rgb_0 = self.conv1_rgb_0(x[:, :3]) #/1 x_rgb_1 = self.conv1_rgb_1(x_rgb_0) #/2 x_rgb_2 = self.conv1_rgb_2(x_rgb_1) #/4 # layer1 rgb x_rgb_2 = self.layer1_rgb(x_rgb_2) #/4 # layer2 rgb x_rgb_3 = self.relu(self.layer2_rgb(self.relu(x_rgb_2))) #/8 x_rgb = x_rgb_3 # layer0 ir x_irinp = x[:, -1:] x_ir_0 = self.conv1_ir_0(x_irinp) x_ir_1 = self.conv1_ir_1(x_ir_0) x_ir_2 = self.conv1_ir_2(x_ir_1) # layer1 ir x_ir_2 = self.layer1_ir(x_ir_2) # layer2 ir x_ir_3 = self.relu(self.layer2_ir(self.relu(x_ir_2))) x_ir = x_ir_3 x = self.weight_channels_pre(torch.cat((x_rgb, x_ir), dim=1)) x_ = self.layer3_(x) x_d = self.layer3_d(x) width_output = x_d.shape[-1] height_output = x_d.shape[-2] x_patched = self.layer3(x) x = self.layer3_unpatch(x_patched) x_rgb = self.relu(self.layer3_rgb(x_rgb)) x_ir = self.relu(self.layer3_ir(x_ir)) x_ = self.pag3(x_, self.compression3(x)) x_d = x_d + F.interpolate( self.diff3(x), size=[height_output, width_output], mode='bilinear', align_corners=algc) if self.augment: temp_p = x_ x = self.layer4(x_patched) x_rgb = self.relu(self.layer4_rgb(x_rgb)) x_ir = self.relu(self.layer4_ir(x_ir)) x_ = self.layer4_(self.relu(x_)) x_d = self.layer4_d(self.relu(x_d)) x_ = self.pag4(x_, self.compression4(x)) x_d = x_d + F.interpolate( self.diff4(x), size=[height_output, width_output], mode='bilinear', align_corners=algc) if self.augment: temp_d = x_d x_ = self.layer5_(self.relu(x_)) x_d = self.layer5_d(self.relu(x_d)) if modality_paths: path_id = torch.randint(0, 3, (1,), device=x.device).item() x = self.weight_channels_post(torch.cat([drop_path(self.layer5(x), self.drop_paths_modalities[2] * (0 if path_id == 2 else 1), self.training), \ drop_path(self.layer5_rgb(x_rgb), self.drop_paths_modalities[0] * (0 if path_id == 0 else 1), self.training), \ drop_path(self.layer5_ir(x_ir), self.drop_paths_modalities[1]* (0 if path_id == 1 else 1), self.training)], dim=1)) else: x = self.layer5(x) x = F.interpolate( self.spp(x), size=[height_output, width_output], mode='bilinear', align_corners=algc) x_ = self.dfm(x_, x, x_d) if short_path: x_ = F.interpolate(x_, size=x_rgb_3.shape[-2:], mode='bilinear', align_corners=algc) + \ drop_path(self.tf_conv3(torch.cat((x_rgb_3, x_ir_3), dim=1)), self.drop_paths_shortcuts[3], self.training) x_ = F.interpolate(x_, size=x_rgb_2.shape[-2:], mode='bilinear', align_corners=algc) + \ drop_path(self.tf_conv2(torch.cat((x_rgb_2, x_ir_2), dim=1)), self.drop_paths_shortcuts[2], self.training) x_ = F.interpolate(x_, size=x_rgb_1.shape[-2:], mode='bilinear', align_corners=algc) + \ drop_path(self.tf_conv1(torch.cat((x_rgb_1, x_ir_1), dim=1)), self.drop_paths_shortcuts[1], self.training) x_ = F.interpolate(x_, size=x_rgb_0.shape[-2:], mode='bilinear', align_corners=algc) + \ drop_path(self.tf_conv0(torch.cat((x_rgb_0, x_ir_0), dim=1)), self.drop_paths_shortcuts[0], self.training) x_ = self.final_layer(x_) if self.augment: x_extra_p = self.seghead_p(temp_p) x_extra_d = self.seghead_d(temp_d) return [x_extra_p, x_, x_extra_d] else: return x_ def save_model(self, path, epoch): os.makedirs(f'{path}', exist_ok=True) torch.save(self.state_dict(), os.path.join(path, f'Epoch{epoch}.pt')) def custom_pretrained(input_res, num_classes, depths, windows_size): model = RoboFireFuseNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=True, channels=4, input_resolution=input_res, window_size=(windows_size, windows_size), tf_depths=depths) configuration = Swinv2Config(image_size = input_res, window_size=windows_size, num_channels=3, depths=[2, *depths, 2]) tf_model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256", config=configuration, ignore_mismatched_sizes=True) model.stage3.load_state_dict(tf_model.encoder.layers[1].state_dict()) model.stage4.load_state_dict(tf_model.encoder.layers[2].state_dict()) pid_pretrained = torch.load('./weights/PIDNet_S_ImageNet.pth.tar')['state_dict'] model_pid = PIDNet(2, 3, 19, 32, 96, 128, True, 3) model_pid.load_state_dict(pid_pretrained, False) for i in range(3): msg = model.conv1_ir_1[i].load_state_dict(model_pid.conv1[3+i].state_dict(), strict=True) msg = model.conv1_ir_2[i].load_state_dict(model_pid.conv1[3+i].state_dict(), strict=True) msg = model.conv1_rgb_1[i].load_state_dict(model_pid.conv1[3+i].state_dict(), strict=True) msg = model.conv1_rgb_2[i].load_state_dict(model_pid.conv1[3+i].state_dict(), strict=True) msg = model.layer1_ir.load_state_dict(model_pid.layer1.state_dict(), strict=True) msg = model.layer1_rgb.load_state_dict(model_pid.layer1.state_dict(), strict=True) msg = model.layer2_ir.load_state_dict(model_pid.layer2.state_dict(), strict=True) msg = model.layer2_rgb.load_state_dict(model_pid.layer2.state_dict(), strict=True) msg = model.layer3_ir.load_state_dict(model_pid.layer3.state_dict(), strict=True) msg = model.layer3_rgb.load_state_dict(model_pid.layer3.state_dict(), strict=True) msg = model.layer4_ir.load_state_dict(model_pid.layer4.state_dict(), strict=True) msg = model.layer4_rgb.load_state_dict(model_pid.layer4.state_dict(), strict=True) msg = model.layer5_ir.load_state_dict(model_pid.layer5.state_dict(), strict=True) msg = model.layer5_rgb.load_state_dict(model_pid.layer5.state_dict(), strict=True) msg = model.layer3_.load_state_dict(model_pid.layer3_.state_dict(), strict=True) msg = model.layer3_d.load_state_dict(model_pid.layer3_d.state_dict(), strict=True) msg = model.layer4_.load_state_dict(model_pid.layer4_.state_dict(), strict=True) msg = model.layer4_d.load_state_dict(model_pid.layer4_d.state_dict(), strict=True) msg = model.layer5_.load_state_dict(model_pid.layer5_.state_dict(), strict=True) msg = model.layer5_d.load_state_dict(model_pid.layer5_d.state_dict(), strict=True) msg = model.compression3.load_state_dict(model_pid.compression3.state_dict(), strict=True) msg = model.compression4.load_state_dict(model_pid.compression4.state_dict(), strict=True) msg = model.seghead_d.load_state_dict(model_pid.seghead_d.state_dict(), strict=True) msg = model.pag3.load_state_dict(model_pid.pag3.state_dict(), strict=True) msg = model.pag4.load_state_dict(model_pid.pag4.state_dict(), strict=True) msg = model.diff3.load_state_dict(model_pid.diff3.state_dict(), strict=True) msg = model.diff4.load_state_dict(model_pid.diff4.state_dict(), strict=True) torch.save(model.state_dict(), 'pretrained_480x640_w8_2_6.pth') return model from time import time if __name__ == '__main__': device = 'cuda:0' windows_size = 8 input_res = (256, 256) num_classes = 3 depths= [2, 6] model = RoboFireFuseNet(m=2, n=3, num_classes=num_classes, planes=32, ppm_planes=96, head_planes=128, augment=True, channels=4, input_resolution=input_res, window_size=(windows_size, windows_size), tf_depths=depths).to(device) model.eval() input = torch.randn(1, 4, *input_res).to(device) avg = 0 for i in range(100): t0 = time() tmp = model(torch.randn(1, 4, *input_res).to(device)) t1 = time() if i > 20: avg += (t1 - t0)/ 80 del(tmp) print(avg)