RoboFireFuseNet-private / train.py
train.py
Raw
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.tools import *
from utils.training_tools import *
from utils.logger import Logger


def train(model, train_data, val_data, args, logger):
    train_loader = DataLoader(train_data, batch_size=args['BATCHSIZE'], shuffle=True, num_workers=args['NUM_WORKERS'], prefetch_factor=1, drop_last=True,
    pin_memory=True)
    trainer = Trainer(args, model, len(train_data))
    for epoch in range(trainer.start_epoch, args['EPOCHS']):
        running_loss, conf_mat = 0, np.zeros((args['NUM_CLASSES'], args['NUM_CLASSES']))
        # train loop
        for batch in tqdm(train_loader, desc=f'Epoch {epoch}'):
            loss, curr_conf_mat = trainer.training_step(batch)
            conf_mat += curr_conf_mat
            running_loss += loss.item() / len(train_loader)
        metrics, _ = calculate_metrics(conf_mat, running_loss, cls_names=args['CLS_NAMES'], cls_weights=args['CLASS_WEIGHTS'])
        # validation, saving model and examining stop criterion
        if epoch % args['VALID_FREQ'] == 0:
            val_metrics, iou = valid(trainer, val_data, epoch, args)
            logger.print_metrics(val_metrics, args['CLS_NAMES'], args['NUM_CLASSES'], cls_weights=args['CLS_NAMES'], val=True)
            metrics = {**metrics, **val_metrics}
            trainer.save_checkpoint(os.path.join('weights', args['PROJECTNAME'], args['SESSIONAME']), epoch)
            if trainer.stop_sign(metrics): break
            train_data.mious = iou
        logger.log(metrics, epoch, trainer.scheduler)
        logger.print_metrics(metrics, cls_names=args['CLS_NAMES'], num_classes=args['NUM_CLASSES'], cls_weights=args['CLS_NAMES'], val=False)
    print(f"Training Finished! Best Epoch {logger.best_epoch}: MIOU {logger.best_miou}")
    return logger.best_miou

def valid(trainer, val_data, cur_epoch, args):
    valid_loader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=args['NUM_WORKERS'])
    running_loss, conf_mat = 0.0, np.zeros((args['NUM_CLASSES'], args['NUM_CLASSES']))
    for batch in tqdm(valid_loader, desc='Validation'):
        loss, curr_conf_mat = trainer.valid_step(batch)
        conf_mat += curr_conf_mat
        running_loss += loss.item() / len(valid_loader)
    metrics, iou = calculate_metrics(conf_mat, running_loss, cls_names=args['CLS_NAMES'], cls_weights=args['CLASS_WEIGHTS'], val=True)
    qualitive_eval(lambda data: trainer.inference(data), val_data, 
                   ex_path=f'./outputs/{args["PROJECTNAME"]}/{args["SESSIONAME"]}/visualizations', name=f'Epoch_{cur_epoch}.png')
    return metrics, iou

def main(args):
    set_reproducibility(args['SEED'])
    train_dataset, val_dataset, _ = get_dataset(args)
    model = get_model(args)
    logger = Logger(args)
    train(model, train_dataset, val_dataset, args, logger)
    logger.close()


if __name__ == '__main__':
    args = parse_args()
    main(args)