import torch as th import numpy as np import dgl def random_aug(graph, x, feat_drop_rate, edge_mask_rate): n_node = graph.number_of_nodes() edge_mask = mask_edge(graph, edge_mask_rate) feat = drop_feature(x, feat_drop_rate) ng = dgl.graph([]) ng.add_nodes(n_node) src = graph.edges()[0] dst = graph.edges()[1] nsrc = src[edge_mask] ndst = dst[edge_mask] ng.add_edges(nsrc, ndst) return ng, feat def random_aug_white(graph, x, feat_drop_rate, edge_mask_rate): n_node = graph.number_of_nodes() edge_mask = mask_edge(graph, edge_mask_rate) feat = drop_feature(x, feat_drop_rate) ng = dgl.graph([]) ng.add_nodes(n_node) src = graph.edges()[0] dst = graph.edges()[1] nsrc = src[edge_mask] ndst = dst[edge_mask] ng.add_edges(nsrc, ndst) aug_list=np.concatenate((nsrc.reshape(-1,1),ndst.reshape(-1,1)),axis=1) return ng, feat,aug_list.transpose() def drop_feature(x, drop_prob): drop_mask = th.empty( (x.size(1),), dtype=th.float32, device=x.device).uniform_(0, 1) < drop_prob x = x.clone() x[:, drop_mask] = 0 return x def mask_edge(graph, mask_prob): E = graph.number_of_edges() mask_rates = th.FloatTensor(np.ones(E) * mask_prob) masks = th.bernoulli(1 - mask_rates) mask_idx = masks.nonzero().squeeze(1) return mask_idx