MIA-GCL / MVGRL / GCL2 / augmentors / feature_masking_mia.py
feature_masking_mia.py
Raw
from GCL2.augmentors.augmentor import Graph, Augmentor
from GCL2.augmentors.functional import drop_feature,drop_feature_mia


class FeatureMasking(Augmentor):
    def __init__(self, pf: float):
        super(FeatureMasking, self).__init__()
        self.pf = pf

    def augment(self, g: Graph) -> Graph:
        x, edge_index, edge_weights = g.unfold()
        x = drop_feature(x, self.pf)
        return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights)


class FeatureMaskingMia(Augmentor):
    def __init__(self, pf: float):
        super(FeatureMaskingMia, self).__init__()
        self.pf = pf

    def augment(self, g: Graph) -> Graph:
        x, edge_index, edge_weights = g.unfold()
        x,drop_mask = drop_feature_mia(x, self.pf)
        return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights),drop_mask