KGTOSA / GNN-Methods / NodeClassifcation / IBS / run_ogbn_ppr.py
run_ogbn_ppr.py
Raw
import logging
import resource
import time
import traceback

import os.path as osp
import numpy as np
import seml
import torch
from sacred import Experiment

from dataloaders.get_loaders_ppr import get_loaders
from data.data_preparation import check_consistence, load_data,load_data_hetero, GraphPreprocess,GraphPreprocess_hetero,GraphPreprocess_homo_hetero,load_data_hetero_homo
from models.get_model import get_model
from train.trainer_ppr import Trainer
from resource import *
ex = Experiment()
seml.setup_logger(ex)

@ex.post_run_hook
def collect_stats(_run):
    seml.collect_exp_stats(_run)


@ex.config
def config():
    overwrite = None
    db_collection = None
    print("db_collection=",db_collection)
    if db_collection is not None:
        ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite))


@ex.automain
def run(dataset_name,
        mode,
        batch_size,
        micro_batch,
        batch_order,
        inference,
        LBMB_val,
        small_trainingset,

        ppr_params,
        batch_params=None,
        graphmodel=None,
        hidden_channels=256,
        reg=0.,
        num_layers=3,
        heads=None,
        #epoch_min=300,
        #epoch_max=800,
        epoch_min=30,
        epoch_max=30,
        patience=100,
        lr=1e-3,
        seed=None,is_hetero=0 ):
    try:
        # dataset_name="mag"

        check_consistence(mode, batch_order)
        logging.info(f'dataset: {dataset_name}, graphmodel: {graphmodel}, mode: {mode}')
        hetero_graph, key2int,split_idx=None,None,None
        if is_hetero == 0:
            graph, (train_indices, val_indices, test_indices) = load_data(dataset_name,
                                                                          small_trainingset,
                                                                          GraphPreprocess(True, True))
        else:

            graph, (train_indices, val_indices, test_indices),hetero_graph,key2int,split_idx,local2global,subject_node_idx = load_data_hetero_homo(dataset_name,
                                                                                 small_trainingset,
                                                                                 GraphPreprocess_homo_hetero(True, True))

            # graph, (train_indices, val_indices, test_indices) = load_data_hetero(dataset_name,
            #                                                                      small_trainingset,
            #                                                                      GraphPreprocess_hetero(True, True))
            # output_node = list(graph.y_dict.keys())[0]

        logging.info("Graph loaded!\n")
        print(getrusage(RUSAGE_SELF))
        #device = 'cuda' if torch.cuda.is_available() else 'cpu'
        device = 'cpu'
        trainer = Trainer(mode,
                          batch_params['num_batches'][0],
                          micro_batch=micro_batch,
                          batch_size=batch_size,
                          epoch_max=epoch_max,
                          epoch_min=epoch_min,
                          patience=patience)

        comment = '_'.join([dataset_name,
                            graphmodel,
                            mode])

        (train_loader,
         self_val_loader,
         self_test_loader) = get_loaders(
            graph,
            (train_indices, val_indices, test_indices),
            batch_size,
            mode,
            batch_order,
            ppr_params,
            inference,
            LBMB_val,is_hetero=0,local2global=local2global,subject_node_idx=subject_node_idx)

        stamp = ''.join(str(time.time()).split('.')) + str(seed)

        logging.info(f'model info: {comment}/model_{stamp}.pt')
        # if is_hetero==0:
        model = get_model(graphmodel,
                          graph.num_node_features,
                          graph.y.max().item() + 1,
                          hidden_channels,
                          num_layers,
                          heads,
                          device,hetero_graph,key2int)
        ######################3
        feat = torch.Tensor(graph.num_nodes, 128)
        torch.nn.init.xavier_uniform_(feat)
        x_feat = {}
        x_feat[3] = feat
        print("x_feat=", x_feat)
        #################################
        model=trainer.train(train_loader,
                      self_val_loader,
                      model=model,
                      lr=lr,
                      reg=reg,
                      comment=comment,
                      run_no=stamp,org_graph=graph,x_feat=x_feat )

        gpu_memory = torch.cuda.max_memory_allocated()
        if inference:
            # model_dir = osp.join('./saved_models', comment)
            # assert osp.isdir(model_dir)
            # model_path = osp.join(model_dir, f'model_{stamp}.pt')
            # model.load_state_dict(torch.load(model_path))
            model.eval()

            trainer.inference(self_val_loader,
                              self_test_loader,
                              model,org_graph=graph,x_feat=x_feat)


            # trainer.full_graph_inference(model, graph,train_indices,val_indices, test_indices,x_feat=x_feat,key2int=key2int,subject_node='paper')

        runtime_train_lst = []
        runtime_self_val_lst = []
        runtime_part_val_lst = []
        runtime_ppr_val_lst = []
        for curves in trainer.database['training_curves']:
            runtime_train_lst += curves['per_train_time']
            runtime_self_val_lst += curves['per_self_val_time']

        results = {
            'runtime_train_perEpoch': sum(runtime_train_lst) / len(runtime_train_lst),
            'runtime_selfval_perEpoch': sum(runtime_self_val_lst) / len(runtime_self_val_lst),
            'gpu_memory': gpu_memory,
            'max_memory': 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss,
            'curves': trainer.database['training_curves'],
            # ...
        }

        for key, item in trainer.database.items():
            if key != 'training_curves':
                results[f'{key}_record'] = item
                item = np.array(item)
                results[f'{key}_stats'] = (item.mean(), item.std(),) if len(item) else (0., 0.,)
        print(getrusage(RUSAGE_SELF))
        return results
    except:
        traceback.print_exc()
        exit()