import os
import numpy as np
import random
from pathlib import Path
import torch
from utils import torch_utils, helper
from model.trainer import CDSRTrainer
from utils.loader import *
from utils.MoCo_utils import compute_features
from utils.cluster import run_kmeans
from utils.collator import CLDataCollator
from model.item_generator import Generator
from config.config import get_args
def main(args):
def seed_everything(seed=1111):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("seed set done! seed{}".format(seed))
if args.cpu:
args.cuda = False
elif args.cuda:
torch.cuda.manual_seed(args.seed)
args.num_cluster = [int(n) for n in args.num_cluster.split(',')]
# make opt
opt = vars(args)
print("My seed:", opt["seed"])
seed_everything(opt["seed"])
model_id = opt["id"]
folder = opt['save_dir'] + '/'+ str(opt['data_dir'])+ '/' + str(model_id)
Path(folder).mkdir(parents=True, exist_ok=True)
model_save_dir = folder + '/' + str(opt['seed'])
opt['model_save_dir'] = model_save_dir
helper.ensure_dir(model_save_dir, verbose=True)
# save config
helper.save_config(opt, model_save_dir + '/config.json', verbose=True)
file_logger = helper.FileLogger(model_save_dir + '/' + opt['log'],
header="# test_MRR\ttest_NDCG_10\ttest_HR_10")
# print model info
helper.print_config(opt)
print("Loading data from {} with batch size {}...".format(opt['data_dir'], opt['batch_size']))
if opt["training_mode"] not in ["finetune","joint_learn","evaluation"]:
raise ValueError("training mode must be finetune, joint_learn or evaluation")
if opt["training_mode"] in ["joint_learn"] and opt["ssl"] not in ["GMiT","group_CL","both"]:
raise ValueError("SSL must be GMiT, group_CL or both")
# read number of items
def read_item(fname):
with codecs.open(fname, "r", encoding="utf-8") as fr:
item_num = [int(d.strip()) for d in fr.readlines()[:2]]
return item_num
filename = opt["data_dir"]
opt["source_item_num"], opt["target_item_num"] = read_item(f"./fairness_dataset/{opt['dataset']}/" + filename + "/train.txt")
opt['itemnum'] = opt["source_item_num"] + opt["target_item_num"] +1
if opt['data_augmentation'] not in ["item_augmentation","user_generation"] and opt['data_augmentation'] is not None:
raise ValueError("data augmentation must be item_augmentation or user_generation")
# load item generator
if opt['data_augmentation'] == "item_augmentation" or opt['ssl']=="group_CL" or opt['ssl']=="both":
source_generator = Generator(opt, type='X')
checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/X/{str(opt['load_pretrain_epoch'])}/model.pt")
state_dict = checkpoint['model']
source_generator.load_state_dict(state_dict)
target_generator = Generator(opt, type='Y')
checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/Y/{str(opt['load_pretrain_epoch'])}/model.pt")
state_dict = checkpoint['model']
target_generator.load_state_dict(state_dict)
mixed_generator = Generator(opt, type='mixed')
checkpoint = torch.load(f"./generator_model/{opt['data_dir']}/mixed/{str(opt['load_pretrain_epoch'])}/model.pt")
state_dict = checkpoint['model']
mixed_generator.load_state_dict(state_dict)
print("\033[01;32m Generator loaded! \033[0m")
# use collator or not for GCL
if opt['ssl'] in ["group_CL","both"] and opt["substitute_mode"]in ["DGIR","AGIR","random"]: #GAW and Hybrid should warm up for few epochs
collator = CLDataCollator(opt, eval=-1, mixed_generator=mixed_generator)
else:
collator = None
# build dataloader
if opt['training_mode'] != "evaluation":
if opt['data_augmentation']=="item_augmentation":
train_batch = CustomDataLoader(opt['data_dir'], opt['batch_size'], opt, evaluation = -1, collate_fn = collator, generator = [source_generator, target_generator, mixed_generator])
else:
train_batch = CustomDataLoader(opt['data_dir'], opt['batch_size'], opt, evaluation = -1, collate_fn = collator)
valid_batch = CustomDataLoader(opt['data_dir'], opt["batch_size"], opt, evaluation = 2, collate_fn = None)
test_batch = CustomDataLoader(opt['data_dir'], opt["batch_size"], opt, evaluation = 1,collate_fn = None)
print("Data loading done!")
# model
trainer = CDSRTrainer(opt)
if opt['training_mode']=="evaluation":
if opt['evaluation_model'] is None:
raise ValueError("evaluation model is not specified!")
if opt["main_task"]=="X":
evaluation_path ="models/" + opt['data_dir'] + f"/{opt['evaluation_model']}/{str(opt['seed'])}/X_model.pt"
elif opt["main_task"]=="Y":
evaluation_path ="models/" + opt['data_dir'] + f"/{opt['evaluation_model']}/{str(opt['seed'])}/Y_model.pt"
print("evaluation_path",evaluation_path)
if os.path.exists(evaluation_path):
print("\033[01;32m Loading evaluation model from {}... \033[0m\n".format(evaluation_path))
trainer.load(evaluation_path)
print("\033[01;32m Loading evaluation model done! \033[0m\n")
else:
raise ValueError("evaluation model does not exist!")
print("\033[01;34m Start evaluation... \033[0m\n")
best_Y_test, best_Y_test_male,best_Y_test_female = trainer.evaluate(test_batch, file_logger)
return best_Y_test, best_Y_test_male,best_Y_test_female
else:
print("\033[01;32m Model training from scratch... \033[0m\n")
if opt['training_mode']=="joint_learn":
print("\033[01;34m Start joint learning... \033[0m\n")
if opt['data_augmentation']=="item_augmentation" or opt['ssl']in ["group_CL","both"]:
trainer.generator = [source_generator, target_generator, mixed_generator]
best_test,best_test_male,best_test_female = trainer.train(opt['num_epoch'], train_batch, valid_batch, test_batch, file_logger)
return best_test,best_test_male,best_test_female
if __name__ == '__main__':
args = get_args()
main(args)