from utils import get_g_bidir from datasets import TrainSubgraphDataset, ValidSubgraphDataset from torch.utils.data import DataLoader import torch from torch import optim from trainer import Trainer import dgl from collections import defaultdict as ddict from resource import * import datetime class MetaTrainer(Trainer): def __init__(self, args,logger=None): super(MetaTrainer, self).__init__(args,logger) # dataloader train_subgraph_dataset = TrainSubgraphDataset(args) valid_subgraph_dataset = ValidSubgraphDataset(args) self.train_subgraph_dataloader = DataLoader(train_subgraph_dataset, batch_size=args.metatrain_bs, shuffle=True, collate_fn=TrainSubgraphDataset.collate_fn) self.valid_subgraph_dataloader = DataLoader(valid_subgraph_dataset, batch_size=args.metatrain_bs, shuffle=False, collate_fn=ValidSubgraphDataset.collate_fn) # optim self.optimizer = optim.Adam(list(self.ent_init.parameters()) + list(self.rgcn.parameters()) + list(self.kge_model.parameters()), lr=self.args.metatrain_lr) def load_pretrain(self): state = torch.load(self.args.pretrain_state, map_location=self.args.gpu) self.ent_init.load_state_dict(state['ent_init']) self.rgcn.load_state_dict(state['rgcn']) self.kge_model.load_state_dict(state['kge_model']) def train(self,start_time=None): best_step = 0 best_eval_rst = {'mrr': 0, 'hits@1': 0, 'hits@5': 0, 'hits@10': 0} bad_count = 0 self.logger.info('start meta-training') for e in range(self.args.metatrain_num_epoch): step = 0 self.logger.info(getrusage(RUSAGE_SELF)) # self.logger.info('number of batches={:}'.format(len(self.train_subgraph_dataloader))) for batch in self.train_subgraph_dataloader: batch_loss = 0 batch_sup_g = dgl.batch([get_g_bidir(d[0], self.args) for d in batch]).to(self.args.gpu) self.get_ent_emb(batch_sup_g) sup_g_list = dgl.unbatch(batch_sup_g) for batch_i, data in enumerate(batch): que_tri, que_neg_tail_ent, que_neg_head_ent = [d.to(self.args.gpu) for d in data[1:]] ent_emb = sup_g_list[batch_i].ndata['h'] # kge loss loss = self.get_loss(que_tri, que_neg_tail_ent, que_neg_head_ent, ent_emb) batch_loss += loss batch_loss /= len(batch) self.optimizer.zero_grad() batch_loss.backward() self.optimizer.step() step += 1 self.logger.info('epoch:{} | step: {} | loss: {:.4f}'.format(e,step,batch_loss.item())) self.write_training_loss(batch_loss.item(), step) if step % self.args.metatrain_check_per_step == 0: eval_res = self.evaluate_valid_subgraphs() self.write_evaluation_result(eval_res, step) self.logger.info(getrusage(RUSAGE_SELF)) self.logger.info("Train Time Sec="+str((datetime.datetime.now() - start_time).total_seconds())) if eval_res['mrr'] > best_eval_rst['mrr']: best_eval_rst = eval_res best_step = step self.logger.info('best model | mrr {:.4f}'.format(best_eval_rst['mrr'])) self.save_checkpoint(step) bad_count = 0 else: bad_count += 1 self.logger.info('best model is at step {0}, mrr {1:.4f}, bad count {2}'.format( best_step, best_eval_rst['mrr'], bad_count)) self.logger.info('finish meta-training') self.logger.info('save best model') self.save_model(best_step) self.logger.info('best validation | epoch:{:}, mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( e,best_eval_rst['mrr'], best_eval_rst['hits@1'], best_eval_rst['hits@5'], best_eval_rst['hits@10'])) self.before_test_load() self.evaluate_indtest_test_triples(num_cand=50) def evaluate_valid_subgraphs(self): all_results = ddict(int) for batch in self.valid_subgraph_dataloader: batch_sup_g = dgl.batch([get_g_bidir(d[0], self.args) for d in batch]).to(self.args.gpu) self.get_ent_emb(batch_sup_g) sup_g_list = dgl.unbatch(batch_sup_g) for batch_i, data in enumerate(batch): que_dataloader = data[1] ent_emb = sup_g_list[batch_i].ndata['h'] results = self.evaluate(ent_emb, que_dataloader) for k, v in results.items(): all_results[k] += v for k, v in all_results.items(): all_results[k] = v / self.args.num_valid_subgraph self.logger.info('valid on valid subgraphs') self.logger.info(' mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( all_results['mrr'], all_results['hits@1'], all_results['hits@5'], all_results['hits@10'])) return all_results