from torch_geometric.data import Data
from torch_geometric.data import Dataset
import torch
class BipartiteDataRL(Data):
def __init__(
self,
x_c: torch.Tensor, # constraint nodes
x_v: torch.Tensor, # variable nodes
edge_index: torch.Tensor, # edge index
edge_features: torch.Tensor, # edge features
candidates,
nb_candidates
):
super().__init__()
self.x_c = x_c
self.x_v = x_v
self.edge_index = edge_index
self.edge_features = edge_features
self.candidates = candidates
self.nb_candidates = nb_candidates
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index': # change increment method for bipartite graphs
return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
if key == "candidates": # candidates change by the number of variables
return self.x_v.size(0)
return super().__inc__(key, value, *args, **kwargs)
class BipartiteDatasetRL(Dataset):
def __init__(self, state, action, reward, action_set, done, old_log_probs):
super().__init__()
self.state = state
self.action = action
self.reward = reward
self.action_set = action_set
self.done = done
self.old_log_probs = old_log_probs
def len(self):
return len(self.state)
def get(self, index):
sample_observation = self.state[index]
sample_action = self.action[index]
sample_reward = self.reward[index]
sample_action_set = self.action_set[index]
sample_done = self.done[index]
sample_old_log_probs = self.old_log_probs[index]
# current state
constraint_features, edge_indices, edge_features, variable_features = sample_observation
graph = BipartiteDataRL(
constraint_features,
variable_features,
edge_indices,
edge_features,
sample_action_set,
len(sample_action_set)
)
# We must tell pytorch geometric how many nodes there are, for indexing purposes
graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]
return graph, sample_action, sample_reward, sample_done, sample_old_log_probs