A2C¶
A synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). It uses multiple workers to avoid the use of a replay buffer.
Warning
If you find training unstable or want to match performance of lightning-baselines A2C, consider using
RMSpropTFLike optimizer from lightning_baselines3.common.sb2_compat.rmsprop_tf_like.
Notes¶
Original paper: https://arxiv.org/abs/1602.01783
OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
Can I use?¶
Recurrent policies: ❌
Multi processing: ✔️
Gym spaces:
Space |
Action |
Observation |
|---|---|---|
Discrete |
✔️ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
✔️ |
✔️ |
MultiBinary |
✔️ |
✔️ |
Example¶
Train a A2C agent on CartPole-v1 using 4 environments.
import gym
import torch
from torch import distributions
from torch import nn
import pytorch_lightning as pl
from lightning_baselines3.common.vec_env import make_vec_env, SubprocVecEnv
from lightning_baselines3.on_policy_models import A2C
class Model(A2C):
def __init__(self, **kwargs):
# **kwargs will pass our arguments on to A2C
super(Model, self).__init__(**kwargs)
self.actor = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, self.action_space.n),
nn.Softmax(dim=1))
self.critic = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1))
self.save_hyperparameters()
# This is for training the model
# Returns the distribution and the corresponding value
def forward(self, x):
out = self.actor(x)
dist = distributions.Categorical(probs=out)
return dist, self.critic(x).flatten()
# This is for inference and evaluation of our model, returns the action
def predict(self, x, deterministic=True):
out = self.actor(x)
if deterministic:
out = torch.max(out, dim=1)[1]
else:
out = distributions.Categorical(probs=out).sample()
return out.cpu().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
return optimizer
if __name__ == '__main__':
env = make_vec_env('CartPole-v1', n_envs=4, vec_env_cls=SubprocVecEnv)
eval_env = gym.make('CartPole-v1')
model = Model(env=env, eval_env=eval_env)
trainer = pl.Trainer(max_epochs=20, gradient_clip_val=0.5)
trainer.fit(model)
model.evaluate(num_eval_episodes=10, render=True)
Parameters¶
- class lightning_baselines3.on_policy_models.a2c.A2C(env, eval_env, buffer_length=5, num_rollouts=100, batch_size=128, epochs_per_rollout=1, num_eval_episodes=10, gamma=0.99, gae_lambda=1.0, value_coef=0.5, entropy_coef=0.0, use_sde=False, sde_sample_freq=- 1, verbose=0, seed=None)[source]¶
Advantage Actor Critic (A2C)
Paper: https://arxiv.org/abs/1602.01783 Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and and Stable Baselines 3 (https://github.com/DLR-RM/stable-baselines3)
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
- Parameters
env (
Union[Env,VecEnv,str]) – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)eval_env (
Union[Env,VecEnv,str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)buffer_length (
int) – (int) Length of the buffer and the number of steps to run for each environment per updatenum_rollouts (
int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epochbatch_size (
int) – Minibatch size for each gradient updateepochs_per_rollout (
int) – Number of epochs to optimise the loss fornum_eval_episodes (
int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epochgamma (
float) – (float) Discount factorgae_lambda (
float) – (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1.value_coef (
float) – Value function coefficient for the loss calculationentropy_coef (
float) – Entropy coefficient for the loss calculationuse_sde (
bool) – (bool) Whether to use generalized State Dependent Exploration (gSDE) instead of action noise explorationsde_sample_freq (
int) – (int) Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)verbose (
int) – The verbosity level: 0 none, 1 training information, 2 debugseed (
Optional[int]) – Seed for the pseudo random generators