from GCL2.augmentors.augmentor import Graph, Augmentor from GCL2.augmentors.functional import dropout_adj class EdgeRemoving(Augmentor): def __init__(self, pe: float): super(EdgeRemoving, self).__init__() self.pe = pe def augment(self, g: Graph) -> Graph: x, edge_index, edge_weights = g.unfold() edge_index, edge_weights = dropout_adj(edge_index, edge_attr=edge_weights, p=self.pe) return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights) class EdgeRemoving_degree(Augmentor): def __init__(self, pe: float): super(EdgeRemoving_degree, self).__init__() self.pe = pe def augment(self, g: Graph) -> Graph: x, edge_index, edge_weights = g.unfold() edge_index, edge_weights = dropout_adj(edge_index, edge_attr=edge_weights, p=self.pe) return Graph(x=x, edge_index=edge_index, edge_weights=edge_weights)