FairCDSR / train_rec.py
train_rec.py
Raw
import os
import numpy as np
import random
from pathlib import Path
import torch
from utils import torch_utils, helper
from model.trainer import CDSRTrainer
from utils.loader import *
from utils.MoCo_utils import compute_features
from utils.cluster import run_kmeans
from utils.collator import CLDataCollator
from model.item_generator import Generator
from config.config import get_args

def main(args):
    def seed_everything(seed=1111):
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print("seed set done! seed{}".format(seed))
    if args.cpu:
        args.cuda = False
    elif args.cuda:
        torch.cuda.manual_seed(args.seed)
    args.num_cluster = [int(n) for n in args.num_cluster.split(',')]
    
    # make opt
    opt = vars(args)
    print("My seed:", opt["seed"])
    seed_everything(opt["seed"])
    model_id = opt["id"]
    folder = opt['save_dir'] + '/'+ str(opt['data_dir'])+ '/' + str(model_id)
    Path(folder).mkdir(parents=True, exist_ok=True)
    model_save_dir = folder + '/' + str(opt['seed'])
    opt['model_save_dir'] = model_save_dir
    helper.ensure_dir(model_save_dir, verbose=True)
    # save config
    helper.save_config(opt, model_save_dir + '/config.json', verbose=True)
    file_logger = helper.FileLogger(model_save_dir + '/' + opt['log'],
                                    header="# test_MRR\ttest_NDCG_10\ttest_HR_10")

    # print model info
    helper.print_config(opt)

    print("Loading data from {} with batch size {}...".format(opt['data_dir'], opt['batch_size']))
    
    if opt["training_mode"] not in ["finetune","joint_learn","evaluation"]:
        raise ValueError("training mode must be finetune, joint_learn or evaluation")
    if opt["training_mode"] in ["joint_learn"] and opt["ssl"] not in ["GMiT","group_CL","both"]:
        raise ValueError("SSL must be GMiT, group_CL or both")
    
    # read number of items
    def read_item(fname):
        with codecs.open(fname, "r", encoding="utf-8") as fr:
            item_num = [int(d.strip()) for d in fr.readlines()[:2]]
        return item_num
    filename = opt["data_dir"]

    opt["source_item_num"], opt["target_item_num"] = read_item(f"./fairness_dataset/{opt['dataset']}/" + filename + "/train.txt")
    opt['itemnum'] = opt["source_item_num"] + opt["target_item_num"] +1
    if opt['data_augmentation'] not in ["item_augmentation","user_generation"] and opt['data_augmentation'] is not None:
        raise ValueError("data augmentation must be item_augmentation or user_generation")
    # load item generator
    if opt['data_augmentation'] == "item_augmentation" or opt['ssl']=="group_CL" or opt['ssl']=="both":
        source_generator = Generator(opt, type='X')
        checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/X/{str(opt['load_pretrain_epoch'])}/model.pt")    
        state_dict = checkpoint['model']
        source_generator.load_state_dict(state_dict)
        target_generator = Generator(opt, type='Y')
        checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/Y/{str(opt['load_pretrain_epoch'])}/model.pt")    
        state_dict = checkpoint['model']
        target_generator.load_state_dict(state_dict)
        mixed_generator = Generator(opt, type='mixed')
        checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/mixed/{str(opt['load_pretrain_epoch'])}/model.pt")    
        state_dict = checkpoint['model']
        mixed_generator.load_state_dict(state_dict)
        print("\033[01;32m Generator loaded! \033[0m")
    # use collator or not for GCL
    if opt['ssl'] in ["group_CL","both"] and opt["substitute_mode"]in ["DGIR","AGIR","random"]: #GAW and Hybrid should warm up for few epochs
        collator = CLDataCollator(opt, eval=-1, mixed_generator=mixed_generator)
    else:
        collator = None
    # build dataloader
    if opt['training_mode'] != "evaluation":
        if opt['data_augmentation']=="item_augmentation":
            train_batch = CustomDataLoader(opt['data_dir'], opt['batch_size'], opt, evaluation = -1, collate_fn  = collator, generator = [source_generator, target_generator, mixed_generator])
        else:
            train_batch = CustomDataLoader(opt['data_dir'], opt['batch_size'], opt, evaluation = -1, collate_fn  = collator)
        valid_batch = CustomDataLoader(opt['data_dir'], opt["batch_size"], opt, evaluation = 2, collate_fn  = None)
    test_batch = CustomDataLoader(opt['data_dir'], opt["batch_size"], opt, evaluation = 1,collate_fn  = None)
    print("Data loading done!")    
    # model
    trainer = CDSRTrainer(opt)
    if opt['training_mode']=="evaluation":
        if opt['evaluation_model'] is None:
            raise ValueError("evaluation model is not specified!") 
        if opt["main_task"]=="X":
            evaluation_path ="models/" + opt['data_dir'] + f"/{opt['evaluation_model']}/{str(opt['seed'])}/X_model.pt" 
        elif opt["main_task"]=="Y":    
            evaluation_path ="models/" + opt['data_dir'] + f"/{opt['evaluation_model']}/{str(opt['seed'])}/Y_model.pt"
        print("evaluation_path",evaluation_path)
        if os.path.exists(evaluation_path):
            print("\033[01;32m Loading evaluation model from {}... \033[0m\n".format(evaluation_path))
            trainer.load(evaluation_path)
            print("\033[01;32m Loading evaluation model done! \033[0m\n")
        else:
            raise ValueError("evaluation model does not exist!")
        print("\033[01;34m Start evaluation... \033[0m\n")
        best_Y_test, best_Y_test_male,best_Y_test_female = trainer.evaluate(test_batch, file_logger)
        return best_Y_test, best_Y_test_male,best_Y_test_female
    else:
        print("\033[01;32m Model training from scratch... \033[0m\n")
    if opt['training_mode']=="joint_learn":
        print("\033[01;34m Start joint learning... \033[0m\n")
    if opt['data_augmentation']=="item_augmentation" or opt['ssl']in ["group_CL","both"]:
        trainer.generator = [source_generator, target_generator, mixed_generator]
    best_test,best_test_male,best_test_female =  trainer.train(opt['num_epoch'], train_batch, valid_batch, test_batch, file_logger)
    return best_test,best_test_male,best_test_female
if __name__ == '__main__':
    args  = get_args()
    main(args)