nmi-val / deepq / playground / policies / memory.py
memory.py
Raw
from collections import deque, namedtuple
import numpy as np
import itertools

# This is the default buffer record nametuple type.
Transition = namedtuple('Transition', ['s', 'a', 'r', 's_next', 'done'])


class ReplayMemory:
    def __init__(self, capacity=100000, replace=False, tuple_class=Transition):
        self.buffer = []
        self.capacity = capacity
        self.replace = replace
        self.tuple_class = tuple_class
        self.fields = tuple_class._fields

    def add(self, record):
        """Any named tuple item."""
        if isinstance(record, self.tuple_class):
            self.buffer.append(record)
        elif isinstance(record, list):
            self.buffer += record

        while self.capacity and self.size > self.capacity:
            print("Buffer overflow, popping")
            self.buffer.pop(0)

    def _reformat(self, indices):
        # Reformat a list of Transition tuples for training.
        # indices: list<int>
        return {
            field_name: np.array([getattr(self.buffer[i], field_name) for i in indices])
            for field_name in self.fields
        }

    def sample(self, batch_size):
        assert len(self.buffer) >= batch_size
        idxs = np.random.choice(range(len(self.buffer)), size=batch_size, replace=self.replace)
        return self._reformat(idxs)

    def pop(self, batch_size):
        # Pop the first `batch_size` Transition items out.
        i = min(self.size, batch_size)
        batch = self._reformat(range(i))
        self.buffer = self.buffer[i:]
        return batch

    def loop(self, batch_size, epoch=None):
        indices = []
        ep = None
        for i in itertools.cycle(range(len(self.buffer))):
            indices.append(i)
            if i == 0:
                ep = 0 if ep is None else ep + 1
            if epoch is not None and ep == epoch:
                break

            if len(indices) == batch_size:
                yield self._reformat(indices)
                indices = []

    @property
    def size(self):
        return len(self.buffer)


class ReplayTrajMemory:
    def __init__(self, capacity=100000, step_size=16):
        self.buffer = deque(maxlen=capacity)
        self.step_size = step_size

    def add(self, traj):
        # traj (list<Transition>)
        if len(traj) >= self.step_size:
            self.buffer.append(traj)

    def sample(self, batch_size):
        traj_idxs = np.random.choice(range(len(self.buffer)), size=batch_size, replace=True)
        batch_data = {field_name: [] for field_name in Transition._fields}

        for traj_idx in traj_idxs:
            i = np.random.randint(0, len(self.buffer[traj_idx]) + 1 - self.step_size)
            transitions = self.buffer[traj_idx][i: i + self.step_size]

            for field_name in Transition._fields:
                batch_data[field_name] += [getattr(t, field_name) for t in transitions]

        assert all(len(v) == batch_size * self.step_size for v in batch_data.values())
        return {k: np.array(v) for k, v in batch_data.items()}

    @property
    def size(self):
        return len(self.buffer)

    @property
    def transition_size(self):
        return sum(map(len, self.buffer))