import numpy as np
import random
from collections import deque, namedtuple
from typing import List
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Dataset
Transition = namedtuple("Transition", ("state", "log_prob", "reward", "done"))
class ReplayBuffer:
"""
Stores the transitions observed during training.
"""
def __init__(self, capacity: float) -> None:
self.capacity = capacity
self.memory = deque([], maxlen=capacity)
def push(self, *args):
self.memory.append(Transition(*args))
def pop(self) -> Transition:
return self.memory.pop()
def sample(self, batch_size: int) -> List[Transition]:
if batch_size > len(self.memory):
return random.sample(self.memory, len(self.memory))
return random.sample(self.memory, batch_size)
def clean(self) -> None:
self.memory = deque([], maxlen=self.capacity)
def __len__(self) -> int:
return len(self.memory)
class BipartiteNodeData(Data):
"""
Data class modelling a single graph.
Parameters
----------
constraint_features : torch.float32
edge_indices : torch.int64
edge_features : torch.float32
variable_features : torch.float32
candidates : torch.int64
candidate_choice : torch.int64
candidate_scores : torch.float32
"""
def __init__(self, constraint_features: torch.Tensor, edge_indices: torch.Tensor, edge_features, variable_features,
candidates, candidate_choice): # removed candidate_scores
super().__init__()
self.constraint_features = constraint_features
self.edge_index = edge_indices
self.edge_attr = edge_features
self.variable_features = variable_features
self.candidates = candidates
self.nb_candidates = len(candidates) if candidates is not None else None
self.candidate_choices = candidate_choice
# self.candidate_scores = candidate_scores
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index':
return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
elif key == 'candidates':
return self.variable_features.size(0)
else:
return super().__inc__(key, value)
class GraphDataset(Dataset):
"""
Dataset class implementing the basic methods to read samples from a file.
Parameters
----------
sample_files : list
List containing the path to the sample files.
"""
def __init__(
self,
state_graph,
# action,
log_probs,
reward,
# next_state_graph,
done,
):
super().__init__()
self.state_graph = state_graph
# self.action = action
self.log_probs = log_probs
self.reward = reward
# self.next_state_graph = next_state_graph
self.done = done
def len(self):
"""
Returns the number of samples in the dataset
"""
return len(self.done)
def get(self, index):
"""
Reads and returns sample at position <index> of the dataset.
Parameters
----------
index : int
Index over the sample file list. Will return sample in this position.
Returns
-------
graph : BipartiteNodeData object
Data sample, in this case a bipartite graph.
"""
sample_observation = self.state_graph[index]
# sample_action = self.action[index]
sample_log_probs = self.log_probs[index]
sample_reward = self.reward[index]
# sample_next_state_graph = self.next_state_graph[index]
sample_done = self.done[index]
return sample_observation, sample_log_probs, sample_reward, sample_done