nmi-val / deepq / playground / utils / misc.py
misc.py
Raw
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

from gym.wrappers.monitor import load_results
from copy import deepcopy

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))


class Config:
    def __init__(self, **kwargs):
        # read parameters from parents, and children can override the values.
        parents = []
        queue = [self.__class__]
        while queue:
            parent = queue.pop()
            if issubclass(parent, Config) and parent is not Config:
                parents.append(parent)
                for p in reversed(parent.__bases__):
                    queue.append(p)

        params = {}
        for cfg in reversed(parents):
            params.update(cfg.__dict__)

        # Set all instance variable based on kwargs and default class variables
        for key, value in params.items():
            if key.startswith('__'):
                continue

            if key in kwargs:
                # override default with provided parameter
                value = kwargs[key]
            else:
                # Need to make copies of class variables so that they aren't changed by instances
                value = deepcopy(value)

            self.__dict__[key] = value

    def __setattr__(self, name, value):
        if name not in self.__dict__:
            raise AttributeError(f"{self.__class__.__name__} does not have attribute {name}")
        self.__dict__[name] = value

    def __getattr__(self, name):
        # Raise error on assignment of missing variable
        if name not in self.__dict__:
            raise AttributeError(f"{self.__class__.__name__} does not have attribute {name}")
        return self.__dict__[name]

    def as_dict(self):
        return deepcopy(self.__dict__)

    def copy(self):
        return self.__class__(**self.as_dict())

    def get(self, name, default):
        return self.as_dict().get(name, default)

    def __repr__(self):
        return super().__repr__() + "\n" + self.dumps()


def plot_learning_curve(filename, value_dict, xlabel='step'):
    # Plot step vs the mean(last 50 episodes' rewards)
    fig = plt.figure(figsize=(12, 4 * len(value_dict)))

    for i, (key, values) in enumerate(value_dict.items()):
        ax = fig.add_subplot(len(value_dict), 1, i + 1)
        ax.plot(range(len(values)), values)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(key)
        ax.grid('k--', alpha=0.6)

    plt.tight_layout()
    os.makedirs(os.path.join(REPO_ROOT, 'figs'), exist_ok=True)
    plt.savefig(os.path.join(REPO_ROOT, 'figs', filename))


def plot_from_monitor_results(monitor_dir, window=10):
    assert os.path.exists(monitor_dir)
    if monitor_dir.endswith('/'):
        monitor_dir = monitor_dir[:-1]

    data = load_results(monitor_dir)
    n_episodes = len(data['episode_lengths'])
    assert n_episodes > 0

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), tight_layout=True, sharex=True)

    ax1.plot(range(n_episodes), pd.rolling_mean(np.array(data['episode_lengths']), window))
    ax1.set_xlabel('episode')
    ax1.set_ylabel('episode length')
    ax1.grid('k--', alpha=0.6)

    ax2.plot(range(n_episodes), pd.rolling_mean(np.array(data['episode_rewards']), window))
    ax2.set_xlabel('episode')
    ax2.set_ylabel('episode reward')
    ax2.grid('k--', alpha=0.6)

    os.makedirs(os.path.join(REPO_ROOT, 'figs'), exist_ok=True)
    plt.savefig(os.path.join(REPO_ROOT, 'figs', os.path.basename(monitor_dir) + '-monitor'))