import logging import pandas as pd import queue as Q from heapq import heappush, heappop, heapify from math import ceil from typing import Optional, List, Tuple import numba import numpy as np import torch from torch.utils.data import Sampler from scipy.sparse import csr_matrix from torch_geometric.data import Data from torch_geometric.utils import is_undirected # from torch_geometric.transforms.gdc import get_calc_ppr from torch_sparse import SparseTensor from tqdm import tqdm from dataloaders.utils import topk_ppr_matrix from .BaseLoader import BaseLoader def get_pairs(ppr_mat: csr_matrix) -> np.ndarray: """ Get symmetric ppr pairs. (Only upper triangle) :param ppr_mat: :return: """ ppr_mat = ppr_mat + ppr_mat.transpose() ppr_mat = ppr_mat.tocoo() # find issue: https://github.com/scipy/scipy/blob/v1.7.1/scipy/sparse/extract.py#L12-L40 row, col, data = ppr_mat.row, ppr_mat.col, ppr_mat.data mask = (row > col) # lu row, col, data = row[mask], col[mask], data[mask] sort_arg = np.argsort(data)[::-1] # sort_arg = parallel_sort.parallel_argsort(data)[::-1] # map prime_nodes to arange ppr_pairs = np.vstack((row[sort_arg], col[sort_arg])).T return ppr_pairs @numba.njit(cache=True) def prime_orient_merge(ppr_pairs: np.ndarray, primes_per_batch: int, num_nodes: int): """ :param ppr_pairs: :param primes_per_batch: :param num_nodes: :return: """ # cannot use list for id_primes_list, updating node_id_list[id_primes_list[id2]] require id_primes_list to be array id_primes_list = list(np.arange(num_nodes, dtype=np.int32).reshape(-1, 1)) node_id_list = np.arange(num_nodes, dtype=np.int32) placeholder = np.zeros(0, dtype=np.int32) # size_flag = [{a} for a in np.arange(num_nodes, dtype=np.int32)] for i, j in ppr_pairs: id1, id2 = node_id_list[i], node_id_list[j] if id1 > id2: id1, id2 = id2, id1 # if not (id1 in size_flag[id2] or id2 in size_flag[id1]) if id1 != id2 and len(id_primes_list[id1]) + len(id_primes_list[id2]) <= primes_per_batch: id_primes_list[id1] = np.concatenate((id_primes_list[id1], id_primes_list[id2])) node_id_list[id_primes_list[id2]] = id1 # node_id_list[j] = id1 id_primes_list[id2] = placeholder prime_lst = list() ids = np.unique(node_id_list) for _id in ids: prime_lst.append(list(id_primes_list[_id])) return list(prime_lst) def prime_post_process(loader, merge_max_size): h = [(len(p), p,) for p in loader] heapify(h) while len(h) > 1: len1, p1 = heappop(h) len2, p2 = heappop(h) if len1 + len2 <= merge_max_size: heappush(h, (len1 + len2, p1 + p2)) else: heappush(h, (len1, p1,)) heappush(h, (len2, p2,)) break new_batch = [] while len(h): _, p = heappop(h) new_batch.append(p) return new_batch @numba.njit(cache=True, locals={'p1': numba.int64, 'p2': numba.int64, 'p3': numba.int64, 'new_list': numba.int64[::1]}) def merge_lists(lst1, lst2): p1, p2, p3 = numba.int64(0), numba.int64(0), numba.int64(0) new_list = np.zeros(len(lst1) + len(lst2), dtype=np.int64) while p2 < len(lst2) and p1 < len(lst1): if lst2[p2] <= lst1[p1]: new_list[p3] = lst2[p2] p2 += 1 if lst2[p2 - 1] == lst1[p1]: p1 += 1 elif lst2[p2] > lst1[p1]: new_list[p3] = lst1[p1] p1 += 1 p3 += 1 if p2 == len(lst2) and p1 == len(lst1): return new_list[:p3] elif p1 == len(lst1): rest = lst2[p2:] else: rest = lst1[p1:] p3_ = p3 + len(rest) new_list[p3: p3_] = rest return new_list[:p3_] @numba.njit(cache=True, locals={'node_id_list': numba.int64[::1], 'placeholder': numba.int64[::1], 'id1': numba.int64, 'id2': numba.int64}) def aux_orient_merge(ppr_pairs, prime_indices, id_second_list, merge_max_size): thresh = numba.int64(merge_max_size * 1.0005) num_nodes = len(prime_indices) node_id_list = np.arange(num_nodes, dtype=np.int64) id_prime_list = list(np.arange(num_nodes, dtype=np.int64).reshape(-1, 1)) size_flag = [{a} for a in np.arange(num_nodes, dtype=np.int64)] placeholder = np.zeros(0, dtype=np.int64) for (n1, n2) in ppr_pairs: id1, id2 = node_id_list[n1], node_id_list[n2] id1, id2 = (id1, id2) if id1 < id2 else (id2, id1) if id1 != id2 and not (id2 in size_flag[id1]) and not (id1 in size_flag[id2]): batch_second1 = id_second_list[id1] batch_second2 = id_second_list[id2] if len(batch_second1) + len(batch_second2) <= thresh: new_batch_second = merge_lists(batch_second1, batch_second2) if len(new_batch_second) <= merge_max_size: batch_prime1 = id_prime_list[id1] batch_prime2 = id_prime_list[id2] new_batch_prime = np.concatenate((batch_prime1, batch_prime2)) id_prime_list[id1] = new_batch_prime id_second_list[id1] = new_batch_second id_second_list[id2] = placeholder id_prime_list[id2] = placeholder node_id_list[batch_prime2] = id1 size_flag[id1].update(size_flag[id2]) size_flag[id2].clear() else: size_flag[id1].add(id2) size_flag[id2].add(id1) else: size_flag[id1].add(id2) size_flag[id2].add(id1) prime_second_lst = list() ids = np.unique(node_id_list) for _id in ids: prime_second_lst.append((prime_indices[id_prime_list[_id]], id_second_list[_id])) return list(prime_second_lst) def aux_post_process(loader, merge_max_size): # merge the smallest clusters first que = Q.PriorityQueue() for p, n in loader: que.put((len(n), (list(p), list(n)))) while que.qsize() > 1: len1, (p1, n1) = que.get() len2, (p2, n2) = que.get() n = merge_lists(np.array(n1), np.array(n2)) if len(n) > merge_max_size: que.put((len1, (p1, n1))) que.put((len2, (p2, n2))) break else: que.put((len(n), (p1 + p2, list(n)))) new_batch = [] while not que.empty(): _, (p, n) = que.get() new_batch.append((np.array(p), np.array(n))) return new_batch def flush_PPR_Scores(ppr_matrix,path,graph): MC=ppr_matrix.tocoo() df=pd.DataFrame.from_dict({'src': MC.row, 'dest': MC.col,'ppr_score':["%.5f" % number for number in MC.data]}) df.to_csv(path,index=None) def flush_PPR_Scores_per_type(ppr_matrix,path,graph): types = graph.node_type.unique() MC = ppr_matrix.tocoo() res=[] for type in types: dic={} type_idxs=(graph.node_type == type).nonzero() dic['node_type'] = type.item() dic['min_ppr']=MC.data[type_idxs].min() dic['max_ppr'] = MC.data[type_idxs].max() dic['avg_ppr'] = MC.data[type_idxs].mean() dic['std_ppr'] = MC.data[type_idxs].std() res.append(dic) df=pd.DataFrame(res) df=df.sort_values(by=["avg_ppr"]) df.to_csv(path,index=None) class IBMBNodeLoader(BaseLoader): """ Batch-wise IBMB dataloader from paper Influence-Based Mini-Batching for Graph Neural Networks """ def __init__(self, graph: Data, batch_order: str, output_indices: torch.LongTensor, return_edge_index_type: str, num_auxiliary_node_per_output: int, num_output_nodes_per_batch: Optional[int] = None, num_auxiliary_nodes_per_batch: Optional[int] = None, alpha: float = 0.2, eps: float = 1.e-4, sampler: Sampler = None,FG_adj_df=None, **kwargs): self.subgraphs = [] self.node_wise_out_aux_pairs = [] self.original_graph = None self.adj = None # assert is_undirected(graph.edge_index, num_nodes=graph.num_nodes), "Assume the graph to be undirected" self.cache_data = kwargs['batch_size'] == 1 self._batchsize = kwargs['batch_size'] try: self.output_indices = output_indices.numpy() except: self.output_indices = output_indices assert return_edge_index_type in ['adj', 'edge_index'] self.return_edge_index_type = return_edge_index_type self.num_auxiliary_node_per_output = num_auxiliary_node_per_output self.num_output_nodes_per_batch = num_output_nodes_per_batch self.num_auxiliary_nodes_per_batch = num_auxiliary_nodes_per_batch self.alpha = alpha self.eps = eps self.FG_adj_df=FG_adj_df self.create_node_wise_loader(graph) if len(self.node_wise_out_aux_pairs) > 2: # <= 2 order makes no sense ys = [graph.y[out].numpy() for out, _ in self.node_wise_out_aux_pairs] sampler = self.define_sampler(batch_order, ys, graph.y.max().item() + 1) if not self.cache_data: self.original_graph = graph # need to cache the original graph super().__init__(self.subgraphs if self.cache_data else self.node_wise_out_aux_pairs, sampler=sampler, **kwargs) def create_node_wise_loader(self, graph: Data): logging.info("Start PPR calculation") ppr_matrix, neighbors = topk_ppr_matrix(graph.edge_index, graph.num_nodes, self.alpha, self.eps, self.output_indices, self.num_auxiliary_node_per_output) flush_PPR_Scores_per_type(ppr_matrix, "ogbn_ibmb_ppr_scores_per_type.csv",graph) # flush_PPR_Scores(ppr_matrix, "ogbn_ibmb_ppr_scores.csv",graph) ppr_matrix = ppr_matrix[:, self.output_indices] logging.info("Getting PPR pairs") ppr_pairs = get_pairs(ppr_matrix) assert (self.num_output_nodes_per_batch is not None) ^ (self.num_auxiliary_nodes_per_batch is not None) if self.num_output_nodes_per_batch is not None: logging.info("Output node oriented merging") output_list = prime_orient_merge(ppr_pairs, self.num_output_nodes_per_batch, len(self.output_indices)) output_list = prime_post_process(output_list, self.num_output_nodes_per_batch) node_wise_out_aux_pairs = [] if isinstance(neighbors, list): neighbors = np.array(neighbors, dtype=object) _union = lambda inputs: np.unique(np.concatenate(inputs)) for p in output_list: node_wise_out_aux_pairs.append((self.output_indices[p], _union(neighbors[p]).astype(np.int64))) else: logging.info("Auxiliary node oriented merging") prime_second_lst = aux_orient_merge(ppr_pairs, self.output_indices, list(neighbors), merge_max_size=self.num_auxiliary_nodes_per_batch) node_wise_out_aux_pairs = aux_post_process(prime_second_lst, self.num_auxiliary_nodes_per_batch) self.indices_complete_check(node_wise_out_aux_pairs, self.output_indices) #################### limit number of subraphs print("len(node_wise_out_aux_pairs)",len(node_wise_out_aux_pairs)) max_subgraphs_num=len(node_wise_out_aux_pairs)//1 print("max_subgraphs_num=",max_subgraphs_num) node_wise_out_aux_pairs=node_wise_out_aux_pairs[:max_subgraphs_num] self.node_wise_out_aux_pairs = node_wise_out_aux_pairs if self.return_edge_index_type == 'adj': adj = SparseTensor.from_edge_index(graph.edge_index, sparse_sizes=(graph.num_nodes, graph.num_nodes)) edge_index_0_np=graph.edge_index[0].numpy() edge_index_1_np = graph.edge_index[1].numpy() edge_attr_np=graph.edge_attr.numpy() edge_index_dic= {(edge_index_0_np[idx],edge_index_1_np[idx]):edge_attr_np[idx] for idx in range(0,len(edge_index_0_np))} adj = self.normalize_adjmat(adj, normalization='rw') row, col, val = adj.coo() row_np=row.numpy() col_np = col.numpy() edge_att_adj=[edge_index_dic[(row_np[idx],col_np[idx])] for idx in range(0,len(row_np))] self.FG_adj_df = pd.DataFrame(np.vstack((row.numpy(), col.numpy(),np.array(edge_att_adj))).T) else: adj = None if self.cache_data: self.prepare_cache(graph, node_wise_out_aux_pairs, adj) else: if self.return_edge_index_type == 'adj': self.adj = adj def prepare_cache(self, graph: Data, batch_wise_out_aux_pairs: List[Tuple[np.ndarray, np.ndarray]], adj: Optional[SparseTensor]): pbar = tqdm(batch_wise_out_aux_pairs) pbar.set_description(f"Caching data with type {self.return_edge_index_type}") if self.return_edge_index_type == 'adj': assert adj is not None, "Trying to cache adjacency matrix, got None type." idx=0 max_subgraphs=10 for out, aux in pbar: mask = torch.from_numpy(np.in1d(aux, out)) if isinstance(aux, np.ndarray): aux = torch.from_numpy(aux) subg = self.get_subgraph(aux, graph, self.return_edge_index_type, adj, output_node_mask=mask,FG_adj_df=self.FG_adj_df) self.subgraphs.append(subg) # idx+=1 # if idx>max_subgraphs: # break def __getitem__(self, idx): return self.subgraphs[idx] if self.cache_data else self.node_wise_out_aux_pairs[idx] def __len__(self): return len(self.node_wise_out_aux_pairs) @property def loader_len(self): return ceil(len(self.node_wise_out_aux_pairs) / self._batchsize) def __collate__(self, data_list): if len(data_list) == 1 and isinstance(data_list[0], Data): return data_list[0] out, aux = zip(*data_list) out = np.concatenate(out) aux = np.unique(np.concatenate(aux)) # still need it to be overlapping mask = torch.from_numpy(np.in1d(aux, out)) aux = torch.from_numpy(aux) subg = self.get_subgraph(aux, self.original_graph, self.return_edge_index_type, self.adj, output_node_mask=mask) return subg class IBMBNodeLoader_hetero(BaseLoader): """ Batch-wise IBMB dataloader from paper Influence-Based Mini-Batching for Graph Neural Networks """ def __init__(self, graph: Data, batch_order: str, output_indices: torch.LongTensor, return_edge_index_type: str, num_auxiliary_node_per_output: int, num_output_nodes_per_batch: Optional[int] = None, num_auxiliary_nodes_per_batch: Optional[int] = None, alpha: float = 0.2, eps: float = 1.e-4, sampler: Sampler = None, **kwargs): output_node = list(graph.y_dict.keys())[0] self.subgraphs = [] self.node_wise_out_aux_pairs = [] self.original_graph = None self.adj = None # for key in graph.edge_index_dict.keys(): # n_nodes=graph.num_nodes_dict[key[0]]+graph.num_nodes_dict[key[2]] # assert is_undirected(graph.edge_index_dict[key], num_nodes=n_nodes), "Assume the graph to be undirected" self.cache_data = kwargs['batch_size'] == 1 self._batchsize = kwargs['batch_size'] self.output_indices = output_indices.numpy() assert return_edge_index_type in ['adj', 'edge_index'] self.return_edge_index_type = return_edge_index_type self.num_auxiliary_node_per_output = num_auxiliary_node_per_output self.num_output_nodes_per_batch = num_output_nodes_per_batch self.num_auxiliary_nodes_per_batch = num_auxiliary_nodes_per_batch self.alpha = alpha self.eps = eps self.create_node_wise_loader_hetero(graph,ishetero=1) if len(self.node_wise_out_aux_pairs) > 2: # <= 2 order makes no sense ys = [graph.y_dict[output_node][out].numpy() for out, _ in self.node_wise_out_aux_pairs] sampler = self.define_sampler(batch_order, ys, graph.y_dict[output_node].max().item() + 1) if not self.cache_data: self.original_graph = graph # need to cache the original graph super().__init__(self.subgraphs if self.cache_data else self.node_wise_out_aux_pairs, sampler=sampler, **kwargs) def create_node_wise_loader(self, graph: Data): logging.info("Start PPR calculation") ppr_matrix, neighbors = topk_ppr_matrix(graph.edge_index, graph.num_nodes, self.alpha, self.eps, self.output_indices, self.num_auxiliary_node_per_output) ppr_matrix = ppr_matrix[:, self.output_indices] logging.info("Getting PPR pairs") ppr_pairs = get_pairs(ppr_matrix) assert (self.num_output_nodes_per_batch is not None) ^ (self.num_auxiliary_nodes_per_batch is not None) if self.num_output_nodes_per_batch is not None: logging.info("Output node oriented merging") output_list = prime_orient_merge(ppr_pairs, self.num_output_nodes_per_batch, len(self.output_indices)) output_list = prime_post_process(output_list, self.num_output_nodes_per_batch) node_wise_out_aux_pairs = [] if isinstance(neighbors, list): neighbors = np.array(neighbors, dtype=object) _union = lambda inputs: np.unique(np.concatenate(inputs)) for p in output_list: node_wise_out_aux_pairs.append((self.output_indices[p], _union(neighbors[p]).astype(np.int64))) else: logging.info("Auxiliary node oriented merging") prime_second_lst = aux_orient_merge(ppr_pairs, self.output_indices, list(neighbors), merge_max_size=self.num_auxiliary_nodes_per_batch) node_wise_out_aux_pairs = aux_post_process(prime_second_lst, self.num_auxiliary_nodes_per_batch) self.indices_complete_check(node_wise_out_aux_pairs, self.output_indices) self.node_wise_out_aux_pairs = node_wise_out_aux_pairs if self.return_edge_index_type == 'adj': adj = SparseTensor.from_edge_index(graph.edge_index, sparse_sizes=(graph.num_nodes, graph.num_nodes)) adj = self.normalize_adjmat(adj, normalization='rw') else: adj = None if self.cache_data: self.prepare_cache(graph, node_wise_out_aux_pairs, adj) else: if self.return_edge_index_type == 'adj': self.adj = adj def create_node_wise_loader_hetero(self, graph: Data,ishetero=0): logging.info("Start PPR calculation") output_node=list(graph.y_dict.keys())[0] for (s,p,o) in graph.edge_index_dict.keys(): if (s==output_node ): ppr_matrix, neighbors = topk_ppr_matrix(graph.edge_index_dict[(s,p,o)], graph.num_nodes_dict[s], self.alpha, self.eps, self.output_indices, self.num_auxiliary_node_per_output) ppr_matrix = ppr_matrix[:, self.output_indices] break ## ToDos include all edges logging.info("Getting PPR pairs") ppr_pairs = get_pairs(ppr_matrix) assert (self.num_output_nodes_per_batch is not None) ^ (self.num_auxiliary_nodes_per_batch is not None) if self.num_output_nodes_per_batch is not None: logging.info("Output node oriented merging") output_list = prime_orient_merge(ppr_pairs, self.num_output_nodes_per_batch, len(self.output_indices)) output_list = prime_post_process(output_list, self.num_output_nodes_per_batch) node_wise_out_aux_pairs = [] if isinstance(neighbors, list): neighbors = np.array(neighbors, dtype=object) _union = lambda inputs: np.unique(np.concatenate(inputs)) for p in output_list: node_wise_out_aux_pairs.append((self.output_indices[p], _union(neighbors[p]).astype(np.int64))) else: logging.info("Auxiliary node oriented merging") prime_second_lst = aux_orient_merge(ppr_pairs, self.output_indices, list(neighbors), merge_max_size=self.num_auxiliary_nodes_per_batch) node_wise_out_aux_pairs = aux_post_process(prime_second_lst, self.num_auxiliary_nodes_per_batch) self.indices_complete_check(node_wise_out_aux_pairs, self.output_indices) self.node_wise_out_aux_pairs = node_wise_out_aux_pairs if self.return_edge_index_type == 'adj': for (s, p, o) in graph.edge_index_dict.keys(): if (s == output_node): adj = SparseTensor.from_edge_index(graph.edge_index_dict[(s, p, o)], sparse_sizes=( graph.edge_index_dict[(s, p, o)].max()+1, graph.edge_index_dict[(s, p, o)].max()+1)) adj = self.normalize_adjmat(adj, normalization='rw') else: adj = None if self.cache_data: self.prepare_cache(graph, node_wise_out_aux_pairs, adj,ishetero) else: if self.return_edge_index_type == 'adj': self.adj = adj def prepare_cache(self, graph: Data, batch_wise_out_aux_pairs: List[Tuple[np.ndarray, np.ndarray]], adj: Optional[SparseTensor],ishetero=0): pbar = tqdm(batch_wise_out_aux_pairs) pbar.set_description(f"Caching data with type {self.return_edge_index_type}") if self.return_edge_index_type == 'adj': assert adj is not None, "Trying to cache adjacency matrix, got None type." for out, aux in pbar: mask = torch.from_numpy(np.in1d(aux, out)) if isinstance(aux, np.ndarray): aux = torch.from_numpy(aux) if ishetero==0: if aux.dense_dim[0]<100000: ### ignore large subgraphs subg = self.get_subgraph(aux, graph, self.return_edge_index_type, adj, output_node_mask=mask) self.subgraphs.append(subg) else: subg = self.get_subgraph_hetero(aux, graph, self.return_edge_index_type, adj, output_node_mask=mask) self.subgraphs.append(subg) def __getitem__(self, idx): return self.subgraphs[idx] if self.cache_data else self.node_wise_out_aux_pairs[idx] def __len__(self): return len(self.node_wise_out_aux_pairs) @property def loader_len(self): return ceil(len(self.node_wise_out_aux_pairs) / self._batchsize) def __collate__(self, data_list): if len(data_list) == 1 and isinstance(data_list[0], Data): return data_list[0] out, aux = zip(*data_list) out = np.concatenate(out) aux = np.unique(np.concatenate(aux)) # still need it to be overlapping mask = torch.from_numpy(np.in1d(aux, out)) aux = torch.from_numpy(aux) subg = self.get_subgraph(aux, self.original_graph, self.return_edge_index_type, self.adj, output_node_mask=mask) return subg