EECE571F-project / learn2branch_gasse / replay.py
replay.py
Raw
# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.

# Source: https://github.com/Qualcomm-AI-research/neural-simulated-annealing/blob/main/neuralsa/training/replay.py

import random
from collections import deque, namedtuple
from typing import List

# Transition = namedtuple('Transition', ('state', 'reward', 'action_set', 'action'))
Transition = namedtuple("Transition", ("state", "action", "reward", "action_set", "done"))


class Replay:
    """
    Stores the transitions observed during training.
    """

    def __init__(self, capacity: float) -> None:
        self.capacity = capacity
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        if len(self.memory) == self.capacity:
            print("Replay buffer is full. Oldest entry will be overwritten.")
        self.memory.append(Transition(*args))

    def pop(self) -> Transition:
        return self.memory.pop()

    def sample(self, batch_size: int) -> List[Transition]:
        if batch_size > len(self.memory):
            return random.sample(self.memory, len(self.memory))
        return random.sample(self.memory, batch_size)

    def clean(self) -> None:
        self.memory = deque([], maxlen=self.capacity)

    def __len__(self) -> int:
        return len(self.memory)