from os import path import sys sys.path.append(path.join(path.dirname(__file__), '..')) import os import numpy as np import tensorflow as tf import time import argparse from ast import literal_eval as make_tuple import random import gym import torch import torch.nn as nn import envs import matplotlib.pyplot as plt from utils.logger import Logger from utils.summary import EvaluationMetrics from scipy.io import loadmat # %% arguments parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0, help='activated GPU number') parser.add_argument('--lr', type=float, default=1e-5, help='learning rate') parser.add_argument('--el', type=int, default=32, help='episode length') parser.add_argument('--amp', type=float, default=1e-3, help='amplitude scaling') # 1e-3 parser.add_argument('--ph', type=float, default=1e+1, help='phase scaling') # 1e+1 parser.add_argument('--hss', type=int, default=256, help='length of hidden state in GRU') parser.add_argument('--batch', type=int, default=256, help='batch size (# episodes)') parser.add_argument('--mb', type=int, default=2048, help='mini batch size (for 1 epoch)') parser.add_argument('--v_hs', type=str, default='(256,128,64,32)', help='network structure') parser.add_argument('--gamma', type=float, default=1.0, help='discount factor') parser.add_argument('--lmbda', type=float, default=0.95, help='lambda for GAE') parser.add_argument('--eps', type=float, default=0.1, help='args.eps for PPO') parser.add_argument('--epochs', type=int, default=4, help='number of epochs for gradient-descent') parser.add_argument('--max', type=int, default=300, help='maximum number of iterations') parser.add_argument('--kl', type=float, default=0.01, help='target KL value for early stopping') parser.add_argument('--du', type=float, default=2.56, help='duration of pulse in ms') parser.add_argument('--w_v', type=float, default=1.0, help='value loss weight') parser.add_argument('--amp_std', type=float, default=0.03, help='fixed amplitude standard deviation') parser.add_argument('--ph_std', type=float, default=0.05, help='fixed phase standard deviation') parser.add_argument('--seed', type=int, default=1003, help='random seed') parser.add_argument('--grad', type=float, default=1000.0, help='l2-norm for gradient clipping') parser.add_argument('--save', type=int, default=300, help='save period in iterations') parser.add_argument("--tag", type=str, default='ppo_rnn_exc_21') parser.add_argument("--log_level", type=int, default=10) parser.add_argument("--debug", "-d", action="store_true") parser.add_argument("--quiet", "-q", action="store_true") parser.add_argument("--sampling_rate", type=int, default=256) parser.add_argument("--log_step", type=int, default=10) parser.add_argument("--env", type=str, default="Exc-v11") args = parser.parse_args() # %% preparation step os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # constants output_size = 2 # (magnitude, phase) EPS = 1e-8 # parsing arguments ts = float(args.du) * 1e-3 / (float(args.sampling_rate)) max_rad = 2 * np.pi * 42.577 * 1e+6 * 0.2 * 1e-4 * ts # random seed tf.reset_default_graph() tf.random.set_random_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) # load reference pulse device = 'cuda' if torch.cuda.is_available() else 'cpu' preset = loadmat('../data/conv_rf/SLR_exc.mat') ref_pulse = torch.unsqueeze(torch.from_numpy(np.array(preset['result'], dtype=np.float32)), dim=0).to(device) # %% function definitions def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None): for h in hidden_sizes[:-1]: x = tf.layers.dense(x, units=h, activation=activation) return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation) # %% neural network definition rnn_in = tf.placeholder(tf.float32, [None, None, output_size]) state_in = tf.placeholder(tf.float32, [None, args.hss]) # hidden state in/out size adv = tf.placeholder(tf.float32, [None, ]) # advantages prob = tf.placeholder(tf.float32, [None, ]) # history of probabilities ret_in = tf.placeholder(tf.float32, [None, ]) # returns for training cell = tf.contrib.rnn.GRUCell(args.hss, reuse=tf.AUTO_REUSE) state_out, _ = tf.nn.dynamic_rnn(cell, rnn_in, initial_state=state_in, dtype=tf.float32, time_major=True) state_out = tf.reshape(state_out, [-1, args.hss]) out = mlp(state_out, hidden_sizes=make_tuple(str(args.v_hs)), activation=tf.nn.relu, output_activation=tf.nn.relu) amp_policy_out = tf.exp(tf.layers.dense(out, units=1)) ph_policy_out = tf.layers.dense(out, units=1) mean = tf.concat([amp_policy_out, ph_policy_out], 1) val_out = tf.layers.dense(out, units=1) log_std = tf.multiply([max_rad * (1 / args.amp), 2 * np.pi * (1 / args.ph)], [args.amp_std, args.ph_std]) std = log_std r = tf.random_normal(tf.shape(mean)) pi = mean + tf.multiply(r, std) pi_clip = tf.clip_by_value(pi, [EPS, -np.inf], [max_rad * (1 / args.amp), np.inf]) middle1 = tf.exp(tf.multiply(tf.square(pi_clip - mean), tf.divide(-tf.ones(tf.shape(std)), tf.scalar_mul(2.0, tf.square(std))))) p_pi_a = tf.reduce_prod(tf.multiply(middle1, tf.scalar_mul(1 / np.sqrt(np.pi * 2.0), tf.divide(tf.ones(tf.shape(std)), std))), axis=1) p_pi = p_pi_a + EPS ratio = tf.exp(tf.log(p_pi) - tf.log(prob)) # ratio of pi_theta and pi_theta_old approx_kl = -tf.reduce_mean(tf.log(p_pi) - tf.log(prob)) surr1 = tf.multiply(ratio, adv) surr2 = tf.multiply(tf.clip_by_value(ratio, 1 - args.eps, 1 + args.eps), adv) policy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2)) value_loss = tf.reduce_mean(tf.square(tf.squeeze(val_out) - ret_in)) loss = policy_loss + args.w_v * value_loss # for gradient clipping params = tf.trainable_variables() trainer = tf.train.AdamOptimizer(args.lr) grads_and_var = trainer.compute_gradients(loss, params) grads, var = zip(*grads_and_var) grads, grad_norm = tf.clip_by_global_norm(grads, args.grad) grads_and_var = list(zip(grads, var)) train_opt = trainer.apply_gradients(grads_and_var) # saver saver = tf.train.Saver(max_to_keep=100000) # %% start training start_t = time.time() best_SAR = 1e+10 best_RF = 0 best_ind = 0 logger = Logger('AINV', args) env = gym.make(args.env) # Create summary statistics info = EvaluationMetrics([ 'Rew/Mean', 'SAR/Mean', 'Mz1/Mean', 'Mz2/Mean', 'Rew/Best', 'SAR/Best', 'Mz1/Best', 'Mz2/Best', 'Time', ]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) rf_list = np.empty((1, args.el, 2)) rew_list = np.array([]) for it in range(args.max): it_start_t = time.time() input_list = np.array([]) adv_list = np.array([]) prob_list = np.array([]) val_list = np.array([]) rnn_in_ = np.ones((1, args.batch, output_size)) state_in_ = np.zeros((args.batch, args.hss)) for ep in range(args.el): if ep == 0: input_list = np.copy(rnn_in_) else: input_list = np.vstack((input_list, rnn_in_)) pi__, state_out_, p_pi__, value_out_ = sess.run([pi_clip, state_out, p_pi, val_out], feed_dict={rnn_in: rnn_in_, state_in: state_in_}) if ep == 0: val_list = value_out_ else: val_list = np.hstack((val_list, value_out_)) rnn_in_ = np.expand_dims(pi__, axis=0) # output to input state_in_ = state_out_ # hidden state n-1 to n prob_list = np.append(prob_list, p_pi__) # save action probability input_list = np.vstack((input_list, rnn_in_)) input_list_tmp = np.swapaxes(np.copy(input_list[-args.el:, :, :]), 0, 1) input_list_tmp[:, :, 0] = input_list_tmp[:, :, 0] * args.amp input_list_tmp[:, :, 1] = input_list_tmp[:, :, 1] * args.ph input_list_tmp[:, :, 0] = 2 * (input_list_tmp[:, :, 0] / max_rad) - 1 # -1 ~ 1 # input_list_tmp[:, :, 1] = (input_list_tmp[:, :, 1] % (2 * np.pi) - np.pi) / np.pi # -1 ~ 1 input_list_tmp[:, :, 1] = (input_list_tmp[:, :, 1] - input_list_tmp[:, 0, 1, np.newaxis]) / np.pi rf_list = np.vstack((rf_list, input_list_tmp)) with torch.no_grad(): # for inference only m = nn.Parameter(torch.from_numpy(input_list_tmp[:, :, 0]).to(device)) p = nn.Parameter(torch.from_numpy(input_list_tmp[:, :, 1]).to(device)) b1 = nn.functional.interpolate( torch.stack([m, p], dim=1), size=int(args.sampling_rate/2.0), mode='linear', align_corners=True ) b1 = torch.stack([torch.cat((b1[:, 0, :], torch.fliplr(b1[:, 0, :])), dim=1), torch.cat((b1[:, 1, :], torch.fliplr(b1[:, 1, :])), dim=1)], dim=1) b1 = torch.cat((ref_pulse, b1), dim=0) # Simulation t = 0 done = False total_rewards = 0.0 while not done: Mt, rews, done = env.step(b1[..., t]) t += 1 total_rewards += rews env.reset() rew = total_rewards[1:, ...].detach().cpu().numpy() # SAR amp = ((b1[:, 0] + 1.0) * env.max_amp * 1e+4 / 2).pow(2).sum(-1) # (G^2) sar = amp * env.du / len(env) * 1e6 # (mG^2)*sec # Magnetization Mt1 = Mt[:, :200, 0, :] # passband Mt2 = Mt[:, 200:, 0, :] # stopband ripple1 = torch.max(torch.abs(1. + Mt1[..., 2]), dim=1)[0] ripple2 = torch.max(torch.abs(1. - Mt2[..., 2]), dim=1)[0] mz = Mt[:, :, 0, 2].detach().cpu().numpy() rew_list = np.append(rew_list, rew) # Update statistics idx = np.argmax(rew) info.update('Rew/Mean', np.mean(rew)) info.update('SAR/Mean', sar.mean().item()) info.update('Mz1/Mean', ripple1.mean().item()) info.update('Mz2/Mean', ripple2.mean().item()) info.update('Rew/Best', rew[idx]) info.update('SAR/Best', sar[idx].item()) info.update('Mz1/Best', ripple1[idx].item()) info.update('Mz2/Best', ripple2[idx].item()) info.update('Time', time.time() - start_t) # Log summary statistics if (it + 1) % args.log_step == 0: # Inversion profile profile = plt.figure(1) plt.plot(np.concatenate((env.df[200:200 + 800], env.df[:200], env.df[200 + 800:200 + 1600])), np.concatenate((mz[idx, 200:200 + 800], mz[idx, :200], mz[idx, 200 + 800:200 + 1600]))) # logger.image_summary(profile, it + 1, 'profile') # RF pulse magnitude t = np.linspace(0, env.du / len(env), len(env)) magnitude = b1[:, 0].detach().cpu().numpy() fig_m = plt.figure(2) plt.plot(t, magnitude[idx]) plt.ylim(-1, 1) # logger.image_summary(fig_m, it + 1, 'magnitude') # RF pulse phase phase = b1[:, 1].detach().cpu().numpy() fig_p = plt.figure(3) plt.plot(t, phase[idx]) plt.ylim(-1, 1) # logger.image_summary(fig_p, it + 1, 'phase') if (it + 1) % args.save == 0: array_dict = {'magnitude': magnitude, 'phase': phase, 'sar': sar, 'rf_list': rf_list[1:, ...], 'mz1': ripple1, 'mz2': ripple2, 'rew': rew, 'rew_list': rew_list} # else: # array_dict = {'magnitude': magnitude, 'phase': phase, 'sar': sar, # 'mz1': ripple1, 'mz2': ripple2, 'rew': rew} logger.savemat('pulse' + str(it + 1), array_dict) logger.scalar_summary(info.val, it + 1) info.reset() # GAE # val_list's size: (args.batch, args.el) target = np.roll(val_list, -1, axis=1) * args.gamma target[:, -1] = rew delta = target - val_list adv_list = [] advs = 0 for t in range(args.el, 0, -1): advs = args.gamma * args.lmbda * advs + delta[:, t-1] adv_list.append(advs) adv_list.reverse() adv_list = np.array(adv_list) normalize_rewards = (adv_list - np.mean(adv_list)) / np.std(adv_list) input_list_resize = np.expand_dims( np.reshape(input_list[:-1, :, :], (args.batch * args.el, output_size)), axis=0) norm_reward_resize = normalize_rewards.flatten(order='F') ret = np.repeat(rew[:, np.newaxis], args.el, axis=1).flatten(order='F') for ee in range(args.epochs): _, kl, p_pi_, log_std_, policy_loss_, val_loss_, grad_norm_ = sess.run( [train_opt, approx_kl, p_pi, log_std, policy_loss, value_loss, grad_norm], feed_dict={rnn_in: input_list[:-1, :, :], state_in: np.zeros((args.batch, args.hss)), adv: norm_reward_resize, prob: prob_list, ret_in: ret}) # early stopping if kl > args.kl: break