KGTOSA / GNN-Methods / NodeClassifcation / IBS / dataloaders / get_loaders.py
get_loaders.py
Raw
import os
import numpy as np
import pandas as pd
from math import ceil
from typing import Dict, Tuple, Union, Optional

from torch import LongTensor

from torch_geometric.data import Data

from dataloaders.GraphSAINTRWSampler import SaintRWTrainSampler, SaintRWValSampler
from dataloaders.IBMBBatchLoader import IBMBBatchLoader,IBMBBatchLoader_hetero
from dataloaders.IBMBNodeLoader import IBMBNodeLoader,IBMBNodeLoader_hetero

Loader = Union[
    SaintRWTrainSampler,
    SaintRWValSampler,
    IBMBBatchLoader,
    IBMBNodeLoader
]
EDGE_INDEX_TYPE = 'adj'


def num_out_nodes_per_batch_normalization(num_out_nodes: int,
                                          num_out_per_batch: int):
    num_batches = ceil(num_out_nodes / num_out_per_batch)
    return ceil(num_out_nodes / num_batches)


def get_loaders(graph: Data,
                splits: Tuple[LongTensor, LongTensor, LongTensor],
                batch_size: int,
                mode: str,
                batch_order: str,
                ppr_params: Optional[Dict],
                batch_params: Optional[Dict],
                rw_sampling_params: Optional[Dict],
                shadow_ppr_params: Optional[Dict],
                rand_ppr_params: Optional[Dict],
                ladies_params: Optional[Dict],
                n_sampling_params: Optional[Dict],
                inference: bool = True,
                ibmb_val: bool = True,is_hetero=0,local2global=None,subject_node_idx=None) -> Tuple[
    Optional[Loader],
    Optional[Loader],
    Optional[Loader],
    Optional[Loader],
    Optional[Loader],
    Optional[Loader],
    Optional[Loader]
]:
    train_indices, val_indices, test_indices = splits

    train_loader = None
    self_val_loader = None
    ppr_val_loader = None
    batch_val_loader = None
    self_test_loader = None
    ppr_test_loader = None
    batch_test_loader = None
    if mode == 'ppr':
        if is_hetero==0:
            train_loader = IBMBNodeLoader(graph,
                                      batch_order,
                                      train_indices if local2global ==None else local2global[subject_node_idx][train_indices],
                                      EDGE_INDEX_TYPE,
                                      ppr_params['neighbor_topk'],
                                      num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                          len(train_indices), ppr_params['primes_per_batch']),
                                      num_auxiliary_nodes_per_batch=None,
                                      alpha=ppr_params['alpha'],
                                      eps=ppr_params['eps'],
                                      batch_size=batch_size,
                                      shuffle=False)    # must be false, instead we define our own order!
            self_val_loader = IBMBNodeLoader(graph,
                                             batch_order,
                                             val_indices if local2global ==None else local2global[subject_node_idx][val_indices],
                                             EDGE_INDEX_TYPE,
                                             ppr_params['neighbor_topk'],
                                             num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                 len(val_indices), ppr_params['primes_per_batch'] * 2),
                                             num_auxiliary_nodes_per_batch=None,
                                             alpha=ppr_params['alpha'],
                                             eps=ppr_params['eps'],
                                             batch_size=batch_size,
                                             shuffle=False)
            if inference:
                self_test_loader = IBMBNodeLoader(graph,
                                                  batch_order,
                                                  test_indices if local2global ==None else local2global[subject_node_idx][test_indices],
                                                  EDGE_INDEX_TYPE,
                                                  ppr_params['neighbor_topk'],
                                                  num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                      len(test_indices), len(test_indices) ),
                                                  # len(test_indices), ppr_params['primes_per_batch'] * 2),
                                                  num_auxiliary_nodes_per_batch=None,
                                                  alpha=ppr_params['alpha'],
                                                  eps=ppr_params['eps'],
                                                  batch_size=batch_size,
                                                  shuffle=False)
        else:
            train_loader = IBMBNodeLoader_hetero(graph,
                                          batch_order,
                                          train_indices,
                                          EDGE_INDEX_TYPE,
                                          ppr_params['neighbor_topk'],
                                          num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                              len(train_indices), ppr_params['primes_per_batch']),
                                          num_auxiliary_nodes_per_batch=None,
                                          alpha=ppr_params['alpha'],
                                          eps=ppr_params['eps'],
                                          batch_size=batch_size,
                                          shuffle=False)  # must be false, instead we define our own order!
            self_val_loader = IBMBNodeLoader_hetero(graph,
                                             batch_order,
                                             val_indices,
                                             EDGE_INDEX_TYPE,
                                             ppr_params['neighbor_topk'],
                                             num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                 len(val_indices), ppr_params['primes_per_batch'] * 2),
                                             num_auxiliary_nodes_per_batch=None,
                                             alpha=ppr_params['alpha'],
                                             eps=ppr_params['eps'],
                                             batch_size=batch_size,
                                             shuffle=False)
            if inference:
                self_test_loader = IBMBNodeLoader_hetero(graph,
                                                  batch_order,
                                                  test_indices,
                                                  EDGE_INDEX_TYPE,
                                                  ppr_params['neighbor_topk'],
                                                  num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                      len(test_indices), ppr_params['primes_per_batch'] * 2),
                                                  num_auxiliary_nodes_per_batch=None,
                                                  alpha=ppr_params['alpha'],
                                                  eps=ppr_params['eps'],
                                                  batch_size=batch_size,
                                                  shuffle=False)

    elif mode == 'part':
        train_loader = IBMBBatchLoader(graph,
                                       batch_order,
                                       batch_params['num_batches'][0],
                                       train_indices,
                                       EDGE_INDEX_TYPE,
                                       batch_params['part_topk'][0],
                                       alpha=batch_params['alpha'],
                                       batch_size=batch_size,
                                       shuffle=False)
        self_val_loader = IBMBBatchLoader(graph,
                                          batch_order,
                                          batch_params['num_batches'][1],
                                          val_indices,
                                          EDGE_INDEX_TYPE,
                                          batch_params['part_topk'][1],
                                          alpha=batch_params['alpha'],
                                          batch_size=batch_size,
                                          shuffle=False)
        if inference:
            self_test_loader = IBMBBatchLoader(graph,
                                               batch_order,
                                               batch_params['num_batches'][2],
                                               test_indices,
                                               EDGE_INDEX_TYPE,
                                               batch_params['part_topk'][1],
                                               alpha=batch_params['alpha'],
                                               batch_size=batch_size,
                                               shuffle=False)
    elif mode == 'rw_sampling':
        dir_name = f'./saint_cache'
        if not os.path.isdir(dir_name):
            os.mkdir(dir_name)
        train_loader = SaintRWTrainSampler(graph,
                                           train_indices,
                                           EDGE_INDEX_TYPE,
                                           graph.num_nodes,
                                           rw_sampling_params['batch_size'][0],
                                           rw_sampling_params['walk_length'],
                                           rw_sampling_params['num_steps'],
                                           rw_sampling_params['sample_coverage'],
                                           save_dir=dir_name,
                                           shuffle=True)
        self_val_loader = SaintRWValSampler(graph,
                                            val_indices,
                                            EDGE_INDEX_TYPE,
                                            graph.num_nodes,
                                            rw_sampling_params['walk_length'],
                                            rw_sampling_params['sample_coverage'],
                                            save_dir=dir_name,
                                            batch_size=rw_sampling_params['batch_size'][1],
                                            shuffle=True)
        if inference:
            self_test_loader = SaintRWValSampler(graph,
                                                 test_indices,
                                                 EDGE_INDEX_TYPE,
                                                 graph.num_nodes,
                                                 rw_sampling_params['walk_length'],
                                                 rw_sampling_params['sample_coverage'],
                                                 save_dir=dir_name,
                                                 batch_size=rw_sampling_params['batch_size'][1],
                                                 shuffle=True)
    elif mode == 'clustergcn':
        train_loader = ClusterGCNLoader(graph,
                                        batch_params['num_batches'][0],
                                        train_indices,
                                        EDGE_INDEX_TYPE,
                                        batch_size=batch_size,
                                        shuffle=True)
        self_val_loader = ClusterGCNLoader(graph,
                                           batch_params['num_batches'][1],
                                           val_indices,
                                           EDGE_INDEX_TYPE,
                                           batch_size=batch_size,
                                           shuffle=True)
        if inference:
            self_test_loader = ClusterGCNLoader(graph,
                                                batch_params['num_batches'][2],
                                                test_indices,
                                                EDGE_INDEX_TYPE,
                                                batch_size=batch_size,
                                                shuffle=True)
    elif mode == 'ppr_shadow':
        train_loader = ShaDowLoader(graph,
                                    train_indices,
                                    EDGE_INDEX_TYPE,
                                    shadow_ppr_params['neighbor_topk'],
                                    shadow_ppr_params['alpha'],
                                    shadow_ppr_params['eps'],
                                    batch_size=num_out_nodes_per_batch_normalization(
                                        len(train_indices), shadow_ppr_params['primes_per_batch']),
                                    shuffle=True)
        self_val_loader = ShaDowLoader(graph,
                                       val_indices,
                                       EDGE_INDEX_TYPE,
                                       shadow_ppr_params['neighbor_topk'],
                                       shadow_ppr_params['alpha'],
                                       shadow_ppr_params['eps'],
                                       batch_size=num_out_nodes_per_batch_normalization(
                                           len(val_indices), shadow_ppr_params['primes_per_batch'] * 2),
                                       shuffle=True)
        if inference:
            self_test_loader = ShaDowLoader(graph,
                                            test_indices,
                                            EDGE_INDEX_TYPE,
                                            shadow_ppr_params['neighbor_topk'],
                                            shadow_ppr_params['alpha'],
                                            shadow_ppr_params['eps'],
                                            batch_size=num_out_nodes_per_batch_normalization(
                                                len(test_indices), shadow_ppr_params['primes_per_batch'] * 2),
                                            shuffle=True)
    elif mode == 'rand':
        train_loader = IBMBRandLoader(graph,
                                      train_indices,
                                      EDGE_INDEX_TYPE,
                                      rand_ppr_params['neighbor_topk'],
                                      rand_ppr_params['alpha'],
                                      rand_ppr_params['eps'],
                                      batch_size=num_out_nodes_per_batch_normalization(
                                          len(train_indices), rand_ppr_params['primes_per_batch']), shuffle=True)
        self_val_loader = IBMBRandLoader(graph,
                                         val_indices,
                                         EDGE_INDEX_TYPE,
                                         rand_ppr_params['neighbor_topk'],
                                         rand_ppr_params['alpha'],
                                         rand_ppr_params['eps'],
                                         batch_size=num_out_nodes_per_batch_normalization(
                                             len(val_indices), rand_ppr_params['primes_per_batch'] * 2), shuffle=True)
        if inference:
            self_test_loader = IBMBRandLoader(graph,
                                              test_indices,
                                              EDGE_INDEX_TYPE,
                                              rand_ppr_params['neighbor_topk'],
                                              rand_ppr_params['alpha'],
                                              rand_ppr_params['eps'],
                                              batch_size=num_out_nodes_per_batch_normalization(
                                                  len(test_indices), rand_ppr_params['primes_per_batch'] * 2),
                                              shuffle=True)
    elif mode == 'randfix':
        train_loader = IBMBRandfixLoader(graph,
                                         batch_order,
                                         train_indices,
                                         EDGE_INDEX_TYPE,
                                         num_out_nodes_per_batch_normalization(
                                             len(train_indices), ppr_params['primes_per_batch']),
                                         rand_ppr_params['neighbor_topk'],
                                         rand_ppr_params['alpha'],
                                         rand_ppr_params['eps'],
                                         batch_size=batch_size,
                                         shuffle=False)
        self_val_loader = IBMBRandfixLoader(graph,
                                            batch_order,
                                            val_indices,
                                            EDGE_INDEX_TYPE,
                                            num_out_nodes_per_batch_normalization(
                                                len(val_indices), ppr_params['primes_per_batch'] * 2),
                                            rand_ppr_params['neighbor_topk'],
                                            rand_ppr_params['alpha'],
                                            rand_ppr_params['eps'],
                                            batch_size=batch_size,
                                            shuffle=False)
        if inference:
            self_test_loader = IBMBRandfixLoader(graph,
                                                 batch_order,
                                                 test_indices,
                                                 EDGE_INDEX_TYPE,
                                                 num_out_nodes_per_batch_normalization(
                                                     len(test_indices), ppr_params['primes_per_batch'] * 2),
                                                 rand_ppr_params['neighbor_topk'],
                                                 rand_ppr_params['alpha'],
                                                 rand_ppr_params['eps'],
                                                 batch_size=batch_size,
                                                 shuffle=False)
    elif mode == 'ladies':
        train_loader = LADIESSampler(graph,
                                     train_indices,
                                     EDGE_INDEX_TYPE,
                                     [ladies_params['sample_size'][0]] * ladies_params['num_layers'],
                                     batch_size=ceil(len(train_indices) / ladies_params['num_batches'][0]),
                                     shuffle=True)
        self_val_loader = LADIESSampler(graph,
                                        val_indices,
                                        EDGE_INDEX_TYPE,
                                        [ladies_params['sample_size'][1]] * ladies_params['num_layers'],
                                        batch_size=ceil(len(val_indices) / ladies_params['num_batches'][1]),
                                        shuffle=True)
        if inference:
            self_test_loader = LADIESSampler(graph,
                                             test_indices,
                                             EDGE_INDEX_TYPE,
                                             [ladies_params['sample_size'][2]] * ladies_params['num_layers'],
                                             batch_size=ceil(len(test_indices) / ladies_params['num_batches'][2]),
                                             shuffle=True)
    elif mode == 'n_sampling':
        train_loader = NeighborSamplingLoader(graph,
                                              sizes=n_sampling_params['n_nodes'],
                                              node_idx=train_indices,
                                              batch_size=ceil(len(train_indices) / n_sampling_params['num_batches'][0]),
                                              shuffle=True)
        self_val_loader = NeighborSamplingLoader(graph,
                                                 sizes=n_sampling_params['n_nodes'],
                                                 node_idx=val_indices,
                                                 batch_size=ceil(
                                                     len(val_indices) / n_sampling_params['num_batches'][1]),
                                                 shuffle=True)
        if inference:
            self_test_loader = NeighborSamplingLoader(graph,
                                                      sizes=n_sampling_params['n_nodes'],
                                                      node_idx=test_indices,
                                                      batch_size=ceil(
                                                          len(test_indices) / n_sampling_params['num_batches'][2]),
                                                      shuffle=True)
    else:
        raise NotImplementedError

    if ibmb_val:
        if mode != 'ppr' and ppr_params is not None:
            ppr_val_loader = IBMBNodeLoader(graph,
                                            batch_order,
                                            val_indices,
                                            EDGE_INDEX_TYPE,
                                            ppr_params['neighbor_topk'],
                                            num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                len(val_indices), ppr_params['primes_per_batch'] * 2),
                                            num_auxiliary_nodes_per_batch=None,
                                            alpha=ppr_params['alpha'],
                                            eps=ppr_params['eps'],
                                            batch_size=batch_size,
                                            shuffle=False)
            if inference:
                ppr_test_loader = IBMBNodeLoader(graph,
                                                 batch_order,
                                                 test_indices,
                                                 EDGE_INDEX_TYPE,
                                                 ppr_params['neighbor_topk'],
                                                 num_output_nodes_per_batch=num_out_nodes_per_batch_normalization(
                                                     len(test_indices), ppr_params['primes_per_batch'] * 2),
                                                 num_auxiliary_nodes_per_batch=None,
                                                 alpha=ppr_params['alpha'],
                                                 eps=ppr_params['eps'],
                                                 batch_size=batch_size,
                                                 shuffle=False)
        if mode != 'part' and batch_params is not None:
            if is_hetero == 0:
                batch_val_loader = IBMBBatchLoader(graph,
                                                   batch_order,
                                                   batch_params['num_batches'][1],
                                                   val_indices,
                                                   EDGE_INDEX_TYPE,
                                                   batch_params['part_topk'][1],
                                                   alpha=batch_params['alpha'],
                                                   batch_size=batch_size,
                                                   shuffle=False)
                if inference:
                    batch_test_loader = IBMBBatchLoader(graph,
                                                        batch_order,
                                                        batch_params['num_batches'][2],
                                                        test_indices,
                                                        EDGE_INDEX_TYPE,
                                                        batch_params['part_topk'][1],
                                                        alpha=batch_params['alpha'],
                                                        batch_size=batch_size,
                                                        shuffle=False)
            else:
                batch_val_loader = IBMBBatchLoader_hetero(graph,
                                                   batch_order,
                                                   batch_params['num_batches'][1],
                                                   val_indices,
                                                   EDGE_INDEX_TYPE,
                                                   batch_params['part_topk'][1],
                                                   alpha=batch_params['alpha'],
                                                   batch_size=batch_size,
                                                   shuffle=False)
                if inference:
                    batch_test_loader = IBMBBatchLoader_hetero(graph,
                                                        batch_order,
                                                        batch_params['num_batches'][2],
                                                        test_indices,
                                                        EDGE_INDEX_TYPE,
                                                        batch_params['part_topk'][1],
                                                        alpha=batch_params['alpha'],
                                                        batch_size=batch_size,
                                                        shuffle=False)


    return (train_loader,
            self_val_loader,
            ppr_val_loader,
            batch_val_loader,
            self_test_loader,
            ppr_test_loader,
            batch_test_loader)