from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch
import numpy as np
class BipartiteData(Data):
def __init__(
self,
x_c: torch.Tensor, # constraint nodes
x_v: torch.Tensor, # variable nodes
edge_index: torch.Tensor, # edge index
edge_features: torch.Tensor, # edge features
candidates: torch.Tensor, # set of candidate variables to branch on
candidates_scores: torch.Tensor, # branch scores for each candidate
candidate_choice: torch.Tensor, # candidate picked by expert
nb_candidates: int
):
super().__init__()
self.x_c = x_c
self.x_v = x_v
self.edge_index = edge_index
self.edge_features = edge_features
self.candidates = candidates
self.candidates_scores = candidates_scores
self.candidate_choice = candidate_choice
self.nb_candidates = nb_candidates
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index': # change increment method for bipartite graphs
return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
if key == "candidates": # candidates change by the number of variables
return self.x_v.size(0)
return super().__inc__(key, value, *args, **kwargs)
class BipartiteDataset(Dataset):
def __init__(self, samples: list):
super().__init__()
self.samples = samples
def len(self):
return len(self.samples)
def get(self, index):
sample_observation, sample_action, sample_action_set, sample_scores = self.samples[index]
constraint_features = sample_observation.row_features
edge_indices = sample_observation.edge_features.indices.astype(np.int32)
edge_features = np.expand_dims(sample_observation.edge_features.values, axis=-1)
variable_features = sample_observation.variable_features
candidates = np.array(sample_action_set, dtype=np.int32)
candidate_scores = np.array([sample_scores[j] for j in candidates])
candidate_choice = np.where(candidates == sample_action)[0][0]
graph = BipartiteData(
torch.FloatTensor(constraint_features),
torch.FloatTensor(variable_features),
torch.LongTensor(edge_indices),
torch.FloatTensor(edge_features),
torch.LongTensor(candidates),
torch.FloatTensor(candidate_scores),
torch.LongTensor([candidate_choice]),
len(candidates),
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
return graph
class BipartiteDataRL(Data):
def __init__(
self,
x_c: torch.Tensor, # constraint nodes
x_v: torch.Tensor, # variable nodes
edge_index: torch.Tensor, # edge index
edge_features: torch.Tensor, # edge features
candidates,
nb_candidates
):
super().__init__()
self.x_c = x_c
self.x_v = x_v
self.edge_index = edge_index
self.edge_features = edge_features
self.candidates = candidates
self.nb_candidates = nb_candidates
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index': # change increment method for bipartite graphs
return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
if key == "candidates": # candidates change by the number of variables
return self.x_v.size(0)
return super().__inc__(key, value, *args, **kwargs)
class BipartiteDatasetRL(Dataset):
def __init__(self, state, action, reward, action_set, done):
super().__init__()
self.state = state
self.action = action
# self.next_state = next_state
self.reward = reward
self.action_set = action_set
# self.next_action_set = next_action_set
self.done = done
def len(self):
return len(self.state)
def get(self, index):
sample_observation = self.state[index]
sample_action = self.action[index]
# sample_next_observation = self.next_state[index]
sample_reward = self.reward[index]
sample_action_set = self.action_set[index]
# sample_next_action_set = self.next_action_set[index]
sample_done = self.done[index]
# current state
constraint_features, edge_indices, edge_features, variable_features = sample_observation
graph = BipartiteDataRL(
constraint_features,
variable_features,
edge_indices,
edge_features,
sample_action_set,
len(sample_action_set)
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
return graph, sample_action, sample_reward, sample_done