# Copyright (c) 2023 Qualcomm Technologies, Inc. # All Rights Reserved. # Source: https://github.com/Qualcomm-AI-research/neural-simulated-annealing/blob/main/neuralsa/training/replay.py import random from collections import deque, namedtuple from typing import List # Transition = namedtuple('Transition', ('state', 'reward', 'action_set', 'action')) Transition = namedtuple("Transition", ("state", "action", "reward", "action_set", "done")) class Replay: """ Stores the transitions observed during training. """ def __init__(self, capacity: float) -> None: self.capacity = capacity self.memory = deque([], maxlen=capacity) def push(self, *args): if len(self.memory) == self.capacity: print("Replay buffer is full. Oldest entry will be overwritten.") self.memory.append(Transition(*args)) def pop(self) -> Transition: return self.memory.pop() def sample(self, batch_size: int) -> List[Transition]: if batch_size > len(self.memory): return random.sample(self.memory, len(self.memory)) return random.sample(self.memory, batch_size) def clean(self) -> None: self.memory = deque([], maxlen=self.capacity) def __len__(self) -> int: return len(self.memory)