KGTOSA / GNN-Methods / NodeClassifcation / IBS / dataloaders / GraphSAINTRWSampler.py
GraphSAINTRWSampler.py
Raw
import os.path as osp
from math import ceil
from typing import Optional

import numpy as np
import torch
from torch_geometric.data import Data
from torch_sparse import SparseTensor
from tqdm import tqdm

from dataloaders.BaseLoader import BaseLoader


class SaintRWTrainSampler(BaseLoader):

    def __init__(self,
                 graph: Data,
                 output_indices: torch.LongTensor,
                 return_edge_index_type: str,
                 num_nodes: int,
                 batch_size: int,
                 walk_length: int,
                 num_steps: int = 1,
                 sample_coverage: int = 0,
                 save_dir: Optional[str] = None,
                 log: bool = True,
                 **kwargs):

        self.walk_length = walk_length
        self.num_steps = num_steps
        self.__batch_size__ = batch_size
        self.sample_coverage = sample_coverage
        self.log = log

        self.N = num_nodes
        self.E = graph.num_edges

        self.adj = SparseTensor.from_edge_index(graph.edge_index, sparse_sizes=(graph.num_nodes, graph.num_nodes))
        self.adj = self.normalize_adjmat(self.adj, normalization='rw')
        self.edge_weight = self.adj.storage.value()
        self.adj.set_value_(torch.arange(self.E))

        self.original_graph = graph
        self.output_indices = output_indices

        assert return_edge_index_type in ['adj', 'edge_index']
        self.return_edge_index_type = return_edge_index_type

        super().__init__(self, batch_size=1, **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            if save_dir is not None and osp.exists(path):  # pragma: no cover
                self.node_norm, self.edge_norm = torch.load(path)
            else:
                self.node_norm, self.edge_norm = self.__compute_norm__()
                if save_dir is not None:  # pragma: no cover
                    torch.save((self.node_norm, self.edge_norm), path)

    @property
    def __filename__(self):
        return (f'{self.N}_{self.walk_length}_'
                f'{self.sample_coverage}.pt')

    def __len__(self):
        return self.num_steps

    def loader_len(self):
        return self.num_steps

    def __sample_nodes__(self, batch_size):
        start = torch.randint(0, self.N, (batch_size,), dtype=torch.long)
        node_idx = self.adj.random_walk(start.flatten(), self.walk_length)
        return node_idx.view(-1)

    def __getitem__(self, idx):
        node_idx = self.__sample_nodes__(self.__batch_size__).unique()
        adj, edge_mask = self.adj.saint_subgraph(node_idx)
        return node_idx, adj

    def __collate__(self, data_list):
        assert len(data_list) == 1
        node_idx, adj = data_list[0]

        mask = torch.from_numpy(np.in1d(node_idx.cpu().numpy(), self.output_indices))

        if self.sample_coverage > 0:
            node_norm = self.node_norm[node_idx]
            node_norm = node_norm[mask]
            node_norm /= node_norm.sum()
        else:
            node_norm = None

        _, _, edge_idx = adj.coo()
        if self.return_edge_index_type == 'adj':
            if self.sample_coverage > 0:
                adj.set_value_(self.edge_norm[edge_idx] * self.edge_weight[edge_idx], layout='csr')
            else:
                adj.set_value_(self.edge_weight[edge_idx], layout='csr')
        else:
            adj = torch.vstack([adj.storage.row(), adj.storage.col()])

        graph = Data(x=self.original_graph.x[node_idx],
                     y=self.original_graph.y[node_idx],
                     edge_index=adj,
                     output_node_mask=mask,
                     node_norm=node_norm)

        return graph

    def __compute_norm__(self):
        node_count = torch.zeros(self.N, dtype=torch.float)
        edge_count = torch.zeros(self.E, dtype=torch.float)

        loader = torch.utils.data.DataLoader(self, batch_size=200,
                                             collate_fn=lambda x: x,
                                             num_workers=self.num_workers)

        if self.log:  # pragma: no cover
            pbar = tqdm(total=self.N * self.sample_coverage)
            pbar.set_description('Compute GraphSAINT normalization')

        num_samples = 0
        total_sampled_nodes = 0
        while total_sampled_nodes < self.N * self.sample_coverage:
            for data in loader:
                for node_idx, adj in data:
                    edge_idx = adj.storage.value()
                    node_count[node_idx] += 1
                    edge_count[edge_idx] += 1
                    total_sampled_nodes += node_idx.size(0)

                    if self.log:  # pragma: no cover
                        pbar.update(node_idx.size(0))
            num_samples += self.num_steps

        if self.log:  # pragma: no cover
            pbar.close()

        row, _, edge_idx = self.adj.coo()
        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])
        edge_norm = (t / edge_count)
        edge_norm[edge_norm == float('inf')] = 0.1
        edge_norm[torch.isnan(edge_norm)] = 0.1

        node_count[node_count == 0] = 0.1
        node_norm = num_samples / node_count / self.N

        return node_norm, edge_norm


class SaintRWValSampler(BaseLoader):

    def __init__(self,
                 graph: Data,
                 output_indices: torch.LongTensor,
                 return_edge_index_type: str,
                 num_nodes: int,
                 walk_length: int,
                 sample_coverage: int = 0,
                 save_dir: Optional[str] = None,
                 **kwargs):

        self.walk_length = walk_length
        self.sample_coverage = sample_coverage

        self.N = num_nodes
        self.E = graph.num_edges

        self.adj = SparseTensor.from_edge_index(graph.edge_index, sparse_sizes=(graph.num_nodes, graph.num_nodes))
        self.adj = self.normalize_adjmat(self.adj, normalization='rw')
        self.edge_weight = self.adj.storage.value()
        self.adj.set_value_(torch.arange(self.E))

        self.original_graph = graph
        self.output_indices = output_indices

        assert return_edge_index_type in ['adj', 'edge_index']
        self.return_edge_index_type = return_edge_index_type

        super().__init__(self, **kwargs)

        if self.sample_coverage > 0:
            path = osp.join(save_dir or '', self.__filename__)
            assert save_dir is not None and osp.exists(path)
            self.node_norm, self.edge_norm = torch.load(path)
        else:
            self.node_norm = torch.ones(self.N)
            self.edge_norm = torch.ones(self.E)

    @property
    def __filename__(self):
        return (f'{self.N}_{self.walk_length}_'
                f'{self.sample_coverage}.pt')

    def __len__(self):
        return len(self.output_indices)

    @property
    def loader_len(self):
        return ceil(len(self.output_indices) / self.batch_size)

    def __getitem__(self, idx):
        return self.output_indices[idx]

    def __collate__(self, data_list):
        prime_nodes = torch.tensor(data_list, dtype=torch.long)
        node_idx = self.adj.random_walk(prime_nodes.flatten(), self.walk_length).view(-1).unique()
        adj, _ = self.adj.saint_subgraph(node_idx)

        _, _, edge_idx = adj.coo()
        if self.return_edge_index_type == 'adj':
            if self.sample_coverage > 0:
                adj.set_value_(self.edge_norm[edge_idx] * self.edge_weight[edge_idx], layout='csr')
            else:
                adj.set_value_(self.edge_weight[edge_idx], layout='csr')
        else:
            adj = torch.vstack([adj.storage.row(), adj.storage.col()])

        node_norm = self.node_norm[prime_nodes]
        node_norm /= node_norm.sum()
        graph = Data(x=self.original_graph.x[node_idx],
                     y=self.original_graph.y[node_idx],
                     edge_index=adj,
                     output_node_mask=torch.from_numpy(np.in1d(node_idx.numpy(), prime_nodes.numpy())),
                     node_norm=node_norm)

        return graph