MIA-GCL / MVGRL / GCL2 / augmentors / edge_attr_masking.py
edge_attr_masking.py
Raw
from GCL2.augmentors.augmentor import Graph, Augmentor
from GCL2.augmentors.functional import drop_feature


class EdgeAttrMasking(Augmentor):
    def __init__(self, pf: float):
        super(EdgeAttrMasking, self).__init__()
        self.pf = pf

    def augment(self, g: Graph) -> Graph:
        x, edge_index, edge_weights = g.unfold()
        if edge_weights is not None:
            edge_weights = drop_feature(edge_weights, self.pf)
        return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights)