from torch_geometric.data import Data from torch_geometric.data import Dataset import torch import numpy as np class BipartiteData(Data): def __init__( self, x_c: torch.Tensor, # constraint nodes x_v: torch.Tensor, # variable nodes edge_index: torch.Tensor, # edge index edge_features: torch.Tensor, # edge features candidates: torch.Tensor, # set of candidate variables to branch on candidates_scores: torch.Tensor, # branch scores for each candidate candidate_choice: torch.Tensor, # candidate picked by expert nb_candidates: int ): super().__init__() self.x_c = x_c self.x_v = x_v self.edge_index = edge_index self.edge_features = edge_features self.candidates = candidates self.candidates_scores = candidates_scores self.candidate_choice = candidate_choice self.nb_candidates = nb_candidates def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index': # change increment method for bipartite graphs return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]]) if key == "candidates": # candidates change by the number of variables return self.x_v.size(0) return super().__inc__(key, value, *args, **kwargs) class BipartiteDataset(Dataset): def __init__(self, samples: list): super().__init__() self.samples = samples def len(self): return len(self.samples) def get(self, index): sample_observation, sample_action, sample_action_set, sample_scores = self.samples[index] constraint_features = sample_observation.row_features edge_indices = sample_observation.edge_features.indices.astype(np.int32) edge_features = np.expand_dims(sample_observation.edge_features.values, axis=-1) variable_features = sample_observation.variable_features candidates = np.array(sample_action_set, dtype=np.int32) candidate_scores = np.array([sample_scores[j] for j in candidates]) candidate_choice = np.where(candidates == sample_action)[0][0] graph = BipartiteData( torch.FloatTensor(constraint_features), torch.FloatTensor(variable_features), torch.LongTensor(edge_indices), torch.FloatTensor(edge_features), torch.LongTensor(candidates), torch.FloatTensor(candidate_scores), torch.LongTensor([candidate_choice]), len(candidates), ) # We must tell pytorch geometric how many nodes there are, for indexing purposes graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0] return graph class BipartiteDataRL(Data): def __init__( self, x_c: torch.Tensor, # constraint nodes x_v: torch.Tensor, # variable nodes edge_index: torch.Tensor, # edge index edge_features: torch.Tensor, # edge features candidates, nb_candidates ): super().__init__() self.x_c = x_c self.x_v = x_v self.edge_index = edge_index self.edge_features = edge_features self.candidates = candidates self.nb_candidates = nb_candidates def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index': # change increment method for bipartite graphs return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]]) if key == "candidates": # candidates change by the number of variables return self.x_v.size(0) return super().__inc__(key, value, *args, **kwargs) class BipartiteDatasetRL(Dataset): def __init__(self, state, action, reward, action_set, done): super().__init__() self.state = state self.action = action # self.next_state = next_state self.reward = reward self.action_set = action_set # self.next_action_set = next_action_set self.done = done def len(self): return len(self.state) def get(self, index): sample_observation = self.state[index] sample_action = self.action[index] # sample_next_observation = self.next_state[index] sample_reward = self.reward[index] sample_action_set = self.action_set[index] # sample_next_action_set = self.next_action_set[index] sample_done = self.done[index] # current state constraint_features, edge_indices, edge_features, variable_features = sample_observation graph = BipartiteDataRL( constraint_features, variable_features, edge_indices, edge_features, sample_action_set, len(sample_action_set) ) # We must tell pytorch geometric how many nodes there are, for indexing purposes graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0] return graph, sample_action, sample_reward, sample_done