EECE571F-project / ml4co-competition / submissions / dual / training / utils.py
utils.py
Raw
import numpy as np

import random
from collections import deque, namedtuple
from typing import List
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Dataset

Transition = namedtuple("Transition", ("state", "log_prob", "reward", "done"))


class ReplayBuffer:
    """
    Stores the transitions observed during training.
    """

    def __init__(self, capacity: float) -> None:
        self.capacity = capacity
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def pop(self) -> Transition:
        return self.memory.pop()

    def sample(self, batch_size: int) -> List[Transition]:
        if batch_size > len(self.memory):
            return random.sample(self.memory, len(self.memory))
        return random.sample(self.memory, batch_size)

    def clean(self) -> None:
        self.memory = deque([], maxlen=self.capacity)

    def __len__(self) -> int:
        return len(self.memory)
    
class BipartiteNodeData(Data):
    """
    Data class modelling a single graph.

    Parameters
    ----------
    constraint_features : torch.float32
    edge_indices : torch.int64
    edge_features : torch.float32
    variable_features : torch.float32
    candidates : torch.int64
    candidate_choice : torch.int64
    candidate_scores : torch.float32
    """
    def __init__(self, constraint_features: torch.Tensor, edge_indices: torch.Tensor, edge_features, variable_features,
                 candidates, candidate_choice): # removed candidate_scores
        super().__init__()
        self.constraint_features = constraint_features
        self.edge_index = edge_indices
        self.edge_attr = edge_features
        self.variable_features = variable_features
        self.candidates = candidates
        self.nb_candidates = len(candidates) if candidates is not None else None
        self.candidate_choices = candidate_choice
        # self.candidate_scores = candidate_scores

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index':
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value)


class GraphDataset(Dataset):
    """
    Dataset class implementing the basic methods to read samples from a file.

    Parameters
    ----------
    sample_files : list
        List containing the path to the sample files.
    """
    def __init__(
            self, 
            state_graph, 
            # action, 
            log_probs, 
            reward,
            # next_state_graph, 
            done, 
        ):
        super().__init__()
        self.state_graph = state_graph
        # self.action = action
        self.log_probs = log_probs
        self.reward = reward
        # self.next_state_graph = next_state_graph
        self.done = done

    def len(self):
        """
        Returns the number of samples in the dataset
        """
        return len(self.done)

    def get(self, index):
        """
        Reads and returns sample at position <index> of the dataset.

        Parameters
        ----------
        index : int
            Index over the sample file list. Will return sample in this position.

        Returns
        -------
        graph : BipartiteNodeData object
            Data sample, in this case a  bipartite graph.
        """
        sample_observation = self.state_graph[index]
        # sample_action = self.action[index]
        sample_log_probs = self.log_probs[index]
        sample_reward = self.reward[index]
        # sample_next_state_graph = self.next_state_graph[index]
        sample_done = self.done[index]    
        return sample_observation, sample_log_probs, sample_reward, sample_done