KGTOSA / GNN-Methods / NodeClassifcation / IBS / infer.py
infer.py
Raw
import logging
import os
import resource
import traceback

import numpy as np
import seml
import torch
from sacred import Experiment

from dataloaders.get_loaders import get_loaders
from data.data_preparation import load_data, GraphPreprocess
from models.get_model import get_model
from train.trainer import Trainer

ex = Experiment()
seml.setup_logger(ex)


@ex.post_run_hook
def collect_stats(_run):
    seml.collect_exp_stats(_run)


@ex.config
def config():
    overwrite = None
    db_collection = None
    if db_collection is not None:
        ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite))


@ex.automain
def run(dataset_name,
        mode,
        batch_size,
        full_graph_chunks,
        model_dir,

        ppr_params=None,
        batch_params=None,
        n_sampling_params=None,
        rw_sampling_params=None,
        ladies_params=None,
        shadow_ppr_params=None,
        rand_ppr_params=None,

        graphmodel='gcn',
        hidden_channels=256,
        num_layers=3,
        heads=None, ):
    try:

        logging.info(f'dataset: {dataset_name}, graphmodel: {graphmodel}, mode: {mode}')

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        graph, (train_indices, val_indices, test_indices) = load_data(dataset_name, 1,
                                                                      GraphPreprocess(True, True))
        logging.info("Graph loaded!\n")

        trainer = Trainer(mode, full_graph_chunks, batch_size=1, )

        (_,
         self_val_loader,
         ppr_val_loader,
         batch_val_loader,
         self_test_loader,
         ppr_test_loader,
         batch_test_loader) = get_loaders(
            graph,
            (train_indices, val_indices, test_indices),
            batch_size,
            mode,
            'rand',
            ppr_params,
            batch_params,
            rw_sampling_params,
            shadow_ppr_params,
            rand_ppr_params,
            ladies_params,
            n_sampling_params,
            inference=True,
            ibmb_val=False)

        model = get_model(graphmodel,
                          graph.num_node_features,
                          graph.y.max().item() + 1,
                          hidden_channels,
                          num_layers,
                          heads,
                          device)

        for _file in os.listdir(model_dir):
            if not _file.endswith('.pt'):
                continue
            model_path = os.path.join(model_dir, _file)
            model.load_state_dict(torch.load(model_path))
            model.eval()

            trainer.inference(self_val_loader,
                              ppr_val_loader,
                              batch_val_loader,
                              self_test_loader,
                              ppr_test_loader,
                              batch_test_loader,
                              model, )

            trainer.full_graph_inference(model, graph, val_indices, test_indices)

        results = {
            'gpu_memory': torch.cuda.max_memory_allocated(),
            'max_memory': 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        }

        for key, item in trainer.database.items():
            if key != 'training_curves':
                results[f'{key}_record'] = item
                item = np.array(item)
                results[f'{key}_stats'] = (item.mean(), item.std(),) if len(item) else (0., 0.,)

        return results
    except:
        traceback.print_exc()
        exit()