SAC

Soft Actor Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.

SAC is the successor of Soft Q-Learning SQL and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.

Notes

Note

In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions.

Note

The example model for SAC differ a bit from others: it uses ReLU instead of tanh activation, to match the original paper

Can I use?

  • Recurrent policies: ❌

  • Multi processing: ❌

  • Gym spaces:

Space

Action

Observation

Discrete

✔️

Box

✔️

✔️

MultiDiscrete

✔️

MultiBinary

✔️

Example

import copy

import numpy as np

import torch
from torch import nn

import pytorch_lightning as pl

from lightning_baselines3.off_policy_models import SAC
from lightning_baselines3.common.distributions import SquashedMultivariateNormal
from lightning_baselines3.common.utils import polyak_update


class Model(SAC):
    def __init__(self, *args, **kwargs):
        super(Model, self).__init__(*args, **kwargs, squashed_actions=True)

        # Note: The output layer of the actor must be Tanh activated
        self.actor = nn.Sequential(
            nn.Linear(self.observation_space.shape[0], 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.action_space.shape[0] * 2))

        in_dim = self.observation_space.shape[0] + self.action_space.shape[0]
        self.critic1 = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1))

        self.critic2 = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1))

        self.critic_target1 = copy.deepcopy(self.critic1)
        self.critic_target2 = copy.deepcopy(self.critic2)

        self.save_hyperparameters()

    def forward_actor(self, x):
        out = list(torch.chunk(self.actor(x), 2, dim=1))
        out[1] = torch.diag_embed(
            torch.exp(torch.clamp(out[1], -5, 5)))
        dist = SquashedMultivariateNormal(
            loc=torch.tanh(out[0]), scale_tril=out[1])
        return dist

    def forward_critics(self, obs, action):
        out = [
            self.critic1(torch.cat([obs, action], dim=1)),
            self.critic2(torch.cat([obs, action], dim=1))]
        return out

    def forward_critic_targets(self, obs, action):
        out = [
            self.critic_target1(torch.cat([obs, action], dim=1)),
            self.critic_target2(torch.cat([obs, action], dim=1))]
        return out

    def update_targets(self):
        polyak_update(
            self.critic1.parameters(),
            self.critic_target1.parameters(),
            tau=0.005)
        polyak_update(
            self.critic2.parameters(),
            self.critic_target2.parameters(),
            tau=0.005)

    def predict(self, x, deterministic=True):
        out = self.actor(x)
        if deterministic:
            out = torch.chunk(out, 2, dim=1)[0]
        else:
            out = list(torch.chunk(out, 2, dim=1))
            out[1] = torch.diag_embed(
                torch.exp(torch.clamp(out[1], -5, 5)))
            out = SquashedMultivariateNormal(
                loc=torch.tanh(out[0]), scale_tril=out[1]).sample()
        return out.cpu().numpy()

    def configure_optimizers(self):
        opt_actor = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        opt_critic = torch.optim.Adam(
            list(self.critic1.parameters()) + list(self.critic2.parameters()),
            lr=3e-4)
        return opt_critic, opt_actor


if __name__ == '__main__':
    model = Model(
        env='LunarLanderContinuous-v2',
        eval_env='LunarLanderContinuous-v2',
        warmup_length=1000)

    trainer = pl.Trainer(max_epochs=20, gradient_clip_val=0.5)
    trainer.fit(model)

    model.evaluate(num_eval_episodes=10, render=True)

Results

Atari Games

Coming soon

How to replicate the results?

Coming soon

Parameters

class lightning_baselines3.off_policy_models.sac.SAC(env, eval_env, batch_size=256, buffer_length=1000000, warmup_length=100, train_freq=1, episodes_per_rollout=- 1, num_rollouts=1000, gradient_steps=1, target_update_interval=1, num_eval_episodes=10, gamma=0.99, entropy_coef='auto', target_entropy='auto', use_sde=False, sde_sample_freq=- 1, use_sde_at_warmup=False, squashed_actions=True, verbose=0, seed=None)[source]

Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo (https://github.com/rail-berkeley/softlearning/) and from Stable Baselines (https://github.com/hill-a/stable-baselines) Paper: https://arxiv.org/abs/1801.01290 Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html

Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270

Parameters
  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • eval_env (Union[Env, VecEnv, str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)

  • batch_size (int) – Minibatch size for each gradient update

  • buffer_length (int) – length of the replay buffer

  • warmup_length (int) – how many steps of the model to collect transitions for before learning starts

  • train_freq (int) – Update the model every train_freq steps. Set to -1 to disable.

  • episodes_per_rollout (int) – Update the model every episodes_per_rollout episodes. Note that this cannot be used at the same time as train_freq. Set to -1 to disable.

  • num_rollouts (int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epoch

  • gradient_steps (int) – How many gradient steps to do after each rollout

  • target_update_interval (int) – How many environment steps to wait between updating the target Q network

  • num_eval_episodes (int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epoch

  • gamma (float) – the discount factor

  • entropy_coef (Union[str, float]) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to ‘auto’ to learn it automatically (and ‘auto_0.1’ for using 0.1 as initial value)

  • target_entropy (Union[str, float]) – target entropy when learning ent_coef (ent_coef = 'auto')

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • use_sde_at_warmup (bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)

  • squashed_actions (bool) – Whether the actions are squashed between [-1, 1] and need to be unsquashed

  • verbose (int) – The verbosity level: 0 none, 1 training information, 2 debug (default: 0)

  • seed (Optional[int]) – Seed for the pseudo random generators

forward_actor(obs)[source]

Runs the actor network. Override this function with your own.

Parameters

obs (Tensor) – The input observations

Return type

Tensor

Returns

The deterministic action of the actor

forward_critic_targets(obs, action)[source]

Runs the all target critic networks. Override this function with your own.

Parameters
  • obs (Tensor) – The input observations

  • action (Tensor) – The input actions

Return type

Tuple[Tensor, …]

Returns

The output Q values of the critic networks in the form of a list

forward_critics(obs, action)[source]

Runs the all critic networks. Override this function with your own.

Parameters
  • obs (Tensor) – The input observations

  • action (Tensor) – The input actions

Return type

Tuple[Tensor, …]

Returns

The output Q values of the critic networks in the form of a list

reset()[source]

Resets the environment and automatic entropy

training_step(batch, batch_idx, optimizer_idx)[source]

Specifies the update step for SAC. Override this if you wish to modify the SAC algorithm

update_targets()[source]

Function to update the target critic networks periodically. Override this function with your own.

Return type

None