from torch_geometric.data import Data from torch_geometric.data import Dataset import torch class BipartiteDataRL(Data): def __init__( self, x_c: torch.Tensor, # constraint nodes x_v: torch.Tensor, # variable nodes edge_index: torch.Tensor, # edge index edge_features: torch.Tensor, # edge features candidates, nb_candidates ): super().__init__() self.x_c = x_c self.x_v = x_v self.edge_index = edge_index self.edge_features = edge_features self.candidates = candidates self.nb_candidates = nb_candidates def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index': # change increment method for bipartite graphs return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]]) if key == "candidates": # candidates change by the number of variables return self.x_v.size(0) return super().__inc__(key, value, *args, **kwargs) class BipartiteDatasetRL(Dataset): def __init__(self, state, action, reward, action_set, done, old_log_probs): super().__init__() self.state = state self.action = action self.reward = reward self.action_set = action_set self.done = done self.old_log_probs = old_log_probs def len(self): return len(self.state) def get(self, index): sample_observation = self.state[index] sample_action = self.action[index] sample_reward = self.reward[index] sample_action_set = self.action_set[index] sample_done = self.done[index] sample_old_log_probs = self.old_log_probs[index] # current state constraint_features, edge_indices, edge_features, variable_features = sample_observation graph = BipartiteDataRL( constraint_features, variable_features, edge_indices, edge_features, sample_action_set, len(sample_action_set) ) # We must tell pytorch geometric how many nodes there are, for indexing purposes graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0] return graph, sample_action, sample_reward, sample_done, sample_old_log_probs