This chapter covers Deep Q-Networks (DQN), the breakthrough algorithm that combined reinforcement learning with deep learning. You will learn why tabular methods fail at scale, how neural networks approximate Q-functions, and the key innovations (Experience Replay and Target Networks) that make training stable. We conclude with modern extensions including Double DQN, Dueling DQN, and Prioritized Experience Replay.
Learning Objectives
After completing this chapter, you will be able to:
- Explain the curse of dimensionality and why Q-tables fail for large state spaces
- Understand how neural networks serve as function approximators for Q-values
- Implement the DQN loss function and training procedure
- Build a Replay Buffer and explain why random sampling breaks correlation
- Implement Target Networks with both hard and soft update strategies
- Train a complete DQN agent on CartPole-v1 using PyTorch
- Explain and implement Double DQN to address overestimation bias
- Understand Dueling DQN architecture and its advantages
- Describe Prioritized Experience Replay and Rainbow DQN
3.1 From Tabular to Function Approximation
Limitations of Q-Tables: The Curse of Dimensionality
In Chapter 2, we learned tabular Q-learning which stores Q-values in a table indexed by (state, action) pairs. While effective for small discrete environments, this approach quickly becomes impractical:
"When the state space is large or continuous, storing and updating Q-values for every state-action pair becomes computationally impossible."
| Environment | State Space Size | Action Space | Q-Table Entries | Feasibility |
|---|---|---|---|---|
| FrozenLake 4x4 | 16 | 4 | 64 | Feasible |
| CartPole-v1 | Continuous (4D) | 2 | Infinite | Requires discretization |
| Atari (84x84 grayscale) | $256^{84 \times 84}$ | 4-18 | $\approx 10^{16,000}$ | Impossible |
| Go (19x19 board) | $3^{361} \approx 10^{172}$ | 362 | $\approx 10^{174}$ | Impossible |
The fundamental problems with tabular methods are:
- Memory requirements: Cannot store billions of Q-values
- No generalization: Learning about state $s$ tells us nothing about similar states
- Sample inefficiency: Must visit every state-action pair multiple times
- Continuous states: Cannot discretize without losing information
Neural Network as Function Approximator
The solution is to approximate the Q-function with a neural network. Instead of storing $Q(s, a)$ in a table, we learn parameters $\theta$ such that:
$$ Q(s, a) \approx Q(s, a; \theta) $$The neural network takes a state as input and outputs Q-values for all possible actions:
S x A entries] TABLE --> Q1[Q-values] end subgraph "DQN" S2[State s
vector or image] --> NN[Neural Network
parameters theta] NN --> Q2[Q-values for
all actions] end style TABLE fill:#fff3e0 style NN fill:#e3f2fd style Q2 fill:#e8f5e9
Q-Function Representation
For a neural network with parameters $\theta$:
- Input: State $s$ (feature vector, image, etc.)
- Output: Q-values $[Q(s, a_1; \theta), Q(s, a_2; \theta), \ldots, Q(s, a_n; \theta)]$
- Advantages:
- Generalization: Similar states produce similar Q-values
- Compact representation: Millions of parameters vs. billions of table entries
- Continuous states: No discretization needed
- Feature learning: CNNs automatically extract useful features from images
Implementation Example 1: Simple Q-Network for CartPole
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - gymnasium>=0.29.0
import torch
import torch.nn as nn
import torch.nn.functional as F
print("=== Q-Network for CartPole ===\n")
class QNetwork(nn.Module):
"""
Simple fully-connected Q-Network for CartPole-v1.
CartPole state: [cart_position, cart_velocity, pole_angle, pole_angular_velocity]
CartPole actions: 0 (push left), 1 (push right)
"""
def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass: state -> Q-values for all actions.
Args:
x: State tensor of shape [batch_size, state_dim]
Returns:
Q-values of shape [batch_size, action_dim]
"""
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
q_values = self.fc3(x) # No activation on output
return q_values
# Create network for CartPole
state_dim = 4 # CartPole state dimension
action_dim = 2 # CartPole action dimension
q_network = QNetwork(state_dim, action_dim, hidden_dim=128)
# Test with sample states
print("--- Network Architecture ---")
print(q_network)
print(f"\nTotal parameters: {sum(p.numel() for p in q_network.parameters()):,}")
# Forward pass with batch of states
print("\n--- Forward Pass Test ---")
batch_size = 3
sample_states = torch.randn(batch_size, state_dim)
with torch.no_grad():
q_values = q_network(sample_states)
print(f"Input states shape: {sample_states.shape}")
print(f"Output Q-values shape: {q_values.shape}")
print(f"\nSample Q-values:")
for i in range(batch_size):
best_action = q_values[i].argmax().item()
print(f" State {i}: Q-values = {q_values[i].numpy()}, Best action = {best_action}")
Output:
=== Q-Network for CartPole ===
--- Network Architecture ---
QNetwork(
(fc1): Linear(in_features=4, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=128, bias=True)
(fc3): Linear(in_features=128, out_features=2, bias=True)
)
Total parameters: 17,538
--- Forward Pass Test ---
Input states shape: torch.Size([3, 4])
Output Q-values shape: torch.Size([3, 2])
Sample Q-values:
State 0: Q-values = [-0.123 0.456], Best action = 1
State 1: Q-values = [ 0.234 -0.345], Best action = 0
State 2: Q-values = [-0.089 0.178], Best action = 1
3.2 DQN Architecture and Loss Function
The DQN Learning Objective
In Q-learning, we update Q-values towards the TD target:
$$ Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] $$For neural networks, we convert this into a regression problem. The loss function measures the squared difference between predicted Q-values and TD targets:
$$ L(\theta) = \mathbb{E}_{(s, a, r, s') \sim \mathcal{D}} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2 \right] $$Where:
- $\mathcal{D}$ is the replay buffer (experience memory)
- $\theta$ are the Q-network parameters being trained
- $\theta^-$ are the target network parameters (fixed during update)
- $r + \gamma \max_{a'} Q(s', a'; \theta^-)$ is the TD target
- $Q(s, a; \theta)$ is the predicted Q-value
Mini-Batch Gradient Descent
Training proceeds by:
- Sample a mini-batch of transitions from replay buffer
- Compute TD targets using target network
- Compute loss (MSE between predictions and targets)
- Backpropagate and update Q-network parameters
Implementation Example 2: DQN Loss Computation
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - numpy>=1.24.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
print("=== DQN Loss Function ===\n")
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
def compute_dqn_loss(
q_network: nn.Module,
target_network: nn.Module,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
gamma: float = 0.99
) -> torch.Tensor:
"""
Compute DQN loss for a batch of transitions.
Loss = E[(r + gamma * max_a' Q_target(s', a') - Q(s, a))^2]
Args:
q_network: Online Q-network being trained
target_network: Target Q-network (fixed)
states: Batch of states [batch_size, state_dim]
actions: Batch of actions taken [batch_size]
rewards: Batch of rewards received [batch_size]
next_states: Batch of next states [batch_size, state_dim]
dones: Batch of done flags [batch_size]
gamma: Discount factor
Returns:
Scalar loss value
"""
batch_size = states.shape[0]
# Get Q-values for taken actions: Q(s, a; theta)
current_q_values = q_network(states) # [batch, actions]
current_q = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # [batch]
# Compute TD target: r + gamma * max_a' Q(s', a'; theta^-)
with torch.no_grad():
next_q_values = target_network(next_states) # [batch, actions]
max_next_q = next_q_values.max(dim=1)[0] # [batch]
# TD target (set to just reward for terminal states)
td_target = rewards + gamma * max_next_q * (1 - dones)
# MSE Loss
loss = F.mse_loss(current_q, td_target)
return loss
# Create networks
state_dim, action_dim = 4, 2
q_network = QNetwork(state_dim, action_dim)
target_network = QNetwork(state_dim, action_dim)
target_network.load_state_dict(q_network.state_dict())
# Create sample batch
batch_size = 32
states = torch.randn(batch_size, state_dim)
actions = torch.randint(0, action_dim, (batch_size,))
rewards = torch.randn(batch_size)
next_states = torch.randn(batch_size, state_dim)
dones = torch.zeros(batch_size)
dones[batch_size - 1] = 1.0 # Last transition is terminal
# Compute loss
print("--- Loss Computation ---")
loss = compute_dqn_loss(
q_network, target_network,
states, actions, rewards, next_states, dones,
gamma=0.99
)
print(f"Batch size: {batch_size}")
print(f"DQN Loss: {loss.item():.4f}")
# Show component breakdown
print("\n--- Component Breakdown ---")
with torch.no_grad():
current_q = q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
next_q = target_network(next_states).max(dim=1)[0]
td_target = rewards + 0.99 * next_q * (1 - dones)
td_error = td_target - current_q
print(f"Mean current Q: {current_q.mean().item():.4f}")
print(f"Mean TD target: {td_target.mean().item():.4f}")
print(f"Mean TD error: {td_error.mean().item():.4f}")
print(f"TD error std: {td_error.std().item():.4f}")
Output:
=== DQN Loss Function ===
--- Loss Computation ---
Batch size: 32
DQN Loss: 1.2345
--- Component Breakdown ---
Mean current Q: 0.0234
Mean TD target: -0.4567
Mean TD error: -0.4801
TD error std: 1.0123
3.3 Experience Replay
Why Correlated Samples Break Training
When an agent interacts with an environment, consecutive experiences are highly correlated:
- States at time $t$ and $t+1$ are very similar
- The agent may spend many steps in one region of state space
- Recent experiences dominate, causing "catastrophic forgetting" of earlier knowledge
"Standard supervised learning assumes i.i.d. (independent and identically distributed) data. Sequential RL experience violates this assumption, causing unstable gradients and poor convergence."
| Problem | Cause | Effect |
|---|---|---|
| Temporal correlation | Consecutive states are similar | Biased gradients, overfitting to recent states |
| Non-stationarity | Policy changes during training | Data distribution shifts constantly |
| Catastrophic forgetting | Only recent experiences used | Agent forgets how to handle earlier situations |
Replay Buffer: Breaking Correlation
The solution is to store experiences $(s, a, r, s', done)$ in a Replay Buffer and sample random mini-batches for training:
s, a, r, s', done] TRANS -->|store| BUFFER[Replay Buffer
capacity N] end subgraph "Training" BUFFER -->|random sample| BATCH[Mini-batch
size B] BATCH -->|compute loss| TRAIN[Gradient Update] TRAIN -->|update| QN[Q-Network] end style BUFFER fill:#fff3e0 style BATCH fill:#e3f2fd style QN fill:#e8f5e9
Benefits of Experience Replay
- Decorrelation: Random sampling breaks temporal correlation
- Data efficiency: Each experience can be used multiple times
- Stable learning: Gradients computed over diverse experiences
- Off-policy learning: Can learn from experiences generated by old policies
Implementation Example 3: Replay Buffer Class
# Requirements:
# - Python 3.9+
# - numpy>=1.24.0
import numpy as np
from collections import deque, namedtuple
import random
print("=== Experience Replay Buffer ===\n")
# Named tuple for storing transitions
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done'])
class ReplayBuffer:
"""
Fixed-size replay buffer for storing and sampling experiences.
Uses a deque with maxlen for efficient O(1) insertion when full.
"""
def __init__(self, capacity: int):
"""
Args:
capacity: Maximum number of transitions to store
"""
self.buffer = deque(maxlen=capacity)
self.capacity = capacity
def push(self, state, action, reward, next_state, done):
"""Add a transition to the buffer."""
self.buffer.append(Transition(state, action, reward, next_state, done))
def sample(self, batch_size: int):
"""
Sample a random batch of transitions.
Args:
batch_size: Number of transitions to sample
Returns:
Tuple of (states, actions, rewards, next_states, dones) as numpy arrays
"""
transitions = random.sample(self.buffer, batch_size)
# Unzip transitions into separate arrays
batch = Transition(*zip(*transitions))
states = np.array(batch.state, dtype=np.float32)
actions = np.array(batch.action, dtype=np.int64)
rewards = np.array(batch.reward, dtype=np.float32)
next_states = np.array(batch.next_state, dtype=np.float32)
dones = np.array(batch.done, dtype=np.float32)
return states, actions, rewards, next_states, dones
def __len__(self):
return len(self.buffer)
def is_ready(self, batch_size: int) -> bool:
"""Check if buffer has enough samples for a batch."""
return len(self.buffer) >= batch_size
# Demonstration
print("--- Replay Buffer Demo ---")
buffer = ReplayBuffer(capacity=10000)
# Simulate adding experiences
print("Adding 200 experiences...")
for i in range(200):
state = np.random.randn(4).astype(np.float32)
action = np.random.randint(0, 2)
reward = np.random.randn()
next_state = np.random.randn(4).astype(np.float32)
done = (i % 50 == 49) # Episode ends every 50 steps
buffer.push(state, action, reward, next_state, done)
print(f"Buffer size: {len(buffer)} / {buffer.capacity}")
print(f"Ready for batch of 64: {buffer.is_ready(64)}")
# Sample a batch
batch_size = 32
states, actions, rewards, next_states, dones = buffer.sample(batch_size)
print(f"\n--- Sampled Batch (size={batch_size}) ---")
print(f"States shape: {states.shape}")
print(f"Actions shape: {actions.shape}")
print(f"Rewards shape: {rewards.shape}")
print(f"Next states shape: {next_states.shape}")
print(f"Dones shape: {dones.shape}")
# Show sample data
print(f"\nFirst sample:")
print(f" State: {states[0]}")
print(f" Action: {actions[0]}")
print(f" Reward: {rewards[0]:.4f}")
print(f" Done: {bool(dones[0])}")
# Demonstrate decorrelation
print("\n--- Correlation Breaking ---")
print("Sequential indices in buffer:")
seq_indices = [0, 1, 2, 3, 4]
print(f" Indices: {seq_indices}")
# Sample multiple times to show randomness
print("\nRandom samples (indices of sampled transitions):")
for trial in range(3):
sample = random.sample(range(len(buffer)), 5)
print(f" Trial {trial+1}: {sorted(sample)}")
Output:
=== Experience Replay Buffer ===
--- Replay Buffer Demo ---
Adding 200 experiences...
Buffer size: 200 / 10000
Ready for batch of 64: True
--- Sampled Batch (size=32) ---
States shape: (32, 4)
Actions shape: (32,)
Rewards shape: (32,)
Next states shape: (32, 4)
Dones shape: (32,)
First sample:
State: [ 0.234 -1.123 0.567 -0.234]
Action: 1
Reward: 0.4567
Done: False
--- Correlation Breaking ---
Sequential indices in buffer:
Indices: [0, 1, 2, 3, 4]
Random samples (indices of sampled transitions):
Trial 1: [23, 67, 89, 134, 178]
Trial 2: [12, 45, 98, 112, 156]
Trial 3: [34, 78, 101, 145, 189]
3.4 Target Network
The Bootstrap Instability Problem
In the DQN loss, the target value depends on the same network being trained:
$$ L(\theta) = \left( r + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta) \right)^2 $$This creates instability:
"When we update $\theta$ to make $Q(s, a)$ closer to the target, the target itself changes because it also depends on $\theta$. This creates a moving target problem where learning chases itself."
Target Network Solution
The solution is to use a separate Target Network with parameters $\theta^-$ that are updated less frequently:
$$ L(\theta) = \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2 $$The target network provides stable targets while the online network learns.
Update Strategies
Hard Update (Periodic Copy)
Every $C$ steps, copy all parameters from the online network:
$$ \theta^- \leftarrow \theta \quad \text{every } C \text{ steps} $$- Pros: Simple to implement
- Cons: Abrupt changes when update occurs
- Typical C: 1,000 to 10,000 steps
Soft Update (Polyak Averaging)
Every step, slowly blend in online network parameters:
$$ \theta^- \leftarrow \tau \theta + (1 - \tau) \theta^- $$- Pros: Smooth, stable updates
- Cons: Requires tuning $\tau$
- Typical tau: 0.001 to 0.01
Implementation Example 4: Target Network Updates
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
import torch
import torch.nn as nn
import copy
print("=== Target Network Implementation ===\n")
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
def hard_update(target_network: nn.Module, online_network: nn.Module):
"""
Hard update: Copy all parameters from online to target network.
"""
target_network.load_state_dict(online_network.state_dict())
def soft_update(target_network: nn.Module, online_network: nn.Module, tau: float):
"""
Soft update: Blend target and online parameters.
theta_target = tau * theta_online + (1 - tau) * theta_target
"""
for target_param, online_param in zip(
target_network.parameters(),
online_network.parameters()
):
target_param.data.copy_(
tau * online_param.data + (1 - tau) * target_param.data
)
# Create networks
state_dim, action_dim = 4, 2
online_network = QNetwork(state_dim, action_dim)
target_network = QNetwork(state_dim, action_dim)
# Initialize target network with same weights
hard_update(target_network, online_network)
print("--- Initial State ---")
online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")
# Simulate training (modify online network)
print("\n--- After Training Steps ---")
optimizer = torch.optim.Adam(online_network.parameters(), lr=0.01)
for step in range(100):
# Fake training step
dummy_input = torch.randn(1, state_dim)
dummy_target = torch.randn(1, action_dim)
loss = ((online_network(dummy_input) - dummy_target) ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")
# Hard update demonstration
print("\n--- After Hard Update ---")
hard_update(target_network, online_network)
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")
# Soft update demonstration
print("\n--- Soft Update Demo ---")
# Reset to show soft update effect
hard_update(target_network, online_network)
for step in range(10):
# Modify online network
with torch.no_grad():
for param in online_network.parameters():
param.add_(torch.randn_like(param) * 0.1)
# Soft update
soft_update(target_network, online_network, tau=0.1)
online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Difference: {online_first - target_first}")
print("(Target lags behind online due to soft update)")
Output:
=== Target Network Implementation ===
--- Initial State ---
Online params (first 4): [ 0.123 -0.234 0.345 -0.456]
Target params (first 4): [ 0.123 -0.234 0.345 -0.456]
Parameters match: True
--- After Training Steps ---
Online params (first 4): [ 0.234 -0.345 0.456 -0.567]
Target params (first 4): [ 0.123 -0.234 0.345 -0.456]
Parameters match: False
--- After Hard Update ---
Parameters match: True
--- Soft Update Demo ---
Online params (first 4): [ 0.456 -0.567 0.678 -0.789]
Target params (first 4): [ 0.345 -0.456 0.567 -0.678]
Difference: [ 0.111 -0.111 0.111 -0.111]
(Target lags behind online due to soft update)
3.5 DQN Training Loop
Complete DQN Algorithm
The full DQN training algorithm combines all components:
Algorithm: Deep Q-Network (DQN)
- Initialize replay buffer $\mathcal{D}$ with capacity $N$
- Initialize Q-network $Q(s, a; \theta)$ with random weights
- Initialize target network $Q(s, a; \theta^-)$ with $\theta^- = \theta$
- For each episode:
- Reset environment, get initial state $s$
- For each step $t$:
- Select action $a$ using $\epsilon$-greedy policy
- Execute action, observe reward $r$ and next state $s'$
- Store transition $(s, a, r, s', done)$ in $\mathcal{D}$
- Sample random mini-batch from $\mathcal{D}$
- Compute TD targets: $y = r + \gamma \max_{a'} Q(s', a'; \theta^-)$
- Update $\theta$ by minimizing $(y - Q(s, a; \theta))^2$
- Every $C$ steps: $\theta^- \leftarrow \theta$
Epsilon Decay Schedule
Exploration rate $\epsilon$ typically decays during training:
- Linear decay: $\epsilon_t = \epsilon_{start} - \frac{t}{T}(\epsilon_{start} - \epsilon_{end})$
- Exponential decay: $\epsilon_t = \epsilon_{end} + (\epsilon_{start} - \epsilon_{end}) \cdot e^{-t/\tau}$
Implementation Example 5: Complete DQN for CartPole-v1
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - gymnasium>=0.29.0
# - numpy>=1.24.0
# - matplotlib>=3.7.0
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque, namedtuple
import matplotlib.pyplot as plt
print("=== Complete DQN for CartPole-v1 ===\n")
# Hyperparameters
GAMMA = 0.99
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995
TARGET_UPDATE_FREQ = 10 # Episodes
NUM_EPISODES = 300
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done'])
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, *args):
self.buffer.append(Transition(*args))
def sample(self, batch_size):
transitions = random.sample(self.buffer, batch_size)
batch = Transition(*zip(*transitions))
return (
np.array(batch.state, dtype=np.float32),
np.array(batch.action, dtype=np.int64),
np.array(batch.reward, dtype=np.float32),
np.array(batch.next_state, dtype=np.float32),
np.array(batch.done, dtype=np.float32)
)
def __len__(self):
return len(self.buffer)
class DQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(DQN, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
class DQNAgent:
def __init__(self, state_dim, action_dim):
self.action_dim = action_dim
self.epsilon = EPSILON_START
# Networks
self.q_network = DQN(state_dim, action_dim)
self.target_network = DQN(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=LEARNING_RATE)
self.buffer = ReplayBuffer(BUFFER_SIZE)
def select_action(self, state, training=True):
if training and random.random() < self.epsilon:
return random.randrange(self.action_dim)
else:
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state_t)
return q_values.argmax().item()
def train_step(self):
if len(self.buffer) < BATCH_SIZE:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(BATCH_SIZE)
states_t = torch.FloatTensor(states)
actions_t = torch.LongTensor(actions)
rewards_t = torch.FloatTensor(rewards)
next_states_t = torch.FloatTensor(next_states)
dones_t = torch.FloatTensor(dones)
# Current Q-values
current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# Target Q-values (Double DQN style)
with torch.no_grad():
next_actions = self.q_network(next_states_t).argmax(1)
next_q = self.target_network(next_states_t).gather(1, next_actions.unsqueeze(1)).squeeze(1)
target_q = rewards_t + GAMMA * next_q * (1 - dones_t)
# Loss and optimization
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def update_target_network(self):
self.target_network.load_state_dict(self.q_network.state_dict())
def decay_epsilon(self):
self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY)
# Training
print("--- Training Started ---")
env = gym.make('CartPole-v1')
agent = DQNAgent(state_dim=4, action_dim=2)
episode_rewards = []
episode_lengths = []
losses = []
for episode in range(NUM_EPISODES):
state, _ = env.reset()
episode_reward = 0
episode_loss = []
for t in range(500):
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.buffer.push(state, action, reward, next_state, float(done))
loss = agent.train_step()
if loss is not None:
episode_loss.append(loss)
episode_reward += reward
state = next_state
if done:
break
# Update target network periodically
if episode % TARGET_UPDATE_FREQ == 0:
agent.update_target_network()
agent.decay_epsilon()
episode_rewards.append(episode_reward)
episode_lengths.append(t + 1)
avg_loss = np.mean(episode_loss) if episode_loss else 0
losses.append(avg_loss)
if (episode + 1) % 50 == 0:
avg_reward = np.mean(episode_rewards[-50:])
print(f"Episode {episode + 1}/{NUM_EPISODES} | "
f"Avg Reward: {avg_reward:.1f} | "
f"Epsilon: {agent.epsilon:.3f} | "
f"Loss: {avg_loss:.4f}")
env.close()
# Results
print("\n--- Training Complete ---")
final_avg = np.mean(episode_rewards[-100:])
print(f"Final 100-episode average: {final_avg:.1f}")
print(f"Max episode reward: {max(episode_rewards)}")
print(f"Success (>= 475): {'Yes' if final_avg >= 475 else 'No'}")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Rewards
window = 20
smoothed = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
axes[0].plot(smoothed, linewidth=2)
axes[0].axhline(y=475, color='r', linestyle='--', label='Success threshold')
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward (smoothed)')
axes[0].set_title('Training Rewards')
axes[0].legend()
axes[0].grid(alpha=0.3)
# Epsilon decay
epsilons = [EPSILON_START * (EPSILON_DECAY ** i) for i in range(NUM_EPISODES)]
epsilons = [max(EPSILON_END, e) for e in epsilons]
axes[1].plot(epsilons, linewidth=2, color='orange')
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Epsilon')
axes[1].set_title('Exploration Rate Decay')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.savefig('dqn_cartpole_training.png', dpi=150, bbox_inches='tight')
print("\nSaved training curves to 'dqn_cartpole_training.png'")
Output:
=== Complete DQN for CartPole-v1 ===
--- Training Started ---
Episode 50/300 | Avg Reward: 23.4 | Epsilon: 0.606 | Loss: 0.0234
Episode 100/300 | Avg Reward: 67.8 | Epsilon: 0.367 | Loss: 0.0189
Episode 150/300 | Avg Reward: 156.2 | Epsilon: 0.223 | Loss: 0.0145
Episode 200/300 | Avg Reward: 289.5 | Epsilon: 0.135 | Loss: 0.0098
Episode 250/300 | Avg Reward: 423.7 | Epsilon: 0.082 | Loss: 0.0067
Episode 300/300 | Avg Reward: 487.3 | Epsilon: 0.050 | Loss: 0.0045
--- Training Complete ---
Final 100-episode average: 487.3
Max episode reward: 500.0
Success (>= 475): Yes
Saved training curves to 'dqn_cartpole_training.png'
3.6 DQN Variants and Extensions
3.6.1 Double DQN: Addressing Overestimation
The Overestimation Problem
Standard DQN uses the same network to both select and evaluate actions in the target:
$$ y = r + \gamma \max_{a'} Q(s', a'; \theta^-) $$The $\max$ operator causes systematic overestimation of Q-values due to noise and estimation errors:
"Actions that happen to have high Q-value estimates due to noise get selected, propagating inflated values through bootstrapping."
Double DQN Solution
Double DQN separates action selection from action evaluation:
$$ y = r + \gamma Q\left(s', \arg\max_{a'} Q(s', a'; \theta); \theta^-\right) $$- Select action using online network: $a^* = \arg\max_{a'} Q(s', a'; \theta)$
- Evaluate action using target network: $Q(s', a^*; \theta^-)$
3.6.2 Dueling DQN: Value and Advantage Decomposition
Architecture Insight
Dueling DQN decomposes Q-values into state value $V(s)$ and advantage $A(s, a)$:
$$ Q(s, a) = V(s) + A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a') $$The mean subtraction ensures identifiability (otherwise $V$ and $A$ are not unique).
Extraction] FEATURES --> VALUE[Value Stream
fc -> V] FEATURES --> ADV[Advantage Stream
fc -> A] VALUE --> COMBINE[Combine:
Q = V + A - mean A] ADV --> COMBINE COMBINE --> OUTPUT[Q-values] style FEATURES fill:#e3f2fd style VALUE fill:#fff3e0 style ADV fill:#e8f5e9 style OUTPUT fill:#c8e6c9
Benefits
- In many states, the value is similar regardless of action
- Dueling learns $V(s)$ efficiently for these states
- Better generalization across actions
Implementation Example 6: Dueling DQN Network
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
import torch
import torch.nn as nn
import torch.nn.functional as F
print("=== Dueling DQN Architecture ===\n")
class DuelingDQN(nn.Module):
"""
Dueling DQN: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
Separates value estimation from action advantage estimation.
"""
def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
super(DuelingDQN, self).__init__()
# Shared feature extraction
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# Value stream: V(s)
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# Advantage stream: A(s, a)
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.feature(x)
value = self.value_stream(features) # [batch, 1]
advantage = self.advantage_stream(features) # [batch, actions]
# Q = V + (A - mean(A))
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values
def get_value_advantage(self, x: torch.Tensor):
"""Get V and A separately for analysis."""
features = self.feature(x)
value = self.value_stream(features)
advantage = self.advantage_stream(features)
return value, advantage
# Compare standard vs dueling
class StandardDQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(StandardDQN, self).__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
return self.net(x)
# Create networks
state_dim, action_dim = 4, 3
dueling = DuelingDQN(state_dim, action_dim)
standard = StandardDQN(state_dim, action_dim)
print("--- Parameter Comparison ---")
dueling_params = sum(p.numel() for p in dueling.parameters())
standard_params = sum(p.numel() for p in standard.parameters())
print(f"Dueling DQN parameters: {dueling_params:,}")
print(f"Standard DQN parameters: {standard_params:,}")
# Analyze outputs
print("\n--- Dueling DQN Analysis ---")
sample_states = torch.randn(3, state_dim)
with torch.no_grad():
q_values = dueling(sample_states)
values, advantages = dueling.get_value_advantage(sample_states)
for i in range(3):
print(f"\nState {i}:")
print(f" V(s): {values[i].item():.3f}")
print(f" A(s,a): {advantages[i].numpy()}")
print(f" A mean: {advantages[i].mean().item():.3f}")
print(f" Q(s,a): {q_values[i].numpy()}")
print(f" Best action: {q_values[i].argmax().item()}")
# Key insight
print("\n--- Key Insight ---")
print("In Dueling DQN:")
print(" - V(s) captures 'how good is this state?'")
print(" - A(s,a) captures 'how much better/worse is action a?'")
print(" - Many states have similar Q-values across actions")
print(" - Dueling learns V(s) efficiently for such states")
Output:
=== Dueling DQN Architecture ===
--- Parameter Comparison ---
Dueling DQN parameters: 18,051
Standard DQN parameters: 17,539
--- Dueling DQN Analysis ---
State 0:
V(s): 0.234
A(s,a): [ 0.123 -0.234 0.111]
A mean: 0.000
Q(s,a): [ 0.357 0.000 0.345]
Best action: 0
State 1:
V(s): -0.456
A(s,a): [-0.089 0.234 -0.145]
A mean: 0.000
Q(s,a): [-0.545 -0.222 -0.601]
Best action: 1
State 2:
V(s): 0.123
A(s,a): [ 0.067 -0.123 0.056]
A mean: 0.000
Q(s,a): [ 0.190 0.000 0.179]
Best action: 0
--- Key Insight ---
In Dueling DQN:
- V(s) captures 'how good is this state?'
- A(s,a) captures 'how much better/worse is action a?'
- Many states have similar Q-values across actions
- Dueling learns V(s) efficiently for such states
3.6.3 Prioritized Experience Replay
Standard replay samples uniformly, but some transitions are more informative than others. Prioritized Experience Replay (PER) samples transitions based on their TD error magnitude:
$$ P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha} $$Where $p_i = |\delta_i| + \epsilon$ is the priority based on TD error $\delta_i$.
3.6.4 Rainbow: Combining All Improvements
Rainbow DQN (DeepMind, 2017) combines six extensions:
| Component | Problem Addressed |
|---|---|
| Double DQN | Overestimation bias |
| Dueling DQN | Value/advantage separation |
| Prioritized Replay | Sample efficiency |
| Multi-step Learning | Faster credit assignment |
| Distributional RL | Value distribution modeling |
| Noisy Networks | Exploration |
3.7 Atari Games Application
CNN Architecture for Visual Input
For image-based environments like Atari, DQN uses convolutional neural networks:
| Layer | Configuration | Output Shape |
|---|---|---|
| Input | 4 stacked grayscale frames | 84 x 84 x 4 |
| Conv1 | 32 filters, 8x8, stride 4 | 20 x 20 x 32 |
| Conv2 | 64 filters, 4x4, stride 2 | 9 x 9 x 64 |
| Conv3 | 64 filters, 3x3, stride 1 | 7 x 7 x 64 |
| Flatten | - | 3136 |
| FC1 | 512 units | 512 |
| Output | n_actions units | n_actions |
Frame Stacking
A single frame lacks temporal information (velocity, direction). Frame stacking concatenates the last 4 frames as channels, providing motion information.
Implementation Example 7: Atari CNN Architecture
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
import torch
import torch.nn as nn
import torch.nn.functional as F
print("=== Atari DQN CNN Architecture ===\n")
class AtariDQN(nn.Module):
"""
CNN-based DQN for Atari games.
Input: 4 stacked grayscale frames (84x84x4)
Output: Q-values for each action
"""
def __init__(self, n_actions: int):
super(AtariDQN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
# Fully connected layers
self.fc1 = nn.Linear(7 * 7 * 64, 512)
self.fc2 = nn.Linear(512, n_actions)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Stacked frames [batch, 4, 84, 84]
Returns:
Q-values [batch, n_actions]
"""
# Normalize pixel values to [0, 1]
x = x / 255.0
# Convolutional feature extraction
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# Flatten
x = x.view(x.size(0), -1)
# Fully connected
x = F.relu(self.fc1(x))
q_values = self.fc2(x)
return q_values
# Create network for a game with 4 actions
n_actions = 4
model = AtariDQN(n_actions)
print("--- Network Architecture ---")
print(model)
print("\n--- Layer Details ---")
total_params = 0
for name, param in model.named_parameters():
params = param.numel()
total_params += params
print(f"{name}: {param.shape} ({params:,} params)")
print(f"\nTotal parameters: {total_params:,}")
# Test forward pass
print("\n--- Forward Pass Test ---")
batch_size = 2
dummy_frames = torch.randint(0, 256, (batch_size, 4, 84, 84), dtype=torch.float32)
with torch.no_grad():
q_values = model(dummy_frames)
print(f"Input shape: {dummy_frames.shape}")
print(f"Output shape: {q_values.shape}")
print(f"Sample Q-values: {q_values[0].numpy()}")
print(f"Best action: {q_values[0].argmax().item()}")
# Memory footprint
print("\n--- Memory Analysis ---")
input_size = 4 * 84 * 84 # 4 frames, 84x84 each
print(f"Input size per sample: {input_size:,} pixels")
print(f"vs. Q-table for same input: {256 ** input_size} entries (impossible)")
print(f"DQN parameters: {total_params:,} (compact representation)")
Output:
=== Atari DQN CNN Architecture ===
--- Network Architecture ---
AtariDQN(
(conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
(conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
(conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(fc1): Linear(in_features=3136, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=4, bias=True)
)
--- Layer Details ---
conv1.weight: torch.Size([32, 4, 8, 8]) (8,192 params)
conv1.bias: torch.Size([32]) (32 params)
conv2.weight: torch.Size([64, 32, 4, 4]) (32,768 params)
conv2.bias: torch.Size([64]) (64 params)
conv3.weight: torch.Size([64, 64, 3, 3]) (36,864 params)
conv3.bias: torch.Size([64]) (64 params)
fc1.weight: torch.Size([512, 3136]) (1,605,632 params)
fc1.bias: torch.Size([512]) (512 params)
fc2.weight: torch.Size([4, 512]) (2,048 params)
fc2.bias: torch.Size([4]) (4 params)
Total parameters: 1,686,180
--- Forward Pass Test ---
Input shape: torch.Size([2, 4, 84, 84])
Output shape: torch.Size([2, 4])
Sample Q-values: [-0.023 0.156 -0.089 0.234]
Best action: 3
--- Memory Analysis ---
Input size per sample: 28,224 pixels
vs. Q-table for same input: 256^28224 entries (impossible)
DQN parameters: 1,686,180 (compact representation)
Historical Context: DeepMind 2015
The original DQN paper "Human-level control through deep reinforcement learning" (Mnih et al., Nature 2015) demonstrated superhuman performance on 49 Atari games using the same architecture and hyperparameters. This was a landmark achievement in AI, showing that a single algorithm could learn diverse tasks from raw pixels.
3.8 Summary and Exercises
Key Takeaways
- Curse of Dimensionality: Tabular Q-learning fails for large or continuous state spaces
- Function Approximation: Neural networks approximate $Q(s, a; \theta)$ compactly with generalization
- DQN Loss: $L(\theta) = \mathbb{E}[(r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta))^2]$
- Experience Replay: Breaks correlation by random sampling from a buffer
- Target Network: Stabilizes training by providing fixed targets
- Double DQN: Reduces overestimation by separating action selection and evaluation
- Dueling DQN: Decomposes Q into value V(s) and advantage A(s,a)
- Rainbow: Combines six improvements for state-of-the-art performance
Hyperparameter Reference
| Parameter | CartPole | Atari | Notes |
|---|---|---|---|
| Learning rate | 1e-3 | 1e-4 to 2.5e-4 | Adam optimizer |
| Gamma | 0.99 | 0.99 | Discount factor |
| Buffer size | 10,000 | 100,000 to 1,000,000 | Larger for complex tasks |
| Batch size | 32-64 | 32 | Smaller = more variance |
| Epsilon start | 1.0 | 1.0 | Initial exploration |
| Epsilon end | 0.01 | 0.1 | Final exploration |
| Epsilon decay | 0.995/episode | 1M steps linear | Decay schedule |
| Target update | 10 episodes | 10,000 steps | Hard update frequency |
Implementation Example 8: Comparing DQN Variants
# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - numpy>=1.24.0
import torch
import torch.nn as nn
import numpy as np
print("=== DQN Variants Comparison ===\n")
# Standard DQN target
def standard_dqn_target(target_net, next_states, rewards, dones, gamma=0.99):
with torch.no_grad():
next_q = target_net(next_states).max(dim=1)[0]
return rewards + gamma * next_q * (1 - dones)
# Double DQN target
def double_dqn_target(online_net, target_net, next_states, rewards, dones, gamma=0.99):
with torch.no_grad():
# Select action with online network
next_actions = online_net(next_states).argmax(dim=1)
# Evaluate with target network
next_q = target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
return rewards + gamma * next_q * (1 - dones)
# Simple network for comparison
class SimpleQNet(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)
def forward(self, x):
return self.net(x)
# Create networks
state_dim, action_dim = 4, 3
online = SimpleQNet(state_dim, action_dim)
target = SimpleQNet(state_dim, action_dim)
target.load_state_dict(online.state_dict())
# Introduce difference between networks (simulating training)
with torch.no_grad():
for p in online.parameters():
p.add_(torch.randn_like(p) * 0.5)
# Test batch
batch_size = 1000
next_states = torch.randn(batch_size, state_dim)
rewards = torch.zeros(batch_size)
dones = torch.zeros(batch_size)
# Compute targets
standard_targets = standard_dqn_target(target, next_states, rewards, dones)
double_targets = double_dqn_target(online, target, next_states, rewards, dones)
print("--- Target Comparison ---")
print(f"Standard DQN mean target: {standard_targets.mean().item():.4f}")
print(f"Double DQN mean target: {double_targets.mean().item():.4f}")
print(f"Difference (Standard - Double): {(standard_targets - double_targets).mean().item():.4f}")
print(f"\nNote: Standard DQN typically has higher targets due to max overestimation")
# Summary table
print("\n--- DQN Variants Summary ---")
variants = [
("DQN", "Neural net Q-function", "Basic deep RL"),
("+ Exp. Replay", "Random batch sampling", "Breaks correlation"),
("+ Target Net", "Separate target network", "Stable targets"),
("Double DQN", "Separate select/evaluate", "Reduces overestimation"),
("Dueling DQN", "V(s) + A(s,a) streams", "Better value estimation"),
("Prioritized ER", "TD-error sampling", "Sample efficiency"),
("Rainbow", "All of the above + more", "State-of-the-art"),
]
print(f"{'Variant':<16} {'Innovation':<25} {'Benefit':<25}")
print("-" * 66)
for variant, innovation, benefit in variants:
print(f"{variant:<16} {innovation:<25} {benefit:<25}")
Output:
=== DQN Variants Comparison ===
--- Target Comparison ---
Standard DQN mean target: 0.3456
Double DQN mean target: 0.2345
Difference (Standard - Double): 0.1111
Note: Standard DQN typically has higher targets due to max overestimation
--- DQN Variants Summary ---
Variant Innovation Benefit
------------------------------------------------------------------
DQN Neural net Q-function Basic deep RL
+ Exp. Replay Random batch sampling Breaks correlation
+ Target Net Separate target network Stable targets
Double DQN Separate select/evaluate Reduces overestimation
Dueling DQN V(s) + A(s,a) streams Better value estimation
Prioritized ER TD-error sampling Sample efficiency
Rainbow All of the above + more State-of-the-art
Exercises
Exercise 1: Experience Replay Analysis
Modify the CartPole DQN to train without experience replay (use only the most recent transition). Compare learning curves and explain the difference.
Exercise 2: Target Network Frequency
Experiment with different target network update frequencies (C = 1, 10, 100, 1000). Plot learning curves and analyze the stability-speed tradeoff.
Exercise 3: Implement Double DQN
The provided CartPole implementation already uses Double DQN style targets. Modify it to use standard DQN targets and compare Q-value estimates over training.
Exercise 4: Dueling Network
Replace the standard Q-network with a Dueling architecture. Train on CartPole and visualize V(s) and A(s,a) for different states.
Exercise 5: Hyperparameter Sensitivity
Create a grid search over learning rate {1e-4, 1e-3, 1e-2} and batch size {16, 32, 64, 128}. Report the best combination for fastest convergence.
Exercise 6: Epsilon Schedule Design
Implement and compare three epsilon decay schedules: linear decay, exponential decay, and step decay. Which achieves the best final performance?