KGTOSA / GNN-Methods / LinkPrediction / Morse / main.py
main.py
Raw
import argparse
from utils import init_dir, set_seed, get_num_rel
from meta_trainer import MetaTrainer
from post_trainer import PostTrainer
import os
from subgraph import gen_subgraph_datasets
from pre_process import data2pkl,data2pkl_Trans_to_Ind
from resource import *
import datetime
from utils import Log
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_name', default='Yago10')
    parser.add_argument('--name', default='Yago3-10_transe', type=str)
    parser.add_argument('--step', default='meta_train', type=str, choices=['meta_train', 'fine_tune'])
    parser.add_argument('--metatrain_state', default='./state/fb237_v1_transe/fb237_v1_transe.best', type=str)

    parser.add_argument('--state_dir', '-state_dir', default='./state', type=str)
    parser.add_argument('--log_dir', '-log_dir', default='./log', type=str)
    parser.add_argument('--tb_log_dir', '-tb_log_dir', default='./tb_log', type=str)

    # params for subgraph
    parser.add_argument('--num_train_subgraph', default=10000)
    parser.add_argument('--num_valid_subgraph', default=200)
    parser.add_argument('--num_sample_for_estimate_size', default=50)
    parser.add_argument('--rw_0', default=10, type=int)
    parser.add_argument('--rw_1', default=10, type=int)
    parser.add_argument('--rw_2', default=5, type=int)
    parser.add_argument('--num_sample_cand', default=5, type=int)

    # params for meta-train
    parser.add_argument('--metatrain_num_neg', default=32)
    parser.add_argument('--metatrain_num_epoch', default=10)
    parser.add_argument('--metatrain_bs', default=64, type=int)
    parser.add_argument('--metatrain_lr', default=0.01, type=float)
    parser.add_argument('--metatrain_check_per_step', default=50, type=int)
    parser.add_argument('--indtest_eval_bs', default=512, type=int)

    # params for fine-tune
    parser.add_argument('--posttrain_num_neg', default=64, type=int)
    parser.add_argument('--posttrain_bs', default=512, type=int)
    parser.add_argument('--posttrain_lr', default=0.001, type=int)
    parser.add_argument('--posttrain_num_epoch', default=10, type=int)
    parser.add_argument('--posttrain_check_per_epoch', default=1, type=int)

    # params for R-GCN
    parser.add_argument('--num_layers', default=3, type=int)
    parser.add_argument('--num_bases', default=4, type=int)
    parser.add_argument('--emb_dim', default=32, type=int)

     # params for KGE
    parser.add_argument('--kge', default='TransE', type=str, choices=['TransE', 'DistMult', 'ComplEx', 'RotatE'])
    parser.add_argument('--gamma', default=10, type=float)
    parser.add_argument('--adv_temp', default=1, type=float)

    parser.add_argument('--gpu', default='cpu', type=str)
    parser.add_argument('--seed', default=1234, type=int)
    args = parser.parse_args()
    logger = Log(args.log_dir, args.name).get_logger()
    init_dir(args)

    args.ent_dim = args.emb_dim
    args.rel_dim = args.emb_dim
    if args.kge in ['ComplEx', 'RotatE']:
        args.ent_dim = args.emb_dim * 2
    if args.kge in ['ComplEx']:
        args.rel_dim = args.emb_dim * 2

    # specify the paths for original data and subgraph db
    
    BGP='FG'
    Target_rel='isConnectedTo'
    for BGP in ['BSQ','PQ','BPQ','FG']:  
        start_t = datetime.datetime.now()
        sample_start_t = datetime.datetime.now()
        args.data_path = 'data/'+args.data_name+'_'+BGP+'.pkl'
        args.db_path = 'data/'+args.data_name+'_'+BGP+'_subgraph'
        # load original data and make index
        if not os.path.exists(args.data_path):
            # data2pkl(args.data_name,Target_rel,BGP)
            data2pkl_Trans_to_Ind(args.data_name,BGP,Target_rel,logger=logger,datapath='/shared_mnt/github_repos/RGCN_LP/data')
            logger.info("Sampling Time Sec="+str((datetime.datetime.now() - sample_start_t).total_seconds()))

        if not os.path.exists(args.db_path):
            gen_subgraph_datasets(args)
        args.num_rel = get_num_rel(args)
        set_seed(args.seed)
        if args.step == 'meta_train':
            meta_trainer = MetaTrainer(args,logger)
            logger.info("BGP="+str(BGP))
            meta_trainer.train(datetime.datetime.now())
            end_t = datetime.datetime.now()
            logger.info("Total Time Sec="+str((datetime.datetime.now() - start_t).total_seconds()))
        elif args.step == 'fine_tune':
            post_trainer = PostTrainer(args)
            post_trainer.train()