EECE571F-project / policy_graident / REINFORCE.py
REINFORCE.py
Raw
import gymnasium as gym
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributions as D
import matplotlib.pyplot as plt


class Actor(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(Actor, self).__init__()
        input_dim = kwargs.get("input_dim", 4)
        hidden_dim = kwargs.get("hidden_dim", 32)
        action_dim =  kwargs.get("action_dim", 2)

        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )

    def forward(self, x):
        logits = self.fc(x)
        return D.Categorical(logits=logits)
    

class Critic(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(Critic, self).__init__()
        input_dim = kwargs.get("input_dim", 4)
        hidden_dim = kwargs.get("hidden_dim", 32)

        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        logits = self.fc(x)
        return logits
    
def REINFORCE(actor_optimizer, value_fnc, value_optimizer, rewards, log_probs, states, gamma):
    discounted_rewards = []
    G = 0 # return
    for r in reversed(rewards):
        G = r + gamma*G
        discounted_rewards.insert(0, G)
    
    discounted_rewards = torch.tensor(discounted_rewards).unsqueeze(-1)
    # discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / discounted_rewards.std()

    states_tensor = torch.stack(states)
    baseline_values = value_fnc(states_tensor)
    advantages = discounted_rewards - baseline_values.detach()

    loss = []

    for A_t, log_prob_t in zip(advantages, log_probs):
        loss.append(-log_prob_t * A_t)
    loss = torch.stack(loss).sum()
   
    actor_optimizer.zero_grad()
    loss.backward()
    actor_optimizer.step()

    value_loss = torch.nn.functional.mse_loss(baseline_values, discounted_rewards)

    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()

    return loss.item()



def main():
    env = gym.make('CartPole-v1', render_mode="rgb_array")
    n_episodes = 1000
    n_batches = 32
    gamma = 0.99

    args = {"input_dim": 4,'hidden_dim': 32,'action_dim':2}
    agent = Actor(**args)
    value_fnc = Critic(**args)

    actor_optimizer = torch.optim.Adam(agent.parameters(), lr=1e-3, weight_decay=0.9)
    value_optimizer = torch.optim.Adam(value_fnc.parameters(), lr=1e-3)


    fig, axis = plt.subplots()
    loss_agg = []
    ret_agg = []
    for episode in range(n_episodes):
        axis.clear()

        episode_over = False 

        # make some empty lists for logging.
        batch_logp = []         # for log probs
        ep_reward = []            # list for rewards accrued throughout ep
        batch_states = []

        observation, info = env.reset()
        while not episode_over:
            batch_states.append(torch.tensor(observation, dtype=torch.float32))

            action_dist = agent(torch.tensor(observation, dtype=torch.float32))
            action = action_dist.sample() # sample action from categorical
            
            # save log probs
            batch_logp.append(action_dist.log_prob(action))

            observation, reward, terminated, truncated, info = env.step(action.item()) # take a step
            episode_over = terminated or truncated
            
            # save reward
            ep_reward.append(reward)

            if episode_over:
                loss = REINFORCE(actor_optimizer, value_fnc, value_optimizer, ep_reward, batch_logp, batch_states, gamma)
                # loss_agg.append(loss)
                ret_agg.append(sum(ep_reward))
                axis.plot(ret_agg)
                plt.pause(0.001)


    env.close()

if __name__ == "__main__":
    main()