import argparse from distutils.util import strtobool import json import os import pickle import tensorflow as tf from softlearning.environments.utils import get_environment_from_params from softlearning.policies.utils import get_policy_from_variant from softlearning.samplers import rollouts import pandas as pd def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--checkpoint_path', type=str, default="/home/kai/pc_scp/pc1/ray_mbpo/Valkyrie/defaults/seed:4650_2020-10-28_17-26-31tm52_dwv/checkpoint_1000", help='Path to the checkpoint.') parser.add_argument('--max-path-length', '-l', type=int, default=1000) parser.add_argument('--num-rollouts', '-n', type=int, default=10) parser.add_argument('--render-mode', '-r', type=str, default='human', choices=('human', 'rgb_array', None), help="Mode to render the rollouts in.") parser.add_argument('--deterministic', '-d', type=lambda x: bool(strtobool(x)), nargs='?', const=True, default=True, help="Evaluate policy deterministically.") args = parser.parse_args() return args def simulate_policy(args): session = tf.keras.backend.get_session() checkpoint_path = args.checkpoint_path.rstrip('/') experiment_path = os.path.dirname(checkpoint_path) variant_path = os.path.join(experiment_path, 'params.json') with open(variant_path, 'r') as f: variant = json.load(f) with session.as_default(): pickle_path = os.path.join(checkpoint_path, 'checkpoint.pkl') with open(pickle_path, 'rb') as f: picklable = pickle.load(f) environment_params = ( variant['environment_params']['evaluation'] if 'evaluation' in variant['environment_params'] else variant['environment_params']['training']) environment_params.update({"kwargs": {"renders": True}}) evaluation_environment = get_environment_from_params(environment_params) # mb = evaluation_environment._env.env.model.body_mass # for idx, item in enumerate(mb): # evaluation_environment._env.env.sim.model.body_mass[idx] = item*0.1 evaluation_environment._env.env._renders = True policy = ( get_policy_from_variant(variant, evaluation_environment, Qs=[None])) policy.set_weights(picklable['policy_weights']) with policy.set_deterministic(args.deterministic): paths = rollouts(args.num_rollouts, evaluation_environment, policy, path_length=args.max_path_length, render_mode=args.render_mode) if args.render_mode != 'human': from pprint import pprint import pdb pdb.set_trace() pass return paths if __name__ == '__main__': args = parse_args() simulate_policy(args)