EECE571F-project / RL_ecole / replay.py
replay.py
Raw
# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.

import random
from collections import deque, namedtuple
from typing import List

Transition = namedtuple(
    "Transition", ("state", "action", "next_state", "reward", "old_log_probs", "gamma")
)


class Replay:
    """
    Stores the transitions observed during training.
    """

    def __init__(self, capacity: float) -> None:
        self.capacity = capacity
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def pop(self) -> Transition:
        return self.memory.pop()

    def sample(self, batch_size: int) -> List[Transition]:
        if batch_size > len(self.memory):
            return random.sample(self.memory, len(self.memory))
        return random.sample(self.memory, batch_size)

    def clean(self) -> None:
        self.memory = deque([], maxlen=self.capacity)

    def __len__(self) -> int:
        return len(self.memory)