# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.
# Source: https://github.com/Qualcomm-AI-research/neural-simulated-annealing/blob/main/neuralsa/training/replay.py
import random
from collections import deque, namedtuple
from typing import List
# Transition = namedtuple('Transition', ('state', 'reward', 'action_set', 'action'))
Transition = namedtuple("Transition", ("state", "action", "reward", "action_set", "done"))
class Replay:
"""
Stores the transitions observed during training.
"""
def __init__(self, capacity: float) -> None:
self.capacity = capacity
self.memory = deque([], maxlen=capacity)
def push(self, *args):
if len(self.memory) == self.capacity:
print("Replay buffer is full. Oldest entry will be overwritten.")
self.memory.append(Transition(*args))
def pop(self) -> Transition:
return self.memory.pop()
def sample(self, batch_size: int) -> List[Transition]:
if batch_size > len(self.memory):
return random.sample(self.memory, len(self.memory))
return random.sample(self.memory, batch_size)
def clean(self) -> None:
self.memory = deque([], maxlen=self.capacity)
def __len__(self) -> int:
return len(self.memory)