SAC¶
Soft Actor Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
SAC is the successor of Soft Q-Learning SQL and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
Notes¶
Original paper: https://arxiv.org/abs/1801.01290
OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Original Implementation: https://github.com/haarnoja/sac
Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
Note
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions.
Note
The example model for SAC differ a bit from others: it uses ReLU instead of tanh activation, to match the original paper
Can I use?¶
Recurrent policies: ❌
Multi processing: ❌
Gym spaces:
Space |
Action |
Observation |
|---|---|---|
Discrete |
❌ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
❌ |
✔️ |
MultiBinary |
❌ |
✔️ |
Example¶
import copy
import numpy as np
import torch
from torch import nn
import pytorch_lightning as pl
from lightning_baselines3.off_policy_models import SAC
from lightning_baselines3.common.distributions import SquashedMultivariateNormal
from lightning_baselines3.common.utils import polyak_update
class Model(SAC):
def __init__(self, *args, **kwargs):
super(Model, self).__init__(*args, **kwargs, squashed_actions=True)
# Note: The output layer of the actor must be Tanh activated
self.actor = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, self.action_space.shape[0] * 2))
in_dim = self.observation_space.shape[0] + self.action_space.shape[0]
self.critic1 = nn.Sequential(
nn.Linear(in_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1))
self.critic2 = nn.Sequential(
nn.Linear(in_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1))
self.critic_target1 = copy.deepcopy(self.critic1)
self.critic_target2 = copy.deepcopy(self.critic2)
self.save_hyperparameters()
def forward_actor(self, x):
out = list(torch.chunk(self.actor(x), 2, dim=1))
out[1] = torch.diag_embed(
torch.exp(torch.clamp(out[1], -5, 5)))
dist = SquashedMultivariateNormal(
loc=torch.tanh(out[0]), scale_tril=out[1])
return dist
def forward_critics(self, obs, action):
out = [
self.critic1(torch.cat([obs, action], dim=1)),
self.critic2(torch.cat([obs, action], dim=1))]
return out
def forward_critic_targets(self, obs, action):
out = [
self.critic_target1(torch.cat([obs, action], dim=1)),
self.critic_target2(torch.cat([obs, action], dim=1))]
return out
def update_targets(self):
polyak_update(
self.critic1.parameters(),
self.critic_target1.parameters(),
tau=0.005)
polyak_update(
self.critic2.parameters(),
self.critic_target2.parameters(),
tau=0.005)
def predict(self, x, deterministic=True):
out = self.actor(x)
if deterministic:
out = torch.chunk(out, 2, dim=1)[0]
else:
out = list(torch.chunk(out, 2, dim=1))
out[1] = torch.diag_embed(
torch.exp(torch.clamp(out[1], -5, 5)))
out = SquashedMultivariateNormal(
loc=torch.tanh(out[0]), scale_tril=out[1]).sample()
return out.cpu().numpy()
def configure_optimizers(self):
opt_actor = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
opt_critic = torch.optim.Adam(
list(self.critic1.parameters()) + list(self.critic2.parameters()),
lr=3e-4)
return opt_critic, opt_actor
if __name__ == '__main__':
model = Model(
env='LunarLanderContinuous-v2',
eval_env='LunarLanderContinuous-v2',
warmup_length=1000)
trainer = pl.Trainer(max_epochs=20, gradient_clip_val=0.5)
trainer.fit(model)
model.evaluate(num_eval_episodes=10, render=True)
Parameters¶
- class lightning_baselines3.off_policy_models.sac.SAC(env, eval_env, batch_size=256, buffer_length=1000000, warmup_length=100, train_freq=1, episodes_per_rollout=- 1, num_rollouts=1000, gradient_steps=1, target_update_interval=1, num_eval_episodes=10, gamma=0.99, entropy_coef='auto', target_entropy='auto', use_sde=False, sde_sample_freq=- 1, use_sde_at_warmup=False, squashed_actions=True, verbose=0, seed=None)[source]¶
Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo (https://github.com/rail-berkeley/softlearning/) and from Stable Baselines (https://github.com/hill-a/stable-baselines) Paper: https://arxiv.org/abs/1801.01290 Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270
- Parameters
env (
Union[Env,VecEnv,str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)eval_env (
Union[Env,VecEnv,str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)batch_size (
int) – Minibatch size for each gradient updatebuffer_length (
int) – length of the replay bufferwarmup_length (
int) – how many steps of the model to collect transitions for before learning startstrain_freq (
int) – Update the model everytrain_freqsteps. Set to -1 to disable.episodes_per_rollout (
int) – Update the model everyepisodes_per_rolloutepisodes. Note that this cannot be used at the same time astrain_freq. Set to -1 to disable.num_rollouts (
int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epochgradient_steps (
int) – How many gradient steps to do after each rollouttarget_update_interval (
int) – How many environment steps to wait between updating the target Q networknum_eval_episodes (
int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epochgamma (
float) – the discount factorentropy_coef (
Union[str,float]) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to ‘auto’ to learn it automatically (and ‘auto_0.1’ for using 0.1 as initial value)target_entropy (
Union[str,float]) – target entropy when learningent_coef(ent_coef = 'auto')use_sde (
bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)sde_sample_freq (
int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)use_sde_at_warmup (
bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)squashed_actions (
bool) – Whether the actions are squashed between [-1, 1] and need to be unsquashedverbose (
int) – The verbosity level: 0 none, 1 training information, 2 debug (default: 0)seed (
Optional[int]) – Seed for the pseudo random generators
- forward_actor(obs)[source]¶
Runs the actor network. Override this function with your own.
- Parameters
obs (
Tensor) – The input observations- Return type
Tensor- Returns
The deterministic action of the actor
- forward_critic_targets(obs, action)[source]¶
Runs the all target critic networks. Override this function with your own.
- Parameters
obs (
Tensor) – The input observationsaction (
Tensor) – The input actions
- Return type
Tuple[Tensor, …]- Returns
The output Q values of the critic networks in the form of a list
- forward_critics(obs, action)[source]¶
Runs the all critic networks. Override this function with your own.
- Parameters
obs (
Tensor) – The input observationsaction (
Tensor) – The input actions
- Return type
Tuple[Tensor, …]- Returns
The output Q values of the critic networks in the form of a list