DQN¶
Deep Q Network (DQN) builds on Fitted Q-Iteration (FQI) and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping.
Notes¶
Original paper: https://arxiv.org/abs/1312.5602
Further reference: https://www.nature.com/articles/nature14236
Note
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. It also does not include any speicifc policy so you can implement whichever one you like yourself. The example below shows epsilon greedy with a linearly decaying epsilon
Can I use?¶
Recurrent policies: ❌
Multi processing: ❌
Gym spaces:
Space |
Action |
Observation |
|---|---|---|
Discrete |
✔ |
✔ |
Box |
❌ |
✔ |
MultiDiscrete |
❌ |
✔ |
MultiBinary |
❌ |
✔ |
Example¶
import copy
import numpy as np
import torch
from torch import nn
import pytorch_lightning as pl
from lightning_baselines3.off_policy_models import DQN
class Model(DQN):
def __init__(self, **kwargs):
# **kwargs will pass our arguments on to DQN
super(Model, self).__init__(**kwargs)
self.qnet = nn.Sequential(
nn.Linear(self.observation_space.shape[0], 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, self.action_space.n))
self.qnet_target = copy.deepcopy(self.qnet)
self.eps = 1.0
self.eps_init = 1.0
self.eps_decay = 5000
self.eps_final = 0.05
self.qnet_target = copy.deepcopy(self.qnet)
self.save_hyperparameters()
# This is for running the model, returns the Q values given our observation
def forward(self, x):
return self.qnet(x)
# This is for running the target Q network
def forward_target(self, x):
return self.qnet_target(x)
# This is for updating the target Q network
def update_target(self):
self.qnet_target.load_state_dict(self.qnet.state_dict())
# Use the environment step callback to linearly decay our epsilon
# per envrionment step for epsilon greedy
def on_step(self):
k = max(self.eps_decay - self.num_timesteps, 0) / self.eps_decay
self.eps = self.eps_final + k * (self.eps_init - self.eps_final)
# This is for inference and evaluation of our model, returns the action
def predict(self, x, deterministic=True):
out = self.qnet(x)
if deterministic:
out = torch.max(out, dim=1)[1]
else:
eps = torch.rand_like(out[:, 0])
eps = (eps < self.eps).float()
out = eps * torch.rand_like(out).max(dim=1)[1] +\
(1 - eps) * out.max(dim=1)[1]
return out.long().cpu().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
return optimizer
if __name__ == '__main__':
model = Model(env='CartPole-v1', eval_env='CartPole-v1')
trainer = pl.Trainer(max_epochs=20, gradient_clip_val=0.5)
trainer.fit(model)
rewards, lengths = model.evaluate(num_eval_episodes=10, render=True)
print(np.mean(rewards), np.mean(lengths))
Parameters¶
- class lightning_baselines3.off_policy_models.dqn.DQN(env, eval_env, batch_size=256, buffer_length=1000000, warmup_length=100, train_freq=4, episodes_per_rollout=- 1, num_rollouts=1000, gradient_steps=1, target_update_interval=10000, num_eval_episodes=10, gamma=0.99, verbose=0, seed=None)[source]¶
Deep Q-Network (DQN)
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
- Parameters
env (
Union[Env,VecEnv,str]) – The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models)eval_env (
Union[Env,VecEnv,str]) – The environment to evaluate on, must not be vectorised/parallelrised (if registered in Gym, can be str. Can be None for loading trained models)batch_size (
int) – Minibatch size for each gradient updatebuffer_length (
int) – length of the replay bufferwarmup_length (
int) – how many steps of the model to collect transitions for before learning startstrain_freq (
int) – Update the model everytrain_freqsteps. Set to -1 to disable.episodes_per_rollout (
int) – Update the model everyepisodes_per_rolloutepisodes. Note that this cannot be used at the same time astrain_freq. Set to -1 to disable.num_rollouts (
int) – Number of rollouts to do per PyTorch Lightning epoch. This does not affect any training dynamic, just how often we evaluate the model since evaluation happens at the end of each Lightning epochgradient_steps (
int) – How many gradient steps to do after each rollouttarget_update_interval (
int) – How many environment steps to wait between updating the target Q networknum_eval_episodes (
int) – The number of episodes to evaluate for at the end of a PyTorch Lightning epochgamma (
float) – the discount factorverbose (
int) – The verbosity level: 0 none, 1 training information, 2 debug (default: 0)seed (
Optional[int]) – Seed for the pseudo random generators
- forward(x)[source]¶
Runs the Q network. Override this function with your own.
- Parameters
x (
Tensor) – The input observations- Return type
Tensor- Returns
The output Q values of the Q network
- forward_target(x)[source]¶
Runs the target Q network. Override this function with your own.
- Parameters
x (
Tensor) – The input observations- Return type
Tensor- Returns
The output Q values of the target Q network