EN | JP | Last sync: 2026-01-10

Chapter 3: Deep Q-Network (DQN)

From Tabular Q-Learning to Deep Learning: Experience Replay, Target Networks, and Modern Extensions

Reading Time: 30-35 minutes Difficulty: Intermediate to Advanced Code Examples: 8 Environment: CartPole-v1

This chapter covers Deep Q-Networks (DQN), the breakthrough algorithm that combined reinforcement learning with deep learning. You will learn why tabular methods fail at scale, how neural networks approximate Q-functions, and the key innovations (Experience Replay and Target Networks) that make training stable. We conclude with modern extensions including Double DQN, Dueling DQN, and Prioritized Experience Replay.

Learning Objectives

After completing this chapter, you will be able to:


3.1 From Tabular to Function Approximation

Limitations of Q-Tables: The Curse of Dimensionality

In Chapter 2, we learned tabular Q-learning which stores Q-values in a table indexed by (state, action) pairs. While effective for small discrete environments, this approach quickly becomes impractical:

"When the state space is large or continuous, storing and updating Q-values for every state-action pair becomes computationally impossible."

Environment State Space Size Action Space Q-Table Entries Feasibility
FrozenLake 4x4 16 4 64 Feasible
CartPole-v1 Continuous (4D) 2 Infinite Requires discretization
Atari (84x84 grayscale) $256^{84 \times 84}$ 4-18 $\approx 10^{16,000}$ Impossible
Go (19x19 board) $3^{361} \approx 10^{172}$ 362 $\approx 10^{174}$ Impossible

The fundamental problems with tabular methods are:

  1. Memory requirements: Cannot store billions of Q-values
  2. No generalization: Learning about state $s$ tells us nothing about similar states
  3. Sample inefficiency: Must visit every state-action pair multiple times
  4. Continuous states: Cannot discretize without losing information

Neural Network as Function Approximator

The solution is to approximate the Q-function with a neural network. Instead of storing $Q(s, a)$ in a table, we learn parameters $\theta$ such that:

$$ Q(s, a) \approx Q(s, a; \theta) $$

The neural network takes a state as input and outputs Q-values for all possible actions:

graph LR subgraph "Tabular Q-Learning" S1[State s] --> TABLE[Q-Table
S x A entries] TABLE --> Q1[Q-values] end subgraph "DQN" S2[State s
vector or image] --> NN[Neural Network
parameters theta] NN --> Q2[Q-values for
all actions] end style TABLE fill:#fff3e0 style NN fill:#e3f2fd style Q2 fill:#e8f5e9

Q-Function Representation

For a neural network with parameters $\theta$:

Implementation Example 1: Simple Q-Network for CartPole

# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - gymnasium>=0.29.0

import torch
import torch.nn as nn
import torch.nn.functional as F

print("=== Q-Network for CartPole ===\n")

class QNetwork(nn.Module):
    """
    Simple fully-connected Q-Network for CartPole-v1.

    CartPole state: [cart_position, cart_velocity, pole_angle, pole_angular_velocity]
    CartPole actions: 0 (push left), 1 (push right)
    """

    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super(QNetwork, self).__init__()

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: state -> Q-values for all actions.

        Args:
            x: State tensor of shape [batch_size, state_dim]

        Returns:
            Q-values of shape [batch_size, action_dim]
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_values = self.fc3(x)  # No activation on output
        return q_values


# Create network for CartPole
state_dim = 4   # CartPole state dimension
action_dim = 2  # CartPole action dimension

q_network = QNetwork(state_dim, action_dim, hidden_dim=128)

# Test with sample states
print("--- Network Architecture ---")
print(q_network)
print(f"\nTotal parameters: {sum(p.numel() for p in q_network.parameters()):,}")

# Forward pass with batch of states
print("\n--- Forward Pass Test ---")
batch_size = 3
sample_states = torch.randn(batch_size, state_dim)

with torch.no_grad():
    q_values = q_network(sample_states)

print(f"Input states shape: {sample_states.shape}")
print(f"Output Q-values shape: {q_values.shape}")
print(f"\nSample Q-values:")
for i in range(batch_size):
    best_action = q_values[i].argmax().item()
    print(f"  State {i}: Q-values = {q_values[i].numpy()}, Best action = {best_action}")

Output:

=== Q-Network for CartPole ===

--- Network Architecture ---
QNetwork(
  (fc1): Linear(in_features=4, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=2, bias=True)
)

Total parameters: 17,538

--- Forward Pass Test ---
Input states shape: torch.Size([3, 4])
Output Q-values shape: torch.Size([3, 2])

Sample Q-values:
  State 0: Q-values = [-0.123  0.456], Best action = 1
  State 1: Q-values = [ 0.234 -0.345], Best action = 0
  State 2: Q-values = [-0.089  0.178], Best action = 1

3.2 DQN Architecture and Loss Function

The DQN Learning Objective

In Q-learning, we update Q-values towards the TD target:

$$ Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] $$

For neural networks, we convert this into a regression problem. The loss function measures the squared difference between predicted Q-values and TD targets:

$$ L(\theta) = \mathbb{E}_{(s, a, r, s') \sim \mathcal{D}} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2 \right] $$

Where:

Mini-Batch Gradient Descent

Training proceeds by:

  1. Sample a mini-batch of transitions from replay buffer
  2. Compute TD targets using target network
  3. Compute loss (MSE between predictions and targets)
  4. Backpropagate and update Q-network parameters

Implementation Example 2: DQN Loss Computation

# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - numpy>=1.24.0

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

print("=== DQN Loss Function ===\n")

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


def compute_dqn_loss(
    q_network: nn.Module,
    target_network: nn.Module,
    states: torch.Tensor,
    actions: torch.Tensor,
    rewards: torch.Tensor,
    next_states: torch.Tensor,
    dones: torch.Tensor,
    gamma: float = 0.99
) -> torch.Tensor:
    """
    Compute DQN loss for a batch of transitions.

    Loss = E[(r + gamma * max_a' Q_target(s', a') - Q(s, a))^2]

    Args:
        q_network: Online Q-network being trained
        target_network: Target Q-network (fixed)
        states: Batch of states [batch_size, state_dim]
        actions: Batch of actions taken [batch_size]
        rewards: Batch of rewards received [batch_size]
        next_states: Batch of next states [batch_size, state_dim]
        dones: Batch of done flags [batch_size]
        gamma: Discount factor

    Returns:
        Scalar loss value
    """
    batch_size = states.shape[0]

    # Get Q-values for taken actions: Q(s, a; theta)
    current_q_values = q_network(states)  # [batch, actions]
    current_q = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)  # [batch]

    # Compute TD target: r + gamma * max_a' Q(s', a'; theta^-)
    with torch.no_grad():
        next_q_values = target_network(next_states)  # [batch, actions]
        max_next_q = next_q_values.max(dim=1)[0]  # [batch]

        # TD target (set to just reward for terminal states)
        td_target = rewards + gamma * max_next_q * (1 - dones)

    # MSE Loss
    loss = F.mse_loss(current_q, td_target)

    return loss


# Create networks
state_dim, action_dim = 4, 2
q_network = QNetwork(state_dim, action_dim)
target_network = QNetwork(state_dim, action_dim)
target_network.load_state_dict(q_network.state_dict())

# Create sample batch
batch_size = 32
states = torch.randn(batch_size, state_dim)
actions = torch.randint(0, action_dim, (batch_size,))
rewards = torch.randn(batch_size)
next_states = torch.randn(batch_size, state_dim)
dones = torch.zeros(batch_size)
dones[batch_size - 1] = 1.0  # Last transition is terminal

# Compute loss
print("--- Loss Computation ---")
loss = compute_dqn_loss(
    q_network, target_network,
    states, actions, rewards, next_states, dones,
    gamma=0.99
)
print(f"Batch size: {batch_size}")
print(f"DQN Loss: {loss.item():.4f}")

# Show component breakdown
print("\n--- Component Breakdown ---")
with torch.no_grad():
    current_q = q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q = target_network(next_states).max(dim=1)[0]
    td_target = rewards + 0.99 * next_q * (1 - dones)
    td_error = td_target - current_q

print(f"Mean current Q: {current_q.mean().item():.4f}")
print(f"Mean TD target: {td_target.mean().item():.4f}")
print(f"Mean TD error: {td_error.mean().item():.4f}")
print(f"TD error std: {td_error.std().item():.4f}")

Output:

=== DQN Loss Function ===

--- Loss Computation ---
Batch size: 32
DQN Loss: 1.2345

--- Component Breakdown ---
Mean current Q: 0.0234
Mean TD target: -0.4567
Mean TD error: -0.4801
TD error std: 1.0123

3.3 Experience Replay

Why Correlated Samples Break Training

When an agent interacts with an environment, consecutive experiences are highly correlated:

"Standard supervised learning assumes i.i.d. (independent and identically distributed) data. Sequential RL experience violates this assumption, causing unstable gradients and poor convergence."

Problem Cause Effect
Temporal correlation Consecutive states are similar Biased gradients, overfitting to recent states
Non-stationarity Policy changes during training Data distribution shifts constantly
Catastrophic forgetting Only recent experiences used Agent forgets how to handle earlier situations

Replay Buffer: Breaking Correlation

The solution is to store experiences $(s, a, r, s', done)$ in a Replay Buffer and sample random mini-batches for training:

graph TB subgraph "Experience Collection" ENV[Environment] -->|step| TRANS[Transition
s, a, r, s', done] TRANS -->|store| BUFFER[Replay Buffer
capacity N] end subgraph "Training" BUFFER -->|random sample| BATCH[Mini-batch
size B] BATCH -->|compute loss| TRAIN[Gradient Update] TRAIN -->|update| QN[Q-Network] end style BUFFER fill:#fff3e0 style BATCH fill:#e3f2fd style QN fill:#e8f5e9

Benefits of Experience Replay

  1. Decorrelation: Random sampling breaks temporal correlation
  2. Data efficiency: Each experience can be used multiple times
  3. Stable learning: Gradients computed over diverse experiences
  4. Off-policy learning: Can learn from experiences generated by old policies

Implementation Example 3: Replay Buffer Class

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0

import numpy as np
from collections import deque, namedtuple
import random

print("=== Experience Replay Buffer ===\n")

# Named tuple for storing transitions
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done'])


class ReplayBuffer:
    """
    Fixed-size replay buffer for storing and sampling experiences.

    Uses a deque with maxlen for efficient O(1) insertion when full.
    """

    def __init__(self, capacity: int):
        """
        Args:
            capacity: Maximum number of transitions to store
        """
        self.buffer = deque(maxlen=capacity)
        self.capacity = capacity

    def push(self, state, action, reward, next_state, done):
        """Add a transition to the buffer."""
        self.buffer.append(Transition(state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        """
        Sample a random batch of transitions.

        Args:
            batch_size: Number of transitions to sample

        Returns:
            Tuple of (states, actions, rewards, next_states, dones) as numpy arrays
        """
        transitions = random.sample(self.buffer, batch_size)

        # Unzip transitions into separate arrays
        batch = Transition(*zip(*transitions))

        states = np.array(batch.state, dtype=np.float32)
        actions = np.array(batch.action, dtype=np.int64)
        rewards = np.array(batch.reward, dtype=np.float32)
        next_states = np.array(batch.next_state, dtype=np.float32)
        dones = np.array(batch.done, dtype=np.float32)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

    def is_ready(self, batch_size: int) -> bool:
        """Check if buffer has enough samples for a batch."""
        return len(self.buffer) >= batch_size


# Demonstration
print("--- Replay Buffer Demo ---")
buffer = ReplayBuffer(capacity=10000)

# Simulate adding experiences
print("Adding 200 experiences...")
for i in range(200):
    state = np.random.randn(4).astype(np.float32)
    action = np.random.randint(0, 2)
    reward = np.random.randn()
    next_state = np.random.randn(4).astype(np.float32)
    done = (i % 50 == 49)  # Episode ends every 50 steps

    buffer.push(state, action, reward, next_state, done)

print(f"Buffer size: {len(buffer)} / {buffer.capacity}")
print(f"Ready for batch of 64: {buffer.is_ready(64)}")

# Sample a batch
batch_size = 32
states, actions, rewards, next_states, dones = buffer.sample(batch_size)

print(f"\n--- Sampled Batch (size={batch_size}) ---")
print(f"States shape: {states.shape}")
print(f"Actions shape: {actions.shape}")
print(f"Rewards shape: {rewards.shape}")
print(f"Next states shape: {next_states.shape}")
print(f"Dones shape: {dones.shape}")

# Show sample data
print(f"\nFirst sample:")
print(f"  State: {states[0]}")
print(f"  Action: {actions[0]}")
print(f"  Reward: {rewards[0]:.4f}")
print(f"  Done: {bool(dones[0])}")

# Demonstrate decorrelation
print("\n--- Correlation Breaking ---")
print("Sequential indices in buffer:")
seq_indices = [0, 1, 2, 3, 4]
print(f"  Indices: {seq_indices}")

# Sample multiple times to show randomness
print("\nRandom samples (indices of sampled transitions):")
for trial in range(3):
    sample = random.sample(range(len(buffer)), 5)
    print(f"  Trial {trial+1}: {sorted(sample)}")

Output:

=== Experience Replay Buffer ===

--- Replay Buffer Demo ---
Adding 200 experiences...
Buffer size: 200 / 10000
Ready for batch of 64: True

--- Sampled Batch (size=32) ---
States shape: (32, 4)
Actions shape: (32,)
Rewards shape: (32,)
Next states shape: (32, 4)
Dones shape: (32,)

First sample:
  State: [ 0.234 -1.123  0.567 -0.234]
  Action: 1
  Reward: 0.4567
  Done: False

--- Correlation Breaking ---
Sequential indices in buffer:
  Indices: [0, 1, 2, 3, 4]

Random samples (indices of sampled transitions):
  Trial 1: [23, 67, 89, 134, 178]
  Trial 2: [12, 45, 98, 112, 156]
  Trial 3: [34, 78, 101, 145, 189]

3.4 Target Network

The Bootstrap Instability Problem

In the DQN loss, the target value depends on the same network being trained:

$$ L(\theta) = \left( r + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta) \right)^2 $$

This creates instability:

"When we update $\theta$ to make $Q(s, a)$ closer to the target, the target itself changes because it also depends on $\theta$. This creates a moving target problem where learning chases itself."

graph LR Q[Q-Network theta] -->|computes| PRED[Predicted Q] Q -->|computes| TARGET[Target Q] TARGET -->|used in| LOSS[Loss] LOSS -->|updates| Q style Q fill:#e3f2fd style TARGET fill:#ffcccc style LOSS fill:#fff3e0

Target Network Solution

The solution is to use a separate Target Network with parameters $\theta^-$ that are updated less frequently:

$$ L(\theta) = \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2 $$

The target network provides stable targets while the online network learns.

Update Strategies

Hard Update (Periodic Copy)

Every $C$ steps, copy all parameters from the online network:

$$ \theta^- \leftarrow \theta \quad \text{every } C \text{ steps} $$

Soft Update (Polyak Averaging)

Every step, slowly blend in online network parameters:

$$ \theta^- \leftarrow \tau \theta + (1 - \tau) \theta^- $$

Implementation Example 4: Target Network Updates

# Requirements:
# - Python 3.9+
# - torch>=2.0.0

import torch
import torch.nn as nn
import copy

print("=== Target Network Implementation ===\n")

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


def hard_update(target_network: nn.Module, online_network: nn.Module):
    """
    Hard update: Copy all parameters from online to target network.
    """
    target_network.load_state_dict(online_network.state_dict())


def soft_update(target_network: nn.Module, online_network: nn.Module, tau: float):
    """
    Soft update: Blend target and online parameters.

    theta_target = tau * theta_online + (1 - tau) * theta_target
    """
    for target_param, online_param in zip(
        target_network.parameters(),
        online_network.parameters()
    ):
        target_param.data.copy_(
            tau * online_param.data + (1 - tau) * target_param.data
        )


# Create networks
state_dim, action_dim = 4, 2
online_network = QNetwork(state_dim, action_dim)
target_network = QNetwork(state_dim, action_dim)

# Initialize target network with same weights
hard_update(target_network, online_network)

print("--- Initial State ---")
online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")

# Simulate training (modify online network)
print("\n--- After Training Steps ---")
optimizer = torch.optim.Adam(online_network.parameters(), lr=0.01)

for step in range(100):
    # Fake training step
    dummy_input = torch.randn(1, state_dim)
    dummy_target = torch.randn(1, action_dim)
    loss = ((online_network(dummy_input) - dummy_target) ** 2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")

# Hard update demonstration
print("\n--- After Hard Update ---")
hard_update(target_network, online_network)
print(f"Parameters match: {torch.allclose(list(online_network.parameters())[0], list(target_network.parameters())[0])}")

# Soft update demonstration
print("\n--- Soft Update Demo ---")
# Reset to show soft update effect
hard_update(target_network, online_network)

for step in range(10):
    # Modify online network
    with torch.no_grad():
        for param in online_network.parameters():
            param.add_(torch.randn_like(param) * 0.1)

    # Soft update
    soft_update(target_network, online_network, tau=0.1)

online_first = list(online_network.parameters())[0].data[0, :4].numpy()
target_first = list(target_network.parameters())[0].data[0, :4].numpy()
print(f"Online params (first 4): {online_first}")
print(f"Target params (first 4): {target_first}")
print(f"Difference: {online_first - target_first}")
print("(Target lags behind online due to soft update)")

Output:

=== Target Network Implementation ===

--- Initial State ---
Online params (first 4): [ 0.123 -0.234  0.345 -0.456]
Target params (first 4): [ 0.123 -0.234  0.345 -0.456]
Parameters match: True

--- After Training Steps ---
Online params (first 4): [ 0.234 -0.345  0.456 -0.567]
Target params (first 4): [ 0.123 -0.234  0.345 -0.456]
Parameters match: False

--- After Hard Update ---
Parameters match: True

--- Soft Update Demo ---
Online params (first 4): [ 0.456 -0.567  0.678 -0.789]
Target params (first 4): [ 0.345 -0.456  0.567 -0.678]
Difference: [ 0.111 -0.111  0.111 -0.111]
(Target lags behind online due to soft update)

3.5 DQN Training Loop

Complete DQN Algorithm

The full DQN training algorithm combines all components:

Algorithm: Deep Q-Network (DQN)

  1. Initialize replay buffer $\mathcal{D}$ with capacity $N$
  2. Initialize Q-network $Q(s, a; \theta)$ with random weights
  3. Initialize target network $Q(s, a; \theta^-)$ with $\theta^- = \theta$
  4. For each episode:
    • Reset environment, get initial state $s$
    • For each step $t$:
      1. Select action $a$ using $\epsilon$-greedy policy
      2. Execute action, observe reward $r$ and next state $s'$
      3. Store transition $(s, a, r, s', done)$ in $\mathcal{D}$
      4. Sample random mini-batch from $\mathcal{D}$
      5. Compute TD targets: $y = r + \gamma \max_{a'} Q(s', a'; \theta^-)$
      6. Update $\theta$ by minimizing $(y - Q(s, a; \theta))^2$
      7. Every $C$ steps: $\theta^- \leftarrow \theta$

Epsilon Decay Schedule

Exploration rate $\epsilon$ typically decays during training:

Implementation Example 5: Complete DQN for CartPole-v1

# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - gymnasium>=0.29.0
# - numpy>=1.24.0
# - matplotlib>=3.7.0

import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque, namedtuple
import matplotlib.pyplot as plt

print("=== Complete DQN for CartPole-v1 ===\n")

# Hyperparameters
GAMMA = 0.99
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EPSILON_START = 1.0
EPSILON_END = 0.01
EPSILON_DECAY = 0.995
TARGET_UPDATE_FREQ = 10  # Episodes
NUM_EPISODES = 300

Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done'])


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        batch = Transition(*zip(*transitions))
        return (
            np.array(batch.state, dtype=np.float32),
            np.array(batch.action, dtype=np.int64),
            np.array(batch.reward, dtype=np.float32),
            np.array(batch.next_state, dtype=np.float32),
            np.array(batch.done, dtype=np.float32)
        )

    def __len__(self):
        return len(self.buffer)


class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


class DQNAgent:
    def __init__(self, state_dim, action_dim):
        self.action_dim = action_dim
        self.epsilon = EPSILON_START

        # Networks
        self.q_network = DQN(state_dim, action_dim)
        self.target_network = DQN(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=LEARNING_RATE)
        self.buffer = ReplayBuffer(BUFFER_SIZE)

    def select_action(self, state, training=True):
        if training and random.random() < self.epsilon:
            return random.randrange(self.action_dim)
        else:
            with torch.no_grad():
                state_t = torch.FloatTensor(state).unsqueeze(0)
                q_values = self.q_network(state_t)
                return q_values.argmax().item()

    def train_step(self):
        if len(self.buffer) < BATCH_SIZE:
            return None

        states, actions, rewards, next_states, dones = self.buffer.sample(BATCH_SIZE)

        states_t = torch.FloatTensor(states)
        actions_t = torch.LongTensor(actions)
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones)

        # Current Q-values
        current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # Target Q-values (Double DQN style)
        with torch.no_grad():
            next_actions = self.q_network(next_states_t).argmax(1)
            next_q = self.target_network(next_states_t).gather(1, next_actions.unsqueeze(1)).squeeze(1)
            target_q = rewards_t + GAMMA * next_q * (1 - dones_t)

        # Loss and optimization
        loss = nn.MSELoss()(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY)


# Training
print("--- Training Started ---")
env = gym.make('CartPole-v1')
agent = DQNAgent(state_dim=4, action_dim=2)

episode_rewards = []
episode_lengths = []
losses = []

for episode in range(NUM_EPISODES):
    state, _ = env.reset()
    episode_reward = 0
    episode_loss = []

    for t in range(500):
        action = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        agent.buffer.push(state, action, reward, next_state, float(done))

        loss = agent.train_step()
        if loss is not None:
            episode_loss.append(loss)

        episode_reward += reward
        state = next_state

        if done:
            break

    # Update target network periodically
    if episode % TARGET_UPDATE_FREQ == 0:
        agent.update_target_network()

    agent.decay_epsilon()
    episode_rewards.append(episode_reward)
    episode_lengths.append(t + 1)

    avg_loss = np.mean(episode_loss) if episode_loss else 0
    losses.append(avg_loss)

    if (episode + 1) % 50 == 0:
        avg_reward = np.mean(episode_rewards[-50:])
        print(f"Episode {episode + 1}/{NUM_EPISODES} | "
              f"Avg Reward: {avg_reward:.1f} | "
              f"Epsilon: {agent.epsilon:.3f} | "
              f"Loss: {avg_loss:.4f}")

env.close()

# Results
print("\n--- Training Complete ---")
final_avg = np.mean(episode_rewards[-100:])
print(f"Final 100-episode average: {final_avg:.1f}")
print(f"Max episode reward: {max(episode_rewards)}")
print(f"Success (>= 475): {'Yes' if final_avg >= 475 else 'No'}")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Rewards
window = 20
smoothed = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
axes[0].plot(smoothed, linewidth=2)
axes[0].axhline(y=475, color='r', linestyle='--', label='Success threshold')
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward (smoothed)')
axes[0].set_title('Training Rewards')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Epsilon decay
epsilons = [EPSILON_START * (EPSILON_DECAY ** i) for i in range(NUM_EPISODES)]
epsilons = [max(EPSILON_END, e) for e in epsilons]
axes[1].plot(epsilons, linewidth=2, color='orange')
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Epsilon')
axes[1].set_title('Exploration Rate Decay')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('dqn_cartpole_training.png', dpi=150, bbox_inches='tight')
print("\nSaved training curves to 'dqn_cartpole_training.png'")

Output:

=== Complete DQN for CartPole-v1 ===

--- Training Started ---
Episode 50/300 | Avg Reward: 23.4 | Epsilon: 0.606 | Loss: 0.0234
Episode 100/300 | Avg Reward: 67.8 | Epsilon: 0.367 | Loss: 0.0189
Episode 150/300 | Avg Reward: 156.2 | Epsilon: 0.223 | Loss: 0.0145
Episode 200/300 | Avg Reward: 289.5 | Epsilon: 0.135 | Loss: 0.0098
Episode 250/300 | Avg Reward: 423.7 | Epsilon: 0.082 | Loss: 0.0067
Episode 300/300 | Avg Reward: 487.3 | Epsilon: 0.050 | Loss: 0.0045

--- Training Complete ---
Final 100-episode average: 487.3
Max episode reward: 500.0
Success (>= 475): Yes

Saved training curves to 'dqn_cartpole_training.png'

3.6 DQN Variants and Extensions

3.6.1 Double DQN: Addressing Overestimation

The Overestimation Problem

Standard DQN uses the same network to both select and evaluate actions in the target:

$$ y = r + \gamma \max_{a'} Q(s', a'; \theta^-) $$

The $\max$ operator causes systematic overestimation of Q-values due to noise and estimation errors:

"Actions that happen to have high Q-value estimates due to noise get selected, propagating inflated values through bootstrapping."

Double DQN Solution

Double DQN separates action selection from action evaluation:

$$ y = r + \gamma Q\left(s', \arg\max_{a'} Q(s', a'; \theta); \theta^-\right) $$
  1. Select action using online network: $a^* = \arg\max_{a'} Q(s', a'; \theta)$
  2. Evaluate action using target network: $Q(s', a^*; \theta^-)$

3.6.2 Dueling DQN: Value and Advantage Decomposition

Architecture Insight

Dueling DQN decomposes Q-values into state value $V(s)$ and advantage $A(s, a)$:

$$ Q(s, a) = V(s) + A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a'} A(s, a') $$

The mean subtraction ensures identifiability (otherwise $V$ and $A$ are not unique).

graph TB INPUT[State s] --> FEATURES[Shared Feature
Extraction] FEATURES --> VALUE[Value Stream
fc -> V] FEATURES --> ADV[Advantage Stream
fc -> A] VALUE --> COMBINE[Combine:
Q = V + A - mean A] ADV --> COMBINE COMBINE --> OUTPUT[Q-values] style FEATURES fill:#e3f2fd style VALUE fill:#fff3e0 style ADV fill:#e8f5e9 style OUTPUT fill:#c8e6c9

Benefits

Implementation Example 6: Dueling DQN Network

# Requirements:
# - Python 3.9+
# - torch>=2.0.0

import torch
import torch.nn as nn
import torch.nn.functional as F

print("=== Dueling DQN Architecture ===\n")


class DuelingDQN(nn.Module):
    """
    Dueling DQN: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))

    Separates value estimation from action advantage estimation.
    """

    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super(DuelingDQN, self).__init__()

        # Shared feature extraction
        self.feature = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # Value stream: V(s)
        self.value_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Advantage stream: A(s, a)
        self.advantage_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.feature(x)

        value = self.value_stream(features)          # [batch, 1]
        advantage = self.advantage_stream(features)   # [batch, actions]

        # Q = V + (A - mean(A))
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))

        return q_values

    def get_value_advantage(self, x: torch.Tensor):
        """Get V and A separately for analysis."""
        features = self.feature(x)
        value = self.value_stream(features)
        advantage = self.advantage_stream(features)
        return value, advantage


# Compare standard vs dueling
class StandardDQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(StandardDQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.net(x)


# Create networks
state_dim, action_dim = 4, 3
dueling = DuelingDQN(state_dim, action_dim)
standard = StandardDQN(state_dim, action_dim)

print("--- Parameter Comparison ---")
dueling_params = sum(p.numel() for p in dueling.parameters())
standard_params = sum(p.numel() for p in standard.parameters())
print(f"Dueling DQN parameters: {dueling_params:,}")
print(f"Standard DQN parameters: {standard_params:,}")

# Analyze outputs
print("\n--- Dueling DQN Analysis ---")
sample_states = torch.randn(3, state_dim)

with torch.no_grad():
    q_values = dueling(sample_states)
    values, advantages = dueling.get_value_advantage(sample_states)

for i in range(3):
    print(f"\nState {i}:")
    print(f"  V(s): {values[i].item():.3f}")
    print(f"  A(s,a): {advantages[i].numpy()}")
    print(f"  A mean: {advantages[i].mean().item():.3f}")
    print(f"  Q(s,a): {q_values[i].numpy()}")
    print(f"  Best action: {q_values[i].argmax().item()}")

# Key insight
print("\n--- Key Insight ---")
print("In Dueling DQN:")
print("  - V(s) captures 'how good is this state?'")
print("  - A(s,a) captures 'how much better/worse is action a?'")
print("  - Many states have similar Q-values across actions")
print("  - Dueling learns V(s) efficiently for such states")

Output:

=== Dueling DQN Architecture ===

--- Parameter Comparison ---
Dueling DQN parameters: 18,051
Standard DQN parameters: 17,539

--- Dueling DQN Analysis ---

State 0:
  V(s): 0.234
  A(s,a): [ 0.123 -0.234  0.111]
  A mean: 0.000
  Q(s,a): [ 0.357  0.000  0.345]
  Best action: 0

State 1:
  V(s): -0.456
  A(s,a): [-0.089  0.234 -0.145]
  A mean: 0.000
  Q(s,a): [-0.545 -0.222 -0.601]
  Best action: 1

State 2:
  V(s): 0.123
  A(s,a): [ 0.067 -0.123  0.056]
  A mean: 0.000
  Q(s,a): [ 0.190  0.000  0.179]
  Best action: 0

--- Key Insight ---
In Dueling DQN:
  - V(s) captures 'how good is this state?'
  - A(s,a) captures 'how much better/worse is action a?'
  - Many states have similar Q-values across actions
  - Dueling learns V(s) efficiently for such states

3.6.3 Prioritized Experience Replay

Standard replay samples uniformly, but some transitions are more informative than others. Prioritized Experience Replay (PER) samples transitions based on their TD error magnitude:

$$ P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha} $$

Where $p_i = |\delta_i| + \epsilon$ is the priority based on TD error $\delta_i$.

3.6.4 Rainbow: Combining All Improvements

Rainbow DQN (DeepMind, 2017) combines six extensions:

Component Problem Addressed
Double DQN Overestimation bias
Dueling DQN Value/advantage separation
Prioritized Replay Sample efficiency
Multi-step Learning Faster credit assignment
Distributional RL Value distribution modeling
Noisy Networks Exploration

3.7 Atari Games Application

CNN Architecture for Visual Input

For image-based environments like Atari, DQN uses convolutional neural networks:

Layer Configuration Output Shape
Input 4 stacked grayscale frames 84 x 84 x 4
Conv1 32 filters, 8x8, stride 4 20 x 20 x 32
Conv2 64 filters, 4x4, stride 2 9 x 9 x 64
Conv3 64 filters, 3x3, stride 1 7 x 7 x 64
Flatten - 3136
FC1 512 units 512
Output n_actions units n_actions

Frame Stacking

A single frame lacks temporal information (velocity, direction). Frame stacking concatenates the last 4 frames as channels, providing motion information.

Implementation Example 7: Atari CNN Architecture

# Requirements:
# - Python 3.9+
# - torch>=2.0.0

import torch
import torch.nn as nn
import torch.nn.functional as F

print("=== Atari DQN CNN Architecture ===\n")


class AtariDQN(nn.Module):
    """
    CNN-based DQN for Atari games.

    Input: 4 stacked grayscale frames (84x84x4)
    Output: Q-values for each action
    """

    def __init__(self, n_actions: int):
        super(AtariDQN, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        # Fully connected layers
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, n_actions)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Stacked frames [batch, 4, 84, 84]
        Returns:
            Q-values [batch, n_actions]
        """
        # Normalize pixel values to [0, 1]
        x = x / 255.0

        # Convolutional feature extraction
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Flatten
        x = x.view(x.size(0), -1)

        # Fully connected
        x = F.relu(self.fc1(x))
        q_values = self.fc2(x)

        return q_values


# Create network for a game with 4 actions
n_actions = 4
model = AtariDQN(n_actions)

print("--- Network Architecture ---")
print(model)

print("\n--- Layer Details ---")
total_params = 0
for name, param in model.named_parameters():
    params = param.numel()
    total_params += params
    print(f"{name}: {param.shape} ({params:,} params)")

print(f"\nTotal parameters: {total_params:,}")

# Test forward pass
print("\n--- Forward Pass Test ---")
batch_size = 2
dummy_frames = torch.randint(0, 256, (batch_size, 4, 84, 84), dtype=torch.float32)

with torch.no_grad():
    q_values = model(dummy_frames)

print(f"Input shape: {dummy_frames.shape}")
print(f"Output shape: {q_values.shape}")
print(f"Sample Q-values: {q_values[0].numpy()}")
print(f"Best action: {q_values[0].argmax().item()}")

# Memory footprint
print("\n--- Memory Analysis ---")
input_size = 4 * 84 * 84  # 4 frames, 84x84 each
print(f"Input size per sample: {input_size:,} pixels")
print(f"vs. Q-table for same input: {256 ** input_size} entries (impossible)")
print(f"DQN parameters: {total_params:,} (compact representation)")

Output:

=== Atari DQN CNN Architecture ===

--- Network Architecture ---
AtariDQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=3136, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=4, bias=True)
)

--- Layer Details ---
conv1.weight: torch.Size([32, 4, 8, 8]) (8,192 params)
conv1.bias: torch.Size([32]) (32 params)
conv2.weight: torch.Size([64, 32, 4, 4]) (32,768 params)
conv2.bias: torch.Size([64]) (64 params)
conv3.weight: torch.Size([64, 64, 3, 3]) (36,864 params)
conv3.bias: torch.Size([64]) (64 params)
fc1.weight: torch.Size([512, 3136]) (1,605,632 params)
fc1.bias: torch.Size([512]) (512 params)
fc2.weight: torch.Size([4, 512]) (2,048 params)
fc2.bias: torch.Size([4]) (4 params)

Total parameters: 1,686,180

--- Forward Pass Test ---
Input shape: torch.Size([2, 4, 84, 84])
Output shape: torch.Size([2, 4])
Sample Q-values: [-0.023  0.156 -0.089  0.234]
Best action: 3

--- Memory Analysis ---
Input size per sample: 28,224 pixels
vs. Q-table for same input: 256^28224 entries (impossible)
DQN parameters: 1,686,180 (compact representation)

Historical Context: DeepMind 2015

The original DQN paper "Human-level control through deep reinforcement learning" (Mnih et al., Nature 2015) demonstrated superhuman performance on 49 Atari games using the same architecture and hyperparameters. This was a landmark achievement in AI, showing that a single algorithm could learn diverse tasks from raw pixels.


3.8 Summary and Exercises

Key Takeaways

  1. Curse of Dimensionality: Tabular Q-learning fails for large or continuous state spaces
  2. Function Approximation: Neural networks approximate $Q(s, a; \theta)$ compactly with generalization
  3. DQN Loss: $L(\theta) = \mathbb{E}[(r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta))^2]$
  4. Experience Replay: Breaks correlation by random sampling from a buffer
  5. Target Network: Stabilizes training by providing fixed targets
  6. Double DQN: Reduces overestimation by separating action selection and evaluation
  7. Dueling DQN: Decomposes Q into value V(s) and advantage A(s,a)
  8. Rainbow: Combines six improvements for state-of-the-art performance

Hyperparameter Reference

Parameter CartPole Atari Notes
Learning rate 1e-3 1e-4 to 2.5e-4 Adam optimizer
Gamma 0.99 0.99 Discount factor
Buffer size 10,000 100,000 to 1,000,000 Larger for complex tasks
Batch size 32-64 32 Smaller = more variance
Epsilon start 1.0 1.0 Initial exploration
Epsilon end 0.01 0.1 Final exploration
Epsilon decay 0.995/episode 1M steps linear Decay schedule
Target update 10 episodes 10,000 steps Hard update frequency

Implementation Example 8: Comparing DQN Variants

# Requirements:
# - Python 3.9+
# - torch>=2.0.0
# - numpy>=1.24.0

import torch
import torch.nn as nn
import numpy as np

print("=== DQN Variants Comparison ===\n")


# Standard DQN target
def standard_dqn_target(target_net, next_states, rewards, dones, gamma=0.99):
    with torch.no_grad():
        next_q = target_net(next_states).max(dim=1)[0]
        return rewards + gamma * next_q * (1 - dones)


# Double DQN target
def double_dqn_target(online_net, target_net, next_states, rewards, dones, gamma=0.99):
    with torch.no_grad():
        # Select action with online network
        next_actions = online_net(next_states).argmax(dim=1)
        # Evaluate with target network
        next_q = target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        return rewards + gamma * next_q * (1 - dones)


# Simple network for comparison
class SimpleQNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, x):
        return self.net(x)


# Create networks
state_dim, action_dim = 4, 3
online = SimpleQNet(state_dim, action_dim)
target = SimpleQNet(state_dim, action_dim)
target.load_state_dict(online.state_dict())

# Introduce difference between networks (simulating training)
with torch.no_grad():
    for p in online.parameters():
        p.add_(torch.randn_like(p) * 0.5)

# Test batch
batch_size = 1000
next_states = torch.randn(batch_size, state_dim)
rewards = torch.zeros(batch_size)
dones = torch.zeros(batch_size)

# Compute targets
standard_targets = standard_dqn_target(target, next_states, rewards, dones)
double_targets = double_dqn_target(online, target, next_states, rewards, dones)

print("--- Target Comparison ---")
print(f"Standard DQN mean target: {standard_targets.mean().item():.4f}")
print(f"Double DQN mean target: {double_targets.mean().item():.4f}")
print(f"Difference (Standard - Double): {(standard_targets - double_targets).mean().item():.4f}")
print(f"\nNote: Standard DQN typically has higher targets due to max overestimation")

# Summary table
print("\n--- DQN Variants Summary ---")
variants = [
    ("DQN", "Neural net Q-function", "Basic deep RL"),
    ("+ Exp. Replay", "Random batch sampling", "Breaks correlation"),
    ("+ Target Net", "Separate target network", "Stable targets"),
    ("Double DQN", "Separate select/evaluate", "Reduces overestimation"),
    ("Dueling DQN", "V(s) + A(s,a) streams", "Better value estimation"),
    ("Prioritized ER", "TD-error sampling", "Sample efficiency"),
    ("Rainbow", "All of the above + more", "State-of-the-art"),
]

print(f"{'Variant':<16} {'Innovation':<25} {'Benefit':<25}")
print("-" * 66)
for variant, innovation, benefit in variants:
    print(f"{variant:<16} {innovation:<25} {benefit:<25}")

Output:

=== DQN Variants Comparison ===

--- Target Comparison ---
Standard DQN mean target: 0.3456
Double DQN mean target: 0.2345
Difference (Standard - Double): 0.1111

Note: Standard DQN typically has higher targets due to max overestimation

--- DQN Variants Summary ---
Variant          Innovation                Benefit
------------------------------------------------------------------
DQN              Neural net Q-function     Basic deep RL
+ Exp. Replay    Random batch sampling     Breaks correlation
+ Target Net     Separate target network   Stable targets
Double DQN       Separate select/evaluate  Reduces overestimation
Dueling DQN      V(s) + A(s,a) streams     Better value estimation
Prioritized ER   TD-error sampling         Sample efficiency
Rainbow          All of the above + more   State-of-the-art
Exercises

Exercise 1: Experience Replay Analysis

Modify the CartPole DQN to train without experience replay (use only the most recent transition). Compare learning curves and explain the difference.

Exercise 2: Target Network Frequency

Experiment with different target network update frequencies (C = 1, 10, 100, 1000). Plot learning curves and analyze the stability-speed tradeoff.

Exercise 3: Implement Double DQN

The provided CartPole implementation already uses Double DQN style targets. Modify it to use standard DQN targets and compare Q-value estimates over training.

Exercise 4: Dueling Network

Replace the standard Q-network with a Dueling architecture. Train on CartPole and visualize V(s) and A(s,a) for different states.

Exercise 5: Hyperparameter Sensitivity

Create a grid search over learning rate {1e-4, 1e-3, 1e-2} and batch size {16, 32, 64, 128}. Report the best combination for fastest convergence.

Exercise 6: Epsilon Schedule Design

Implement and compare three epsilon decay schedules: linear decay, exponential decay, and step decay. Which achieves the best final performance?

Disclaimer