KGTOSA / GNN-Methods / NodeClassifcation / IBS / dataloaders / utils.py
utils.py
Raw
from typing import List, Union, Tuple, Optional

import numba
import numpy as np
import torch
from scipy.sparse import csr_matrix, coo_matrix
from torch_geometric.utils import coalesce
from torch_sparse import SparseTensor
from tqdm import tqdm

def topk_ppr_matrix(edge_index: torch.Tensor,
                    num_nodes: int,
                    alpha: float,
                    eps: float,
                    output_node_indices: Union[np.ndarray, torch.LongTensor],
                    topk: int,
                    normalization='row') -> Tuple[csr_matrix, List[np.ndarray]]:
    """Create a sparse matrix where each node has up to the topk PPR neighbors and their weights."""
    if isinstance(output_node_indices, torch.Tensor):
        output_node_indices = output_node_indices.numpy()

    edge_index = coalesce(edge_index, num_nodes=num_nodes)
    edge_index_np = edge_index.cpu().numpy()

    _, indptr, out_degree = np.unique(edge_index_np[0],
                                      return_index=True,
                                      return_counts=True)
    indptr = np.append(indptr, len(edge_index_np[0]))

    neighbors, weights = calc_ppr_topk_parallel(indptr, edge_index_np[1], out_degree,
                                                alpha, eps, output_node_indices)

    # neighbors, weights = get_calc_ppr()(indptr, edge_index_np[1], out_degree, alpha, eps)

    ppr_matrix = construct_sparse(neighbors, weights, (len(output_node_indices), num_nodes)) ## PPR Ranks
    ppr_matrix = ppr_matrix.tocsr()

    neighbors = sparsify(neighbors, weights, topk)
    neighbors = [np.union1d(nei, pr) for nei, pr in zip(neighbors, output_node_indices)]

    if normalization == 'sym':
        # Assume undirected (symmetric) adjacency matrix
        deg_sqrt = np.sqrt(np.maximum(out_degree, 1e-12))
        deg_inv_sqrt = 1. / deg_sqrt

        row, col = ppr_matrix.nonzero()
        ppr_matrix.data = deg_sqrt[output_node_indices[row]] * ppr_matrix.data * deg_inv_sqrt[col]
    elif normalization == 'col':
        # Assume undirected (symmetric) adjacency matrix
        deg_inv = 1. / np.maximum(out_degree, 1e-12)

        row, col = ppr_matrix.nonzero()
        ppr_matrix.data = out_degree[output_node_indices[row]] * ppr_matrix.data * deg_inv[col]
    elif normalization == 'row':
        pass
    else:
        raise ValueError(f"Unknown PPR normalization: {normalization}")

    return ppr_matrix, neighbors


def construct_sparse(neighbors, weights, shape):
    i = np.repeat(np.arange(len(neighbors)), np.fromiter(map(len, neighbors), dtype=int))
    j = np.concatenate(neighbors)
    return coo_matrix((np.concatenate(weights), (i, j)), shape)


# https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/gdc.html#GDC.diffusion_matrix_approx
# @numba.njit(cache=True, parallel=True)
def calc_ppr_topk_parallel(indptr, indices, deg, alpha, epsilon, nodes) \
        -> Tuple[List[np.ndarray], List[np.ndarray]]:
    js = [np.zeros(0, dtype=np.int64)] * len(nodes)
    vals = [np.zeros(0, dtype=np.float32)] * len(nodes)
    for i in tqdm(numba.prange(len(nodes))):
    # for i in numba.prange(len(nodes)):
        j, val = _calc_ppr_node(nodes[i], indptr, indices, deg, alpha, epsilon)
        js[i] = np.array(j)
        vals[i] = np.array(val)
    return js, vals


@numba.njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
# @numba.njit( locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32})
def _calc_ppr_node(inode, indptr, indices, deg, alpha, epsilon):
    alpha_eps = alpha * epsilon
    f32_0 = numba.float32(0)
    p = {inode: f32_0}
    r = {}
    r[inode] = alpha
    q = [inode]

    start=1
    # while (len(q) > 0 & start <100):
    while len(q) > 0:
        unode = q.pop()

        res = r[unode] if unode in r else f32_0
        if unode in p:
            p[unode] += res
        else:
            p[unode] = res
        r[unode] = f32_0
        for vnode in indices[indptr[unode]:indptr[unode + 1]]:
            _val = (1 - alpha) * res / deg[unode]
            if vnode in r:
                r[vnode] += _val
            else:
                r[vnode] = _val

            res_vnode = r[vnode] if vnode in r else f32_0
            if res_vnode >= alpha_eps * deg[vnode]:
                if vnode not in q:
                    q.append(vnode)
        start+=1

    return list(p.keys()), list(p.values())


def sparsify(neighbors: List[np.ndarray], weights: List[np.ndarray], topk: int):
    new_neighbors = []
    for n, w in zip(neighbors, weights):
        idx_topk = np.argsort(w)[-topk:]
        new_neighbor = n[idx_topk]
        new_neighbors.append(new_neighbor)

    return new_neighbors


def get_partitions(edge_index: Union[torch.LongTensor, SparseTensor],
                   num_partitions: int,
                   indices: torch.LongTensor,
                   num_nodes: int,
                   output_weight: Optional[float] = None) -> List[torch.LongTensor]:
    """
    Graph partitioning using METIS.
    If output_weight is given, assign weights on output nodes.

    :param edge_index:
    :param num_partitions:
    :param indices:
    :param num_nodes:
    :param output_weight:
    :return:
    """

    assert isinstance(edge_index, (torch.LongTensor, SparseTensor)), f'Unsupported edge_index type {type(edge_index)}'
    if isinstance(edge_index, torch.LongTensor):
        edge_index = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes))

    if output_weight is not None and output_weight != 1:
        node_weight = torch.ones(num_nodes)
        node_weight[indices] = output_weight
    else:
        node_weight = None

    _, partptr, perm = edge_index.partition(num_parts=num_partitions,
                                            recursive=False,
                                            weighted=False,
                                            node_weight=node_weight)

    partitions = []
    for i in range(len(partptr) - 1):
        partitions.append(perm[partptr[i]: partptr[i + 1]])

    return partitions