Base RL Class

Common interface for all the RL algorithms

Abstract base classes for RL algorithms.

class lightning_baselines3.common.base_model.BaseModel(env, eval_env, num_eval_episodes=10, verbose=0, support_multi_env=False, seed=None, use_sde=False)[source]

The base of RL algorithms

Parameters
  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • eval_env (Union[Env, VecEnv, str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)

  • num_eval_episodes (int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epoch

  • verbose (int) – The verbosity level: 0 none, 1 training information, 2 debug

  • support_multi_env (bool) – Whether the algorithm supports training with multiple environments in parallel

  • seed (Optional[int]) – Seed for the pseudo random generators

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE)

evaluate(num_eval_episodes, deterministic=True, render=False, record=False, record_fn=None)[source]

Evaluate the model with eval_env

Parameters
  • num_eval_episodes (int) – Number of episodes to evaluate for

  • deterministic (bool) – Whether to evaluate deterministically

  • render (bool) – Whether to render while evaluating

  • record (bool) – Whether to recod while evaluating

  • record_fn (Optional[str]) – File to record environment to if we are recording

Return type

Tuple[List[float], List[int]]

Returns

A list of total episode rewards and a list of episode lengths

predict(obs, deterministic=False)[source]

Override this function with the predict function of your own model

Parameters
  • obs (Union[Tuple, Dict[str, Any], ndarray, int]) – The input observations

  • deterministic (bool) – Whether to predict deterministically

Return type

ndarray

Returns

The chosen actions

reset()[source]

Reset the enviornment

Return type

None

sample_action(obs, deterministic=False)[source]

Samples an action from the environment or from our model

Parameters
  • obs (ndarray) – The input observation

  • deterministic (bool) – Whether we are sampling deterministically.

Return type

Tuple[ndarray, ndarray]

Returns

The action to step with, and the action to store in our buffer

save_hyperparameters(frame=None, exclude=['env', 'eval_env'])[source]

Utility function to save the hyperparameters of the model. This function behaves identically to LightningModule.save_hyperparameters, but will by default exclude the Gym environments See https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html#lightningmodule-hyperparameters for more details

set_random_seed(seed)[source]

Set the seed of the pseudo-random generators (python, numpy, pytorch, gym)

Parameters

seed (int) – The random seed to set

Return type

None

training_epoch_end(outputs)[source]

Run the evaluation function at the end of the training epoch Override this if you also wish to do other things at the end of a training epoch

Return type

None

Base Off-Policy Class

The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)

class lightning_baselines3.off_policy_models.off_policy_model.OffPolicyModel(env, eval_env, batch_size=256, buffer_length=1000000, warmup_length=100, train_freq=- 1, episodes_per_rollout=- 1, num_rollouts=1, gradient_steps=1, num_eval_episodes=10, gamma=0.99, squashed_actions=False, use_sde=False, sde_sample_freq=- 1, use_sde_at_warmup=False, verbose=0, seed=None)[source]

The base for Off-Policy algorithms (ex: SAC/TD3)

Parameters
  • env (Union[Env, VecEnv, str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)

  • eval_env (Union[Env, VecEnv, str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)

  • batch_size (int) – Minibatch size for each gradient update

  • buffer_length (int) – length of the replay buffer

  • warmup_length (int) – how many steps of the model to collect transitions for before learning starts

  • train_freq (int) – Update the model every train_freq steps. Set to -1 to disable.

  • episodes_per_rollout (int) – Update the model every episodes_per_rollout episodes. Note that this cannot be used at the same time as train_freq. Set to -1 to disable.

  • num_rollouts (int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epoch

  • gradient_steps (int) – How many gradient steps to do after each rollout

  • num_eval_episodes (int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epoch

  • gamma (float) – the discount factor

  • squashed_actions (bool) – whether the actions are squashed between [-1, 1] and need to be unsquashed

  • use_sde (bool) – Whether to use generalized State Dependent Exploration (gSDE)

  • sde_sample_freq (int) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • use_sde_at_warmup (bool) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)

  • verbose (int) – The verbosity level: 0 none, 1 training information, 2 debug (default: 0)

  • seed (Optional[int]) – Seed for the pseudo random generators

collect_rollouts()[source]

Collect rollouts and put them into the ReplayBuffer

on_step()[source]

Simple callback for each step we take in the environment

reset()[source]

Reset the environment and set the num_timesteps to 0

sample_action(obs, deterministic=False)[source]

Samples an action from the environment or from our model

Parameters
  • obs (ndarray) – The input observation

  • deterministic (bool) – Whether we are sampling deterministically.

Return type

Tuple[ndarray, ndarray]

Returns

The action to step with, and the action to store in our buffer

scale_actions(actions, squashed=False)[source]

Scale the action appropriately for spaces.Box based on whether they are squashed between [-1, 1]

Parameters

action – The input action

Return type

Tuple[ndarray, ndarray]

Returns

The action to step the environment with and the action to buffer with

train_dataloader()[source]

Create the dataloader for our OffPolicyModel

training_epoch_end(outputs)[source]

Run the evaluation function at the end of the training epoch Override this if you also wish to do other things at the end of a training epoch

Return type

None

Base On-Policy Class

The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)

class lightning_baselines3.on_policy_models.on_policy_model.OnPolicyModel(env, eval_env, buffer_length, num_rollouts, batch_size, epochs_per_rollout, num_eval_episodes=10, gamma=0.99, gae_lambda=0.95, use_sde=False, sde_sample_freq=- 1, verbose=0, seed=None)[source]

The base for On-Policy algorithms (ex: A2C/PPO).

Parameters
  • env (Union[Env, VecEnv, str]) – (Gym environment or str) The environment to learn from (if registered in Gym, can be str)

  • eval_env (Union[Env, VecEnv, str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)

  • buffer_length (int) – (int) Length of the buffer and the number of steps to run for each environment per update

  • num_rollouts (int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epoch

  • batch_size (int) – Minibatch size for each gradient update

  • epochs_per_rollout (int) – Number of epochs to optimise the loss for

  • num_eval_episodes (int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epoch

  • gamma (float) – (float) Discount factor

  • gae_lambda (float) – (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator. Equivalent to classic advantage when set to 1.

  • use_sde (bool) – (bool) Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration

  • sde_sample_freq (int) – (int) Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)

  • verbose (int) – The verbosity level: 0 none, 1 training information, 2 debug

  • seed (Optional[int]) – Seed for the pseudo random generators

collect_rollouts()[source]

Collect rollouts and put them into the RolloutBuffer

Return type

RolloutBufferSamples

forward(obs)[source]

Override this function with the forward function of your model

Parameters

obs (Union[Tuple, Dict[str, Any], ndarray, int]) – The input observations

Return type

Tuple[Distribution, Tensor]

Returns

The chosen actions

train_dataloader()[source]

Create the dataloader for our OffPolicyModel