MIA-GCL / MVGRL / GCL2 / augmentors / node_shuffling.py
node_shuffling.py
Raw
from GCL2.augmentors.augmentor import Graph, Augmentor
from GCL2.augmentors.functional import permute


class NodeShuffling(Augmentor):
    def __init__(self):
        super(NodeShuffling, self).__init__()

    def augment(self, g: Graph) -> Graph:
        x, edge_index, edge_weights = g.unfold()
        x = permute(x)
        return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights)