SalFormer_GitFront / ttrain.py
ttrain.py
Raw
# -*- coding: utf-8 -*-
"""ttrain.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1uwgolSLC-Ubd0o3PPxPFB7LnZhXYmQxr
"""

! pip install einops

! pip install timm

! pip install tensorboardX

load=None

lr= 1e-4

import os
import torch
import torch.nn.functional as F
import sys

import numpy as np
from datetime import datetime
from torchvision.utils import make_grid
from my_model_19 import IR_Net
from data_edge import get_loader,test_dataset
from utils_new import clip_gradient, adjust_lr
from tensorboardX import SummaryWriter
import logging
import torch.backends.cudnn as cudnn
import ResNet
from sa import SpatialAttention

gpu_id=0

#set the device for training
if gpu_id=='0':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('USE GPU 0')
elif gpu_id=='1':
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    print('USE GPU 1')
cudnn.benchmark = True

# Commented out IPython magic to ensure Python compatibility.
import torch
from torch import nn
from torch import functional as F
from torch import optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
import os
from pytorch_iou import IOU

# %matplotlib inline

import os
os.makedirs('train_file')
os.makedirs('test_file')
os.makedirs('test_in_train_file')

save_path='/content/train_file/RGBD_for_train/depth_after_HHA/'
if not os.path.exists(save_path):
      os.makedirs(save_path)

save_path1='/content/test_in_train_file/test_in_train/depth_after_HHA/'
if not os.path.exists(save_path1):
      os.makedirs(save_path1)

edge_path_train='/content/train_file/RGBD_for_train/edge_maps/'

if not os.path.exists(edge_path_train):
      os.makedirs(edge_path_train)

edge_path_val='/content/test_in_train_file/test_in_train/edge_maps/'

if not os.path.exists(edge_path_val):
      os.makedirs(edge_path_val)

# Unzipping training  edge maps
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/edge_training_Canny_fromGTsalmap.zip" -d "/content/train_file/RGBD_for_train/edge_maps/"

# Unzipping training  edge maps
#! unzip "/content/drive/MyDrive/Colab Notebooks/image_retargeting/Dataset_fromBBS_Net/edge_training_Canny_fromGTsalmap.zip" -d "/content/train_file/RGBD_for_train/edge_maps/"
# Unzipping training images ok
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/RGBD_for_train.zip" -d "/content/train_file/"
# Unzipping training images new depth maps ok
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/train_depth_after_HHA.zip" -d "/content/train_file/RGBD_for_train/depth_after_HHA/"
#unzipping validation images ok
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/test_in_train.zip" -d "/content/test_in_train_file/"
# Unzipping validation images new depth maps ok 
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/val_depth_after_HHA.zip" -d "/content/test_in_train_file/test_in_train/depth_after_HHA/"
#unzipping test images ok 
! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/RGBD_for_test.zip" -d "/content/test_file/"

model= IR_Net(in_channels = 3,
        
        cmt_channel = [46, 92, 184, 368],
        patch_channel = [46, 92, 184, 368],
        block_layer = [2, 2, 10, 2],
        R = 3.6,
        img_size = 224)

print(model)

model.cuda()

params = model.parameters()
optimizer = torch.optim.Adam(params, lr)

for parameter in model.parameters():
        parameter.requires_grad = False
model.cuda()
model.eval()
dim=(224,224)

save_path='/content/test_file/RGBD_for_test/NLPR/depth_HHA/'
if not os.path.exists(save_path):
      os.makedirs(save_path)

! unzip "/content/drive/MyDrive/Colab Notebooks/RGBD_SOD/Datasets/depth_HHA_NLPR.zip" -d "/content/test_file/RGBD_for_test/NLPR/depth_HHA/"

model.cuda()

params = model.parameters()
optimizer = torch.optim.Adam(params, lr)

#set the path
image_root ='/content/train_file/RGBD_for_train/RGB/'
gt_root = '/content/train_file/RGBD_for_train/GT/'
depth_root='/content/train_file/RGBD_for_train/depth_after_HHA/'
edge_root= '/content/train_file/RGBD_for_train/edge_maps/'

# validation data path
test_image_root='/content/test_in_train_file/test_in_train/RGB/'
test_gt_root='/content/test_in_train_file/test_in_train/GT/'
test_depth_root='/content/test_in_train_file/test_in_train/depth_after_HHA/'
#val_edge_root= '/content/test_in_train_file/test_in_train/edge_maps/'
save_path_cpoints='/content/chk_points/'

if not os.path.exists(save_path_cpoints):
    os.makedirs(save_path_cpoints)

batchsize=5
trainsize= 224
epoch=125
clip=0.5
decay_rate=0.1
decay_epoch=50

#load data

print('load data...')
train_loader = get_loader(image_root, gt_root,depth_root,  batchsize=batchsize, trainsize=trainsize)
test_loader = test_dataset(test_image_root, test_gt_root,test_depth_root, trainsize)
total_step = len(train_loader)
print("Loading done")

logging.basicConfig(filename=save_path_cpoints+'log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p')
logging.info("BBSNet-Train")
logging.info("Config")
logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};save_path:{};decay_epoch:{}'.format(epoch,lr,batchsize,trainsize,clip,decay_rate,save_path_cpoints,decay_epoch))

#set loss function
#CE = torch.nn.BCEWithLogitsLoss().cuda()

CE = torch.nn.BCEWithLogitsLoss()
IOU = IOU(size_average = True)
#SSIM= SSIM(size_average = True)

#ECE = torch.nn.BCELoss()
#CE_levels = torch.nn.BCELoss().cuda()

import torch.nn.functional as F

criterion= torch.nn.BCEWithLogitsLoss()

## train function with edge maps added, no side supervision

step=0
writer = SummaryWriter(save_path_cpoints+'summary')
best_mae=1
best_epoch=0
clip=0.5

#train function
def train(train_loader, model, optimizer, epoch,save_path):
    global step
    model.cuda()
    model.train()
    
    sal_loss_all = 0
    
    loss_all = 0
    epoch_step = 0

    try:
        for i, (images, gts, depths) in enumerate(train_loader, start=1):
            optimizer.zero_grad()
            
            images = images.cuda()
            gts = gts.cuda()
            depths=depths.cuda()
            
          
            
            
            s1, s1_sig = model(images,depths)
           
            sal_loss= CE(s1,gts) +IOU(s1_sig, gts)
            
            
            
            loss = sal_loss 
            loss.backward()
         

            clip_gradient(optimizer, clip)
            optimizer.step()
            step+=1
            epoch_step+=1
            loss_all+=float(loss.data)
            if i % 100 == 0 or i == total_step or i==1:
                print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '.
                    format(datetime.now(), epoch, epoch, i, total_step, loss.data))
                logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '.
                    format( epoch, epoch, i, total_step, loss.data))
                writer.add_scalar('Loss', loss.data, global_step=step)
                grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True)
                writer.add_image('RGB', grid_image, step)
                grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True)
                writer.add_image('Ground_truth', grid_image, step)
                res=s1[0].clone()
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                writer.add_image('s1', torch.tensor(res), step,dataformats='HW')
               
        
        loss_all/=epoch_step
        logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, epoch, loss_all))
        writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
        #if (epoch) % 5 == 0:
         #   torch.save(model.state_dict(), save_path+'IR_Net_epoch_{}.pth'.format(epoch))
    except KeyboardInterrupt: 
        #print('Keyboard Interrupt: save model and exit.')
        #if not os.path.exists(save_path):
         #   os.makedirs(save_path)
        #torch.save(model.state_dict(), save_path+'IR_Net_epoch_{}.pth'.format(epoch+1))
        #print('save checkpoints successfully!')
        raise

#train function with side supervision 
step=0
writer = SummaryWriter(save_path_cpoints+'summary')
best_mae=1
best_epoch=0
clip=0.5

#train function
def train(train_loader, model, optimizer, epoch,save_path):
    global step
    model.cuda()
    model.train()
    
    loss_all=0
    epoch_step=0
    try:
        for i, (images, gts, depths) in enumerate(train_loader, start=1):
            optimizer.zero_grad()
            
            images = images.cuda()
            gts = gts.cuda()
            depths=depths.cuda()
          
            
            
            gts4= nn.functional.interpolate(gts, size=(14,14) , mode='bilinear') 
            gts4=gts4.cuda()
           

            gts3= nn.functional.interpolate(gts, size=(28,28) , mode='bilinear')
            gts3=gts3.cuda()
            
            gts2= nn.functional.interpolate(gts, size=(56,56) , mode='bilinear')
            
            gts2=gts2.cuda()

            gts1= nn.functional.interpolate(gts, size=(112,112) , mode='bilinear')
            
            gts1=gts1.cuda()
            '''
            
            
            #s1,s2 = model(images,depths)
            s1 = model(images,depths)

            '''
            s1, level1,level2,level3,level4,s1_sig,lev1_sig,lev2_sig,lev3_sig,lev4_sig = model(images,depths)
           
            sal_loss= CE(s1,gts) +IOU(s1_sig, gts)
           
          
            
            loss1= CE(level1,gts1)+IOU(lev1_sig, gts1)
           
            loss2= CE(level2,gts2)+IOU(lev2_sig, gts2)
           
            
            loss3= CE(level3,gts3)+IOU(lev3_sig, gts3)
            
            loss4= CE(level4,gts4)+IOU(lev4_sig, gts4)
            
            loss=sal_loss+loss1+loss2+loss3+loss4
            
            
            loss.backward()

            clip_gradient(optimizer, clip)
            optimizer.step()
            step+=1
            epoch_step+=1
            loss_all+=float(loss.data)
            if i % 100 == 0 or i == total_step or i==1:
                print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '.
                    format(datetime.now(), epoch, epoch, i, total_step, loss.data))
                logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} '.
                    format( epoch, epoch, i, total_step, loss.data))
                writer.add_scalar('Loss', loss.data, global_step=step)
                grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True)
                writer.add_image('RGB', grid_image, step)
                grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True)
                writer.add_image('Ground_truth', grid_image, step)
                res=s1[0].clone()
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                writer.add_image('s1', torch.tensor(res), step,dataformats='HW')
                '''
                res=s2[0].clone()
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                writer.add_image('s2', torch.tensor(res), step,dataformats='HW')
                '''
        
        loss_all/=epoch_step
        logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, epoch, loss_all))
        writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
       # if (epoch) % 20 == 0:
        #    torch.save(model.state_dict(), save_path+'IR_Net_epoch_{}.pth'.format(epoch))
    except KeyboardInterrupt: 
        #print('Keyboard Interrupt: save model and exit.')
        #if not os.path.exists(save_path):
         #   os.makedirs(save_path)
        #torch.save(model.state_dict(), save_path+'IR_Net_epoch_{}.pth'.format(epoch+1))
        #print('save checkpoints successfully!')
        raise

#validation test function with edges

def test(test_loader,model,epoch,save_path):
    global best_mae,best_epoch
    model.eval()
    with torch.no_grad():
        mae_sum=0
        for i in range(test_loader.size):
            image, gt,depth,name,img_for_post = test_loader.load_data()
            gt = np.asarray(gt, np.float32)
            gt /= (gt.max() + 1e-8)
            image = image.cuda()
            depth = depth.cuda()
            
            res,_ = model(image,depth)
            #_,res = model(image,depth)
            m = nn.Upsample(size=gt.shape, mode='bilinear', align_corners=False)
            res=m(res)
            #res = f.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            #imageio.imwrite(save_path+name,res)
            #misc.imsave(save_path+name, res)

            mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1])
        mae=mae_sum/test_loader.size
        writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch)
        print('Epoch: {} MAE: {} ####  bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch))
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': mae
            }, '/content/chk_points/current_ckpoint.pt')
        
        if epoch==1:
            best_mae=mae
        else:
            if mae<best_mae:
                best_mae=mae
                best_epoch=epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_mae
                    
                  }, '/content/chk_points/best_modell.pt')
                #torch.save(model.state_dict(), save_path+'IR_Net_epoch_best.pth')
                print('best epoch:{}'.format(epoch))
        logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch,mae,best_epoch,best_mae))

#validation test function with superrvison 

def test(test_loader,model,epoch,save_path):
    global best_mae,best_epoch
    model.eval()
    with torch.no_grad():
        mae_sum=0
        for i in range(test_loader.size):
            image, gt,depth, name,img_for_post = test_loader.load_data()
            gt = np.asarray(gt, np.float32)
            gt /= (gt.max() + 1e-8)
            image = image.cuda()
            depth = depth.cuda()
            res,_,_,_,_,_,_,_,_,_ = model(image,depth)
            #_,res = model(image,depth)
            m = nn.Upsample(size=gt.shape, mode='bilinear', align_corners=False)
            res=m(res)
            #res = f.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)
            #imageio.imwrite(save_path+name,res)
            #misc.imsave(save_path+name, res)

            mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1])
        mae=mae_sum/test_loader.size
        writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch)
        print('Epoch: {} MAE: {} ####  bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch))

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': mae
            }, '/content/chk_points/current_ckpoint.pt')
        

        if epoch==1:
            best_mae=mae
        else:
            if mae<best_mae:
                best_mae=mae
                best_epoch=epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_mae
                  }, '/content/chk_points/best_modell.pt')
                #torch.save(model.state_dict(), save_path+'IR_Net_epoch_best.pth')
                print('best epoch:{}'.format(epoch))
        logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch,mae,best_epoch,best_mae))

print("Start train...")
model.cuda()


for epoch in range(1, epoch):       
        cur_lr=adjust_lr(optimizer, lr, epoch, decay_rate, decay_epoch)
        writer.add_scalar('learning_rate', cur_lr, global_step=epoch)
        
        train(train_loader, model, optimizer, epoch,save_path_cpoints)
        
        test(test_loader,model,epoch,save_path_cpoints)