import numpy as np import random from collections import deque, namedtuple from typing import List import torch import torch_geometric from torch_geometric.data import Data from torch_geometric.data import Dataset Transition = namedtuple("Transition", ("state", "log_prob", "reward", "done")) class ReplayBuffer: """ Stores the transitions observed during training. """ def __init__(self, capacity: float) -> None: self.capacity = capacity self.memory = deque([], maxlen=capacity) def push(self, *args): self.memory.append(Transition(*args)) def pop(self) -> Transition: return self.memory.pop() def sample(self, batch_size: int) -> List[Transition]: if batch_size > len(self.memory): return random.sample(self.memory, len(self.memory)) return random.sample(self.memory, batch_size) def clean(self) -> None: self.memory = deque([], maxlen=self.capacity) def __len__(self) -> int: return len(self.memory) class BipartiteNodeData(Data): """ Data class modelling a single graph. Parameters ---------- constraint_features : torch.float32 edge_indices : torch.int64 edge_features : torch.float32 variable_features : torch.float32 candidates : torch.int64 candidate_choice : torch.int64 candidate_scores : torch.float32 """ def __init__(self, constraint_features: torch.Tensor, edge_indices: torch.Tensor, edge_features, variable_features, candidates, candidate_choice): # removed candidate_scores super().__init__() self.constraint_features = constraint_features self.edge_index = edge_indices self.edge_attr = edge_features self.variable_features = variable_features self.candidates = candidates self.nb_candidates = len(candidates) if candidates is not None else None self.candidate_choices = candidate_choice # self.candidate_scores = candidate_scores def __inc__(self, key, value, *args, **kwargs): if key == 'edge_index': return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]]) elif key == 'candidates': return self.variable_features.size(0) else: return super().__inc__(key, value) class GraphDataset(Dataset): """ Dataset class implementing the basic methods to read samples from a file. Parameters ---------- sample_files : list List containing the path to the sample files. """ def __init__( self, state_graph, # action, log_probs, reward, # next_state_graph, done, ): super().__init__() self.state_graph = state_graph # self.action = action self.log_probs = log_probs self.reward = reward # self.next_state_graph = next_state_graph self.done = done def len(self): """ Returns the number of samples in the dataset """ return len(self.done) def get(self, index): """ Reads and returns sample at position of the dataset. Parameters ---------- index : int Index over the sample file list. Will return sample in this position. Returns ------- graph : BipartiteNodeData object Data sample, in this case a bipartite graph. """ sample_observation = self.state_graph[index] # sample_action = self.action[index] sample_log_probs = self.log_probs[index] sample_reward = self.reward[index] # sample_next_state_graph = self.next_state_graph[index] sample_done = self.done[index] return sample_observation, sample_log_probs, sample_reward, sample_done