# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.
import random
from collections import deque, namedtuple
from typing import List
Transition = namedtuple(
"Transition", ("state", "action", "next_state", "reward", "old_log_probs", "gamma")
)
class Replay:
"""
Stores the transitions observed during training.
"""
def __init__(self, capacity: float) -> None:
self.capacity = capacity
self.memory = deque([], maxlen=capacity)
def push(self, *args):
self.memory.append(Transition(*args))
def pop(self) -> Transition:
return self.memory.pop()
def sample(self, batch_size: int) -> List[Transition]:
if batch_size > len(self.memory):
return random.sample(self.memory, len(self.memory))
return random.sample(self.memory, batch_size)
def clean(self) -> None:
self.memory = deque([], maxlen=self.capacity)
def __len__(self) -> int:
return len(self.memory)