import gymnasium as gym from tqdm import tqdm import torch import torch.nn as nn import torch.distributions as D import matplotlib.pyplot as plt class Actor(nn.Module): def __init__(self, *args, **kwargs) -> None: super(Actor, self).__init__() input_dim = kwargs.get("input_dim", 4) hidden_dim = kwargs.get("hidden_dim", 32) action_dim = kwargs.get("action_dim", 2) self.fc = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), ) def forward(self, x): logits = self.fc(x) return D.Categorical(logits=logits) class Critic(nn.Module): def __init__(self, *args, **kwargs) -> None: super(Critic, self).__init__() input_dim = kwargs.get("input_dim", 4) hidden_dim = kwargs.get("hidden_dim", 32) self.fc = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), ) def forward(self, x): logits = self.fc(x) return logits def REINFORCE(actor_optimizer, value_fnc, value_optimizer, rewards, log_probs, states, gamma): discounted_rewards = [] G = 0 # return for r in reversed(rewards): G = r + gamma*G discounted_rewards.insert(0, G) discounted_rewards = torch.tensor(discounted_rewards).unsqueeze(-1) # discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / discounted_rewards.std() states_tensor = torch.stack(states) baseline_values = value_fnc(states_tensor) advantages = discounted_rewards - baseline_values.detach() loss = [] for A_t, log_prob_t in zip(advantages, log_probs): loss.append(-log_prob_t * A_t) loss = torch.stack(loss).sum() actor_optimizer.zero_grad() loss.backward() actor_optimizer.step() value_loss = torch.nn.functional.mse_loss(baseline_values, discounted_rewards) value_optimizer.zero_grad() value_loss.backward() value_optimizer.step() return loss.item() def main(): env = gym.make('CartPole-v1', render_mode="rgb_array") n_episodes = 1000 n_batches = 32 gamma = 0.99 args = {"input_dim": 4,'hidden_dim': 32,'action_dim':2} agent = Actor(**args) value_fnc = Critic(**args) actor_optimizer = torch.optim.Adam(agent.parameters(), lr=1e-3, weight_decay=0.9) value_optimizer = torch.optim.Adam(value_fnc.parameters(), lr=1e-3) fig, axis = plt.subplots() loss_agg = [] ret_agg = [] for episode in range(n_episodes): axis.clear() episode_over = False # make some empty lists for logging. batch_logp = [] # for log probs ep_reward = [] # list for rewards accrued throughout ep batch_states = [] observation, info = env.reset() while not episode_over: batch_states.append(torch.tensor(observation, dtype=torch.float32)) action_dist = agent(torch.tensor(observation, dtype=torch.float32)) action = action_dist.sample() # sample action from categorical # save log probs batch_logp.append(action_dist.log_prob(action)) observation, reward, terminated, truncated, info = env.step(action.item()) # take a step episode_over = terminated or truncated # save reward ep_reward.append(reward) if episode_over: loss = REINFORCE(actor_optimizer, value_fnc, value_optimizer, ep_reward, batch_logp, batch_states, gamma) # loss_agg.append(loss) ret_agg.append(sum(ep_reward)) axis.plot(ret_agg) plt.pause(0.001) env.close() if __name__ == "__main__": main()