EN | JP | Last sync: 2026-01-10

Chapter 5: Advanced RL and Modern Applications

From SAC and RLHF to Model-Based RL and Real-World Deployment

Reading Time: 35-40 minutes Difficulty: Advanced Code Examples: 7 Exercises: 5

This chapter covers the most advanced and impactful reinforcement learning techniques as of 2025. You will learn about SAC for continuous control, RLHF that powers modern language models like ChatGPT and Claude, model-based approaches like DreamerV3 and MuZero, offline RL with Decision Transformer, and practical deployment strategies.

Learning Objectives

By reading this chapter, you will be able to:


5.1 SAC (Soft Actor-Critic)

5.1.1 Maximum Entropy Reinforcement Learning

Soft Actor-Critic (SAC) is a state-of-the-art off-policy algorithm based on the maximum entropy reinforcement learning framework. It optimizes policies to maximize both expected returns and policy entropy, leading to more robust and exploratory behavior.

The Maximum Entropy Objective

Traditional RL maximizes expected cumulative reward:

$$ J(\pi) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t r(s_t, a_t) \right] $$

SAC adds an entropy bonus to encourage exploration:

$$ J(\pi) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^{\infty} \gamma^t \left( r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t)) \right) \right] $$

where $\mathcal{H}(\pi) = -\sum_a \pi(a|s) \log \pi(a|s)$ is the policy entropy and $\alpha$ is the temperature parameter.

Why Maximum Entropy Works

Benefit Description
Exploration High entropy encourages trying diverse actions, preventing premature convergence
Robustness Stochastic policies are more robust to perturbations and model errors
Multi-modal solutions Can capture multiple good strategies rather than collapsing to one
Transfer Entropy-regularized policies transfer better to new tasks

5.1.2 SAC Architecture

graph TB subgraph "SAC Architecture" S[State s] --> Actor[Actor Network
Gaussian Policy] S --> Q1[Q-Network 1] S --> Q2[Q-Network 2] Actor --> |Sample action| A[Action a] Actor --> |Log prob| LP[log pi] Q1 --> MIN[min Q1, Q2] Q2 --> MIN MIN --> |Soft Q-value| SQ[Q - alpha * log pi] LP --> SQ SQ --> |Policy update| Actor TQ1[Target Q1] --> |Soft update| Q1 TQ2[Target Q2] --> |Soft update| Q2 end style Actor fill:#27ae60,color:#fff style Q1 fill:#e74c3c,color:#fff style Q2 fill:#e74c3c,color:#fff style MIN fill:#f39c12,color:#fff

Key Components

5.1.3 SAC Objective Functions

Soft Q-Function Update (Bellman backup):

$$ Q(s_t, a_t) \leftarrow r_t + \gamma \mathbb{E}_{s_{t+1}} \left[ V(s_{t+1}) \right] $$

where the soft value function is:

$$ V(s) = \mathbb{E}_{a \sim \pi} \left[ Q(s, a) - \alpha \log \pi(a|s) \right] $$

Policy Update (maximize entropy-augmented Q):

$$ \mathcal{L}_\pi(\theta) = \mathbb{E}_{s \sim \mathcal{D}} \left[ \mathbb{E}_{a \sim \pi_\theta} \left[ \alpha \log \pi_\theta(a|s) - Q(s, a) \right] \right] $$

Automatic Temperature Tuning:

$$ \mathcal{L}_\alpha = \mathbb{E}_{a \sim \pi} \left[ -\alpha \left( \log \pi(a|s) + \bar{\mathcal{H}} \right) \right] $$

where $\bar{\mathcal{H}}$ is the target entropy (typically $-\dim(\mathcal{A})$ for continuous actions).

5.1.4 SAC Implementation

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0
# - torch>=2.0.0
# - gymnasium>=0.29.0

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np
from collections import deque
import random

class GaussianPolicy(nn.Module):
    """
    Gaussian policy for SAC with reparameterization trick.
    Outputs mean and log_std, samples action via tanh squashing.
    """

    def __init__(self, state_dim, action_dim, hidden_dim=256,
                 log_std_min=-20, log_std_max=2):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = torch.clamp(self.log_std(x), self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, state):
        """Sample action with reparameterization trick + tanh squashing."""
        mean, log_std = self.forward(state)
        std = log_std.exp()

        # Reparameterization: a = mu + sigma * epsilon
        normal = Normal(mean, std)
        x_t = normal.rsample()  # Differentiable sampling

        # Tanh squashing to [-1, 1]
        action = torch.tanh(x_t)

        # Log probability with tanh correction
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)

        return action, log_prob


class SoftQNetwork(nn.Module):
    """Q-Network for SAC: Q(s, a) -> scalar."""

    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class SAC:
    """
    Soft Actor-Critic with automatic temperature tuning.

    Key features:
    - Maximum entropy RL for robust exploration
    - Double Q-learning to prevent overestimation
    - Off-policy with experience replay
    - Automatic alpha (temperature) tuning
    """

    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
                 tau=0.005, alpha=0.2, auto_alpha=True):
        self.gamma = gamma
        self.tau = tau
        self.auto_alpha = auto_alpha

        # Networks
        self.actor = GaussianPolicy(state_dim, action_dim)
        self.q1 = SoftQNetwork(state_dim, action_dim)
        self.q2 = SoftQNetwork(state_dim, action_dim)
        self.q1_target = SoftQNetwork(state_dim, action_dim)
        self.q2_target = SoftQNetwork(state_dim, action_dim)

        # Copy weights to targets
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        # Optimizers
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.q1_optim = torch.optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optim = torch.optim.Adam(self.q2.parameters(), lr=lr)

        # Automatic temperature tuning
        if auto_alpha:
            self.target_entropy = -action_dim  # Heuristic: -dim(A)
            self.log_alpha = torch.zeros(1, requires_grad=True)
            self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.exp().item()
        else:
            self.alpha = alpha

        # Replay buffer
        self.buffer = deque(maxlen=100000)

    def select_action(self, state, evaluate=False):
        """Select action: stochastic during training, deterministic during eval."""
        state = torch.FloatTensor(state).unsqueeze(0)

        if evaluate:
            with torch.no_grad():
                mean, _ = self.actor(state)
                return torch.tanh(mean).cpu().numpy()[0]
        else:
            with torch.no_grad():
                action, _ = self.actor.sample(state)
                return action.cpu().numpy()[0]

    def update(self, batch_size=256):
        """Perform one SAC update step."""
        if len(self.buffer) < batch_size:
            return

        # Sample batch
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)

        state = torch.FloatTensor(np.array(state))
        action = torch.FloatTensor(np.array(action))
        reward = torch.FloatTensor(reward).unsqueeze(1)
        next_state = torch.FloatTensor(np.array(next_state))
        done = torch.FloatTensor(done).unsqueeze(1)

        # --- Q-function update ---
        with torch.no_grad():
            next_action, next_log_prob = self.actor.sample(next_state)
            target_q1 = self.q1_target(next_state, next_action)
            target_q2 = self.q2_target(next_state, next_action)
            target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_prob
            target_value = reward + (1 - done) * self.gamma * target_q

        q1_loss = F.mse_loss(self.q1(state, action), target_value)
        q2_loss = F.mse_loss(self.q2(state, action), target_value)

        self.q1_optim.zero_grad()
        q1_loss.backward()
        self.q1_optim.step()

        self.q2_optim.zero_grad()
        q2_loss.backward()
        self.q2_optim.step()

        # --- Policy update ---
        new_action, log_prob = self.actor.sample(state)
        q1_new = self.q1(state, new_action)
        q2_new = self.q2(state, new_action)
        q_new = torch.min(q1_new, q2_new)

        actor_loss = (self.alpha * log_prob - q_new).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        # --- Alpha (temperature) update ---
        if self.auto_alpha:
            alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp().item()

        # --- Soft update target networks ---
        for param, target_param in zip(self.q1.parameters(), self.q1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.q2.parameters(), self.q2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


# --- Training Example ---
if __name__ == "__main__":
    import gymnasium as gym

    print("SAC Training on Pendulum-v1")
    print("=" * 50)

    env = gym.make('Pendulum-v1')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    agent = SAC(state_dim, action_dim, auto_alpha=True)

    for episode in range(100):
        state, _ = env.reset()
        episode_reward = 0

        for step in range(200):
            action = agent.select_action(state)
            action_scaled = action * 2.0  # Pendulum action range [-2, 2]

            next_state, reward, terminated, truncated, _ = env.step(action_scaled)
            done = terminated or truncated

            agent.buffer.append((state, action, reward, next_state, float(done)))
            agent.update()

            episode_reward += reward
            state = next_state

            if done:
                break

        if (episode + 1) % 10 == 0:
            print(f"Episode {episode+1}, Reward: {episode_reward:.1f}, Alpha: {agent.alpha:.3f}")

    print("\nTraining complete!")

SAC Key Points: The reparameterization trick enables end-to-end gradient flow through stochastic sampling. Automatic temperature tuning adjusts exploration dynamically - high alpha early for exploration, lower alpha later for exploitation. SAC is considered one of the most sample-efficient model-free algorithms for continuous control.


5.2 RLHF: Reinforcement Learning from Human Feedback

RLHF is the key technology that transformed language models into helpful AI assistants like ChatGPT and Claude. This section explains how RLHF works and why it is crucial for AI alignment.

5.2.1 Why RLHF Matters

Pre-trained language models (like GPT) are trained to predict the next token. However, this objective does not directly optimize for:

RLHF bridges this gap by using human preferences to define what "good" outputs look like, then using RL to optimize the model toward those preferences.

5.2.2 The RLHF Pipeline

flowchart TB subgraph "Step 1: Supervised Fine-Tuning SFT" PT[Pre-trained LLM] --> |Fine-tune on demonstrations| SFT[SFT Model] HD[Human Demonstrations
High-quality responses] --> SFT end subgraph "Step 2: Reward Model Training" SFT --> |Generate responses| RESP[Response Pairs
y1, y2 for prompt x] RESP --> |Human ranks| PREF[Preference Data
y1 better than y2] PREF --> |Train| RM[Reward Model
r = RM x,y] end subgraph "Step 3: RL Fine-Tuning with PPO" SFT --> |Initialize| POLICY[Policy Model] RM --> |Provides reward| PPO[PPO Training] POLICY --> |Generate| GEN[Generated Response] GEN --> RM PPO --> |Update| POLICY KL[KL Penalty
Stay close to SFT] --> PPO end POLICY --> FINAL[RLHF-Aligned Model
ChatGPT, Claude] style PT fill:#e3f2fd style SFT fill:#fff3e0 style RM fill:#f3e5f5 style POLICY fill:#e8f5e9 style FINAL fill:#ffcdd2

5.2.3 Step 1: Supervised Fine-Tuning (SFT)

Start with a pre-trained LLM and fine-tune it on high-quality demonstrations:

$$ \mathcal{L}_{\text{SFT}} = -\mathbb{E}_{(x, y) \sim \mathcal{D}_{\text{demo}}} \left[ \log \pi_\theta(y | x) \right] $$

where $\mathcal{D}_{\text{demo}}$ contains human-written responses to prompts.

Why SFT First?

5.2.4 Step 2: Reward Model Training

Train a reward model to predict human preferences from comparison data.

Data Collection

  1. Sample prompts from a diverse dataset
  2. Generate multiple responses using the SFT model
  3. Human annotators rank responses (e.g., A > B)

Bradley-Terry Model

Model the probability of preferring response $y_1$ over $y_2$:

$$ P(y_1 \succ y_2 | x) = \sigma(r(x, y_1) - r(x, y_2)) $$

where $\sigma$ is the sigmoid function and $r(x, y)$ is the reward model.

Reward Model Loss

$$ \mathcal{L}_{\text{RM}} = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma(r(x, y_w) - r(x, y_l)) \right] $$

where $y_w$ is the preferred (winning) response and $y_l$ is the rejected (losing) response.

# Reward Model Training (Conceptual)
import torch
import torch.nn as nn

class RewardModel(nn.Module):
    """
    Reward model for RLHF.
    Takes (prompt, response) and outputs a scalar reward.
    """

    def __init__(self, base_model):
        super().__init__()
        self.backbone = base_model  # Pre-trained LLM
        self.reward_head = nn.Linear(base_model.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # Get last hidden state
        outputs = self.backbone(input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state[:, -1, :]  # [CLS] or last token
        reward = self.reward_head(last_hidden)
        return reward.squeeze(-1)


def reward_model_loss(reward_model, chosen_ids, rejected_ids, attention_mask_c, attention_mask_r):
    """
    Bradley-Terry loss for preference learning.
    """
    r_chosen = reward_model(chosen_ids, attention_mask_c)
    r_rejected = reward_model(rejected_ids, attention_mask_r)

    # Loss: -log(sigmoid(r_chosen - r_rejected))
    loss = -torch.log(torch.sigmoid(r_chosen - r_rejected)).mean()
    return loss

5.2.5 Step 3: PPO Fine-Tuning

Use PPO to optimize the policy (LLM) to maximize the reward model's score while staying close to the SFT model.

RLHF Objective

$$ \mathcal{L}_{\text{RLHF}} = \mathbb{E}_{x \sim \mathcal{D}, y \sim \pi_\theta} \left[ r(x, y) - \beta \cdot \text{KL}(\pi_\theta(y|x) \| \pi_{\text{SFT}}(y|x)) \right] $$

The KL penalty prevents the policy from:

PPO for Language Models

The PPO clipped objective from Chapter 4 is applied to token-level actions:

$$ \mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E}_t \left[ \min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t) \right] $$

where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ is the probability ratio for each token.

# RLHF PPO Training (Simplified Conceptual Code)
def compute_rlhf_reward(response_ids, prompt_ids, reward_model, ref_model, policy_model, beta=0.1):
    """
    Compute RLHF reward with KL penalty.

    reward = r(x, y) - beta * KL(policy || reference)
    """
    # Get reward from reward model
    reward_score = reward_model(prompt_ids, response_ids)

    # Compute KL divergence per token
    with torch.no_grad():
        ref_logprobs = ref_model.get_log_probs(prompt_ids, response_ids)
    policy_logprobs = policy_model.get_log_probs(prompt_ids, response_ids)

    kl_div = (policy_logprobs - ref_logprobs).sum(dim=-1)

    # Total reward
    total_reward = reward_score - beta * kl_div
    return total_reward


def ppo_step(policy_model, old_logprobs, states, actions, advantages, epsilon=0.2):
    """
    PPO update for language model policy.
    """
    new_logprobs = policy_model.get_log_probs(states, actions)

    # Probability ratio
    ratio = torch.exp(new_logprobs - old_logprobs)

    # Clipped objective
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages

    loss = -torch.min(surr1, surr2).mean()
    return loss

5.2.6 DPO: Direct Preference Optimization

DPO (Direct Preference Optimization) is a simpler alternative to RLHF that skips the reward model and directly optimizes preferences.

Key Insight

The optimal policy under the RLHF objective can be expressed in closed form:

$$ \pi^*(y|x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) $$

This means we can rearrange to express the reward in terms of policies:

$$ r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{\text{ref}}(y|x)} + \beta \log Z(x) $$

DPO Loss

Substituting into the Bradley-Terry preference model gives the DPO objective:

$$ \mathcal{L}_{\text{DPO}}(\theta) = -\mathbb{E}_{(x, y_w, y_l)} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] $$

DPO vs RLHF

Aspect RLHF DPO
Reward Model Required (separate training) Not needed (implicit)
RL Training PPO with online sampling Supervised learning only
Complexity Higher (3 stages) Lower (1 stage after SFT)
Stability Can be unstable (RL issues) More stable (supervised)
Flexibility Can iterate with new prompts Fixed to preference dataset
# DPO Implementation (Simplified)
def dpo_loss(policy_model, ref_model, chosen_ids, rejected_ids, beta=0.1):
    """
    Direct Preference Optimization loss.

    No reward model needed - directly optimize preferences.
    """
    # Get log probabilities from policy and reference
    policy_chosen_logps = policy_model.get_log_probs(chosen_ids)
    policy_rejected_logps = policy_model.get_log_probs(rejected_ids)

    with torch.no_grad():
        ref_chosen_logps = ref_model.get_log_probs(chosen_ids)
        ref_rejected_logps = ref_model.get_log_probs(rejected_ids)

    # Compute log ratios
    chosen_ratio = beta * (policy_chosen_logps - ref_chosen_logps)
    rejected_ratio = beta * (policy_rejected_logps - ref_rejected_logps)

    # DPO loss: -log sigmoid(chosen_ratio - rejected_ratio)
    loss = -torch.log(torch.sigmoid(chosen_ratio - rejected_ratio)).mean()
    return loss

RLHF Importance for AI Alignment: RLHF is not just about making models more helpful - it is a critical technique for AI safety. By learning from human feedback, models can be steered away from harmful behaviors that emerge from pure next-token prediction. The "H" in RLHF represents the human values we want to instill in AI systems.


5.3 Model-Based Reinforcement Learning

Model-based RL learns a model of the environment dynamics and uses it for planning, achieving much higher sample efficiency than model-free methods.

5.3.1 World Models Concept

A world model predicts how the environment will respond to actions:

$$ \hat{s}_{t+1}, \hat{r}_t = f_\theta(s_t, a_t) $$

This allows the agent to "imagine" outcomes without actually taking actions in the real environment.

graph LR subgraph "Model-Free RL" E1[Real Environment] --> |Experience| P1[Policy] P1 --> |Actions| E1 end subgraph "Model-Based RL" E2[Real Environment] --> |Experience| WM[World Model] WM --> |Imagined Experience| P2[Policy] P2 --> |Actions| E2 WM --> |Planning| PLAN[MCTS / MPC] PLAN --> P2 end style WM fill:#f39c12,color:#fff style PLAN fill:#9b59b6,color:#fff

Advantages of Model-Based RL

Advantage Description
Sample Efficiency Learn from 10-100x fewer environment interactions
Imagination Generate unlimited synthetic experience for training
Planning Look ahead to evaluate action sequences
Transfer World model can transfer to new tasks in same environment

5.3.2 DreamerV3 (2023-2025)

DreamerV3 is a state-of-the-art world model algorithm that has achieved remarkable results across 150+ diverse tasks, including collecting diamonds in Minecraft from scratch.

Key Innovations

DreamerV3 Architecture

graph TB subgraph "DreamerV3 World Model" OBS[Observation x_t] --> ENC[Encoder CNN/MLP] ENC --> POST[Posterior z_t
Discrete Latent] H[Recurrent State h_t] --> PRIOR[Prior z_t] A[Action a_t-1] --> PRIOR POST --> DEC[Decoder] H --> DEC DEC --> PRED[Predicted x_t, r_t] POST --> |Update| H H --> |Next step| H end subgraph "Imagination Training" H --> |Unroll| IMAG[Imagined Trajectories] IMAG --> ACTOR[Actor Loss] IMAG --> CRITIC[Critic Loss] end style POST fill:#27ae60,color:#fff style IMAG fill:#e74c3c,color:#fff

Learning in Imagination

DreamerV3 trains the policy entirely in imagination:

  1. Collect real experience: Interact with environment, store in buffer
  2. Train world model: Learn to predict observations and rewards
  3. Imagine trajectories: Unroll world model with current policy
  4. Train actor-critic: Use imagined returns to update policy

Minecraft Diamond Achievement

DreamerV3 was the first algorithm to collect a diamond in Minecraft without human demonstrations or curriculum learning - a task requiring:

5.3.3 MuZero Principles

MuZero (DeepMind, 2020) combines learned dynamics models with Monte Carlo Tree Search (MCTS), achieving superhuman performance on Atari, Chess, Shogi, and Go without knowing the game rules.

Key Insight

MuZero learns three components:

MuZero Architecture

graph LR O[Observation o_t] --> |h| S0[Hidden State s_0] S0 --> |f| PV0[Policy p_0
Value v_0] S0 --> |g with a_0| S1[State s_1
Reward r_0] S1 --> |f| PV1[p_1, v_1] S1 --> |g with a_1| S2[State s_2
Reward r_1] S2 --> |f| PV2[p_2, v_2] subgraph "MCTS Planning" PV0 --> MCTS[Tree Search] PV1 --> MCTS PV2 --> MCTS MCTS --> ACT[Action Selection] end style MCTS fill:#9b59b6,color:#fff

Why MuZero Works Without Rules

Model-Based vs Model-Free Comparison

Aspect Model-Free (DQN, PPO) Model-Based (DreamerV3, MuZero)
Sample Efficiency Low (millions of steps) High (10-100x fewer)
Computation Lower per step Higher (planning overhead)
Model Errors No model to be wrong Errors can compound
Generalization Task-specific Model can transfer

5.4 Offline RL and Decision Transformer

Offline RL (also called Batch RL) learns policies from fixed datasets without environment interaction - crucial when real-world interaction is expensive or dangerous.

5.4.1 The Offline RL Challenge

Standard RL algorithms fail on offline data due to distribution shift:

When to Use Offline RL

5.4.2 Decision Transformer

Decision Transformer (2021) reformulates RL as a sequence modeling problem, using Transformers to generate actions conditioned on desired returns.

Key Insight

Instead of predicting values and optimizing returns, condition on the desired return and predict actions that achieve it.

Sequence Format

Decision Transformer models trajectories as sequences:

$$ \tau = (\hat{R}_1, s_1, a_1, \hat{R}_2, s_2, a_2, \ldots, \hat{R}_T, s_T, a_T) $$

where $\hat{R}_t = \sum_{t'=t}^{T} r_{t'}$ is the return-to-go (remaining return from time $t$).

graph LR subgraph "Decision Transformer Input" R1[R-to-go 1] --> EMB[Token Embedding] S1[State 1] --> EMB A1[Action 1] --> EMB R2[R-to-go 2] --> EMB S2[State 2] --> EMB end EMB --> TF[Transformer
Causal Attention] TF --> PRED[Predicted Action a_2] style TF fill:#9b59b6,color:#fff style PRED fill:#27ae60,color:#fff

Training Objective

Simple supervised learning on offline trajectories:

$$ \mathcal{L} = \mathbb{E}_{\tau \sim \mathcal{D}} \left[ \sum_t \| a_t - \hat{a}_t \|^2 \right] $$

where $\hat{a}_t$ is the predicted action given context $(\hat{R}_1, s_1, a_1, \ldots, \hat{R}_t, s_t)$.

Inference: Return Conditioning

At test time, specify the desired return and the model generates actions to achieve it:

  1. Set $\hat{R}_1$ to desired total return (e.g., expert-level performance)
  2. Observe $s_1$, predict $a_1$
  3. Execute $a_1$, observe $r_1, s_2$
  4. Update $\hat{R}_2 = \hat{R}_1 - r_1$
  5. Repeat
# Decision Transformer (Simplified Implementation)
import torch
import torch.nn as nn

class DecisionTransformer(nn.Module):
    """
    Decision Transformer for offline RL.
    Conditions on return-to-go to generate actions.
    """

    def __init__(self, state_dim, action_dim, hidden_dim=128,
                 n_layers=3, n_heads=1, max_length=20):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.max_length = max_length

        # Embeddings for each modality
        self.state_embed = nn.Linear(state_dim, hidden_dim)
        self.action_embed = nn.Linear(action_dim, hidden_dim)
        self.return_embed = nn.Linear(1, hidden_dim)

        # Positional embedding
        self.pos_embed = nn.Embedding(max_length * 3, hidden_dim)

        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=n_heads,
            dim_feedforward=hidden_dim * 4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Action prediction head
        self.action_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, returns_to_go, states, actions, timesteps):
        """
        Forward pass: predict actions given context.

        Args:
            returns_to_go: (batch, seq_len, 1)
            states: (batch, seq_len, state_dim)
            actions: (batch, seq_len, action_dim)
            timesteps: (batch, seq_len)
        """
        batch_size, seq_len = states.shape[0], states.shape[1]

        # Embed each modality
        state_embeddings = self.state_embed(states)
        action_embeddings = self.action_embed(actions)
        return_embeddings = self.return_embed(returns_to_go)

        # Interleave: [R_1, s_1, a_1, R_2, s_2, a_2, ...]
        # Shape: (batch, seq_len * 3, hidden_dim)
        stacked = torch.stack([return_embeddings, state_embeddings, action_embeddings], dim=2)
        stacked = stacked.reshape(batch_size, seq_len * 3, self.hidden_dim)

        # Add positional embeddings
        positions = torch.arange(seq_len * 3, device=states.device).unsqueeze(0)
        stacked = stacked + self.pos_embed(positions)

        # Causal mask (can only attend to past)
        mask = nn.Transformer.generate_square_subsequent_mask(seq_len * 3).to(states.device)

        # Transformer forward
        output = self.transformer(stacked, mask=mask)

        # Predict actions (from state positions: indices 1, 4, 7, ...)
        state_positions = torch.arange(1, seq_len * 3, 3)
        action_preds = self.action_head(output[:, state_positions, :])

        return action_preds

    def get_action(self, returns_to_go, states, actions, timesteps):
        """Get action for current state (inference mode)."""
        action_preds = self.forward(returns_to_go, states, actions, timesteps)
        return action_preds[:, -1, :]  # Return last predicted action


# Example usage
if __name__ == "__main__":
    print("Decision Transformer Example")
    print("=" * 50)

    state_dim, action_dim = 4, 2
    model = DecisionTransformer(state_dim, action_dim)

    # Simulate offline data
    batch_size, seq_len = 8, 10
    returns = torch.randn(batch_size, seq_len, 1)
    states = torch.randn(batch_size, seq_len, state_dim)
    actions = torch.randn(batch_size, seq_len, action_dim)
    timesteps = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)

    # Forward pass
    pred_actions = model(returns, states, actions, timesteps)
    print(f"Input shape: states {states.shape}")
    print(f"Output shape: predicted actions {pred_actions.shape}")

    # Inference: get action for new state
    action = model.get_action(returns, states, actions, timesteps)
    print(f"Inference action: {action[0].detach().numpy()}")

Decision Transformer Results


5.5 Multi-Agent Reinforcement Learning

Multi-Agent RL (MARL) involves multiple agents learning and interacting in a shared environment.

5.5.1 CTDE: Centralized Training, Decentralized Execution

The dominant paradigm in MARL:

graph TB subgraph "Centralized Training" G[Global State] --> CRITIC[Centralized Critic] O1[Obs Agent 1] --> CRITIC O2[Obs Agent 2] --> CRITIC CRITIC --> |Value estimates| TRAIN[Training Signal] end subgraph "Decentralized Execution" O1_E[Local Obs 1] --> A1[Agent 1 Policy] O2_E[Local Obs 2] --> A2[Agent 2 Policy] A1 --> ACT1[Action 1] A2 --> ACT2[Action 2] end style CRITIC fill:#e74c3c,color:#fff style A1 fill:#27ae60,color:#fff style A2 fill:#27ae60,color:#fff

5.5.2 Cooperative vs Competitive Settings

Setting Description Examples
Cooperative Shared reward, team objective Robot swarms, StarCraft micromanagement
Competitive Zero-sum, adversarial Chess, Go, poker
Mixed Both cooperation and competition Traffic, negotiations, social dilemmas

5.5.3 Key Algorithms


5.6 Safe Reinforcement Learning

Real-world RL applications require safety constraints during both training and deployment.

5.6.1 Constrained MDPs

Extend standard MDPs with cost constraints:

$$ \max_\pi J(\pi) = \mathbb{E}\left[\sum_t \gamma^t r_t\right] \quad \text{s.t.} \quad C_i(\pi) = \mathbb{E}\left[\sum_t \gamma^t c_i(s_t, a_t)\right] \leq d_i $$

where $c_i$ are cost functions (e.g., unsafe actions) and $d_i$ are thresholds.

5.6.2 Solution Methods

5.6.3 Real-World Safety Requirements

Domain Safety Constraints
Robotics Joint limits, collision avoidance, force limits
Autonomous Driving Lane boundaries, minimum distance, speed limits
Healthcare Drug dosage limits, treatment safety thresholds
Finance Position limits, drawdown constraints, risk budgets

5.7 Real-World Applications

5.7.1 Robotics: Sim-to-Real Transfer

Training robots in simulation then transferring to the real world.

Key Techniques

SLIM System (2024)

Stanford's Sim-to-real Learning for Manipulation:

5.7.2 Autonomous Driving

RL is used for:

5.7.3 Game AI

System Game Achievement
AlphaGo Go Beat world champion Lee Sedol
AlphaStar StarCraft II Grandmaster level, beat top pros
OpenAI Five Dota 2 Beat world champions OG
MuZero Atari, Chess, Shogi Superhuman without knowing rules

5.7.4 Resource Optimization


5.8 Practical Training with Stable-Baselines3

5.8.1 SAC on Continuous Control

# Requirements:
# - Python 3.9+
# - stable-baselines3>=2.1.0
# - gymnasium>=0.29.0

"""
Stable-Baselines3 SAC Example
Professional-grade RL training with minimal code
"""

import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor
import numpy as np

def train_sac_pendulum():
    """Train SAC on Pendulum-v1 with monitoring and evaluation."""

    print("SAC Training with Stable-Baselines3")
    print("=" * 50)

    # Create environment with monitoring
    env = gym.make("Pendulum-v1")
    env = Monitor(env)

    # Evaluation environment
    eval_env = gym.make("Pendulum-v1")
    eval_env = Monitor(eval_env)

    # SAC with optimized hyperparameters
    model = SAC(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        buffer_size=100000,
        learning_starts=1000,
        batch_size=256,
        tau=0.005,              # Soft update coefficient
        gamma=0.99,
        train_freq=1,
        gradient_steps=1,
        ent_coef="auto",        # Automatic entropy tuning
        verbose=1,
        tensorboard_log="./sac_pendulum_tensorboard/"
    )

    # Callbacks for monitoring
    eval_callback = EvalCallback(
        eval_env,
        best_model_save_path="./logs/best_model/",
        log_path="./logs/",
        eval_freq=5000,
        deterministic=True,
        render=False
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=10000,
        save_path="./logs/checkpoints/",
        name_prefix="sac_pendulum"
    )

    # Train
    print("\nStarting training...")
    model.learn(
        total_timesteps=50000,
        callback=[eval_callback, checkpoint_callback],
        progress_bar=True
    )

    # Evaluate final performance
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)
    print(f"\nFinal evaluation: {mean_reward:.2f} +/- {std_reward:.2f}")

    # Save final model
    model.save("sac_pendulum_final")
    print("Model saved to sac_pendulum_final.zip")

    return model


def train_ppo_lunarlander():
    """Train PPO on LunarLander with vectorized environments."""

    from stable_baselines3 import PPO
    from stable_baselines3.common.env_util import make_vec_env

    print("\nPPO Training with Vectorized Environments")
    print("=" * 50)

    # Vectorized environment (4 parallel environments)
    env = make_vec_env("LunarLander-v2", n_envs=4)

    model = PPO(
        "MlpPolicy",
        env,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1
    )

    print("Training for 100k timesteps...")
    model.learn(total_timesteps=100000, progress_bar=True)

    # Evaluate
    eval_env = gym.make("LunarLander-v2")
    mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)
    print(f"Evaluation: {mean_reward:.2f} +/- {std_reward:.2f}")

    return model


def hyperparameter_tuning_example():
    """Example of hyperparameter tuning with Optuna (optional)."""

    try:
        import optuna
        from stable_baselines3 import PPO

        print("\nHyperparameter Tuning with Optuna")
        print("=" * 50)

        def objective(trial):
            """Optuna objective function."""
            lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
            n_steps = trial.suggest_categorical("n_steps", [256, 512, 1024, 2048])
            gamma = trial.suggest_float("gamma", 0.9, 0.9999)

            env = gym.make("CartPole-v1")
            model = PPO("MlpPolicy", env, learning_rate=lr, n_steps=n_steps, gamma=gamma, verbose=0)
            model.learn(total_timesteps=10000)

            mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
            return mean_reward

        study = optuna.create_study(direction="maximize")
        study.optimize(objective, n_trials=10, show_progress_bar=True)

        print(f"\nBest hyperparameters: {study.best_params}")
        print(f"Best reward: {study.best_value:.2f}")

    except ImportError:
        print("Optuna not installed. Run: pip install optuna")


if __name__ == "__main__":
    # Example 1: SAC on Pendulum
    sac_model = train_sac_pendulum()

    # Example 2: PPO on LunarLander
    ppo_model = train_ppo_lunarlander()

    # Example 3: Hyperparameter tuning (optional)
    # hyperparameter_tuning_example()

    print("\n" + "=" * 50)
    print("All training examples completed!")
    print("View TensorBoard logs: tensorboard --logdir ./sac_pendulum_tensorboard/")

5.9 Summary and Future Directions

What We Learned

Future Directions (2025+)

Direction Description
Foundation Models for RL Pre-trained models that transfer across tasks and environments
RLHF Improvements Constitutional AI, debate, scalable oversight
Sim-to-Real at Scale Better transfer from simulation to diverse real robots
Hierarchical RL Learning reusable skills and temporal abstractions
Multi-Modal RL Vision-language-action models for robotics

Resources


Exercises

Exercise 5.1: SAC Temperature Analysis

Problem: Compare SAC with fixed alpha values (0.05, 0.1, 0.2, 0.5) vs automatic tuning on Pendulum-v1.

Tasks:

Expected Insight: Automatic tuning should achieve good performance without manual tuning.

Exercise 5.2: RLHF Reward Model

Problem: Implement a simple reward model for text classification preferences.

Tasks:

Hint: Use a pre-trained sentence encoder (e.g., sentence-transformers) as backbone.

Exercise 5.3: Decision Transformer Implementation

Problem: Train Decision Transformer on CartPole offline data.

Tasks:

Hint: Start with short context lengths (5-10 steps) for faster iteration.

Exercise 5.4: Multi-Agent Cooperative Task

Problem: Implement a simple multi-agent environment where agents must cooperate.

Tasks:

Hint: Use PettingZoo for multi-agent environments.

Exercise 5.5: Safe RL with Constraints

Problem: Add safety constraints to a standard RL task.

Tasks:

Hint: Start with a simple reward penalty approach before full CPO.


Disclaimer