import torch from abc import ABC, abstractmethod from torch_scatter import scatter class Sampler(ABC): def __init__(self, intraview_negs=False): self.intraview_negs = intraview_negs def __call__(self, anchor, sample, *args, **kwargs): ret = self.sample(anchor, sample, *args, **kwargs) if self.intraview_negs: ret = self.add_intraview_negs(*ret) return ret @abstractmethod def sample(self, anchor, sample, *args, **kwargs): pass @staticmethod def add_intraview_negs(anchor, sample, pos_mask, neg_mask): num_nodes = anchor.size(0) device = anchor.device intraview_pos_mask = torch.zeros_like(pos_mask, device=device) intraview_neg_mask = torch.ones_like(pos_mask, device=device) - torch.eye(num_nodes, device=device) new_sample = torch.cat([sample, anchor], dim=0) # (M+N) * K new_pos_mask = torch.cat([pos_mask, intraview_pos_mask], dim=1) # M * (M+N) new_neg_mask = torch.cat([neg_mask, intraview_neg_mask], dim=1) # M * (M+N) return anchor, new_sample, new_pos_mask, new_neg_mask class SameScaleSampler(Sampler): def __init__(self, *args, **kwargs): super(SameScaleSampler, self).__init__(*args, **kwargs) def sample(self, anchor, sample, *args, **kwargs): assert anchor.size(0) == sample.size(0) num_nodes = anchor.size(0) device = anchor.device pos_mask = torch.eye(num_nodes, dtype=torch.float32, device=device) neg_mask = 1. - pos_mask return anchor, sample, pos_mask, neg_mask class CrossScaleSampler(Sampler): def __init__(self, *args, **kwargs): super(CrossScaleSampler, self).__init__(*args, **kwargs) def sample(self, anchor, sample, batch=None, neg_sample=None, use_gpu=True, *args, **kwargs): num_graphs = anchor.shape[0] # M num_nodes = sample.shape[0] # N device = sample.device if neg_sample is not None: assert num_graphs == 1 # only one graph, explicit negative samples are needed assert sample.shape == neg_sample.shape pos_mask1 = torch.ones((num_graphs, num_nodes), dtype=torch.float32, device=device) pos_mask0 = torch.zeros((num_graphs, num_nodes), dtype=torch.float32, device=device) pos_mask = torch.cat([pos_mask1, pos_mask0], dim=1) # M * 2N sample = torch.cat([sample, neg_sample], dim=0) # 2N * K else: assert batch is not None if use_gpu: ones = torch.eye(num_nodes, dtype=torch.float32, device=device) # N * N pos_mask = scatter(ones, batch, dim=0, reduce='sum') # M * N else: pos_mask = torch.zeros((num_graphs, num_nodes), dtype=torch.float32).to(device) for node_idx, graph_idx in enumerate(batch): pos_mask[graph_idx][node_idx] = 1. # M * N neg_mask = 1. - pos_mask return anchor, sample, pos_mask, neg_mask def get_sampler(mode: str, intraview_negs: bool) -> Sampler: if mode in {'L2L', 'G2G'}: return SameScaleSampler(intraview_negs=intraview_negs) elif mode == 'G2L': return CrossScaleSampler(intraview_negs=intraview_negs) else: raise RuntimeError(f'unsupported mode: {mode}')