import gymnasium as gym
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributions as D
import matplotlib.pyplot as plt
class Actor(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(Actor, self).__init__()
input_dim = kwargs.get("input_dim", 4)
hidden_dim = kwargs.get("hidden_dim", 32)
action_dim = kwargs.get("action_dim", 2)
self.fc = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
)
def forward(self, x):
logits = self.fc(x)
return D.Categorical(logits=logits)
class Critic(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super(Critic, self).__init__()
input_dim = kwargs.get("input_dim", 4)
hidden_dim = kwargs.get("hidden_dim", 32)
self.fc = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
logits = self.fc(x)
return logits
def REINFORCE(actor_optimizer, value_fnc, value_optimizer, rewards, log_probs, states, gamma):
discounted_rewards = []
G = 0 # return
for r in reversed(rewards):
G = r + gamma*G
discounted_rewards.insert(0, G)
discounted_rewards = torch.tensor(discounted_rewards).unsqueeze(-1)
# discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / discounted_rewards.std()
states_tensor = torch.stack(states)
baseline_values = value_fnc(states_tensor)
advantages = discounted_rewards - baseline_values.detach()
loss = []
for A_t, log_prob_t in zip(advantages, log_probs):
loss.append(-log_prob_t * A_t)
loss = torch.stack(loss).sum()
actor_optimizer.zero_grad()
loss.backward()
actor_optimizer.step()
value_loss = torch.nn.functional.mse_loss(baseline_values, discounted_rewards)
value_optimizer.zero_grad()
value_loss.backward()
value_optimizer.step()
return loss.item()
def main():
env = gym.make('CartPole-v1', render_mode="rgb_array")
n_episodes = 1000
n_batches = 32
gamma = 0.99
args = {"input_dim": 4,'hidden_dim': 32,'action_dim':2}
agent = Actor(**args)
value_fnc = Critic(**args)
actor_optimizer = torch.optim.Adam(agent.parameters(), lr=1e-3, weight_decay=0.9)
value_optimizer = torch.optim.Adam(value_fnc.parameters(), lr=1e-3)
fig, axis = plt.subplots()
loss_agg = []
ret_agg = []
for episode in range(n_episodes):
axis.clear()
episode_over = False
# make some empty lists for logging.
batch_logp = [] # for log probs
ep_reward = [] # list for rewards accrued throughout ep
batch_states = []
observation, info = env.reset()
while not episode_over:
batch_states.append(torch.tensor(observation, dtype=torch.float32))
action_dist = agent(torch.tensor(observation, dtype=torch.float32))
action = action_dist.sample() # sample action from categorical
# save log probs
batch_logp.append(action_dist.log_prob(action))
observation, reward, terminated, truncated, info = env.step(action.item()) # take a step
episode_over = terminated or truncated
# save reward
ep_reward.append(reward)
if episode_over:
loss = REINFORCE(actor_optimizer, value_fnc, value_optimizer, ep_reward, batch_logp, batch_states, gamma)
# loss_agg.append(loss)
ret_agg.append(sum(ep_reward))
axis.plot(ret_agg)
plt.pause(0.001)
env.close()
if __name__ == "__main__":
main()