Lightning Baselines3 in 3 Steps¶
Step 1: Choose an algorithm¶
We will use A2C in this example.
# Minimal Example for the CartPole-v1 environment with PPO
import gym
import torch
from torch import distributions
from torch import nn
import pytorch_lightning as pl
from lightning_baselines3.on_policy_models import A2C
Step 2: Define Your Model¶
class Model(A2C):
def __init__(self, **kwargs):
# **kwargs will pass our arguments on to PPO
super(Model, self).__init__(**kwargs)
self.actor = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, self.action_space.n),
nn.Softmax(dim=1))
self.critic = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1))
self.save_hyperparameters()
# This is for training the model, output the distribution and the corresponding value function estimate
def forward(self, x):
out = self.actor(x)
dist = distributions.Categorical(probs=out)
return dist, self.critic(x).flatten()
# We need this for inference and evaluation of our model
def predict(self, x, deterministic=True):
out = self.actor(x)
if deterministic:
out = torch.max(out, dim=1)[1]
else:
out = distributions.Categorical(probs=out).sample()
return out.cpu().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
return optimizer
Step 2: Fit with Lightning Trainer¶
if __name__ == '__main__':
env = gym.make('CartPole-v1') # Make the environment
model = Model(env=env, eval_env=env) # Use that environment for training and evaluation
# Add some gradient clipping for good measure
trainer = pl.Trainer(max_epochs=20, gradient_clip_val=0.5)
trainer.fit(model)
# Evaluate
model.evaluate(num_eval_episodes=10, render=True)