Chapter 4: LLM Inference and Optimization

From Research to Production: Efficient Deployment Strategies

Reading Time: 35-40 minutes Difficulty: Advanced Last Updated: 2026-01
Disclaimer: This chapter covers cutting-edge inference optimization techniques as of early 2026. The field evolves rapidly, and specific frameworks and approaches may change. Always consult official documentation for implementation details.

1. LLM Inference Fundamentals

Inference is where LLMs deliver value: generating text, answering questions, and powering applications. However, inference presents unique challenges compared to training. While training optimizes for throughput over hours or days, inference must deliver responses within milliseconds to seconds, often under strict cost constraints.

1.1 The Inference Pipeline

flowchart LR A[Input Text] --> B[Tokenization] B --> C[Prefill Phase] C --> D[Decode Phase] D --> E[Token Generation] E -->|Loop| D E --> F[Output Text] style C fill:#e3f2fd style D fill:#fff3e0

LLM inference consists of two distinct phases with different computational characteristics:

Two-Phase Inference

Prefill Phase (Prompt Processing):

Decode Phase (Token Generation):

1.2 Memory Requirements

Understanding memory consumption is crucial for inference optimization. For a model with parameters $P$ and context length $L$:

def calculate_inference_memory(
    params_billions: float,
    context_length: int,
    batch_size: int,
    num_layers: int,
    hidden_dim: int,
    num_kv_heads: int,
    precision: str = "fp16"
) -> dict:
    """
    Calculate memory requirements for LLM inference.

    Args:
        params_billions: Model size in billions of parameters
        context_length: Maximum sequence length
        batch_size: Number of concurrent requests
        num_layers: Number of transformer layers
        hidden_dim: Model hidden dimension
        num_kv_heads: Number of KV attention heads
        precision: Weight precision ("fp32", "fp16", "int8", "int4")

    Returns:
        Memory breakdown in GB
    """
    bytes_per_param = {
        "fp32": 4, "fp16": 2, "bf16": 2,
        "int8": 1, "int4": 0.5, "fp8": 1
    }

    # Model weights
    weight_bytes = params_billions * 1e9 * bytes_per_param[precision]
    weight_gb = weight_bytes / (1024**3)

    # KV cache: 2 (K and V) * layers * batch * seq * heads * head_dim
    head_dim = hidden_dim // num_kv_heads
    kv_cache_bytes = (
        2 * num_layers * batch_size * context_length *
        num_kv_heads * head_dim * 2  # fp16 for KV cache
    )
    kv_cache_gb = kv_cache_bytes / (1024**3)

    # Activation memory (approximate)
    activation_gb = batch_size * context_length * hidden_dim * 4 / (1024**3)

    return {
        "weights_gb": round(weight_gb, 2),
        "kv_cache_gb": round(kv_cache_gb, 2),
        "activations_gb": round(activation_gb, 2),
        "total_gb": round(weight_gb + kv_cache_gb + activation_gb, 2)
    }

# Example: Llama-3-70B with 128K context
memory = calculate_inference_memory(
    params_billions=70,
    context_length=128000,
    batch_size=1,
    num_layers=80,
    hidden_dim=8192,
    num_kv_heads=8,  # GQA
    precision="int4"
)
print(f"Memory breakdown: {memory}")
# weights_gb: 35.0, kv_cache_gb: 40.0, activations_gb: 4.0, total_gb: 79.0

1.3 Latency Components

Component Description Optimization Target
Time To First Token (TTFT) Latency until first token appears Prefill optimization, prompt caching
Time Per Output Token (TPOT) Latency per generated token Decode optimization, batching
End-to-End Latency TTFT + (tokens × TPOT) Speculative decoding
Throughput Tokens/second across all requests Continuous batching

2. Model Quantization

Quantization reduces model precision from 16-bit or 32-bit floating point to lower bit-widths, dramatically reducing memory footprint and accelerating inference. Modern quantization techniques achieve near-lossless quality while cutting memory by 4-8x.

2.1 Quantization Fundamentals

flowchart TB subgraph "Precision Levels" FP32[FP32: 32 bits] --> FP16[FP16/BF16: 16 bits] FP16 --> FP8[FP8: 8 bits] FP8 --> INT8[INT8: 8 bits] INT8 --> INT4[INT4: 4 bits] end subgraph "Memory Savings" M1[100%] --> M2[50%] M2 --> M3[25%] M3 --> M4[25%] M4 --> M5[12.5%] end
import torch
import torch.nn.functional as F
from typing import Tuple

class Quantizer:
    """Basic quantization utilities for understanding the fundamentals."""

    @staticmethod
    def absmax_quantize(
        tensor: torch.Tensor,
        bits: int = 8
    ) -> Tuple[torch.Tensor, float]:
        """
        Absmax (symmetric) quantization.
        Maps values to [-2^(bits-1), 2^(bits-1)-1]
        """
        qmax = 2 ** (bits - 1) - 1
        scale = tensor.abs().max() / qmax
        quantized = torch.round(tensor / scale).to(torch.int8)
        return quantized, scale

    @staticmethod
    def absmax_dequantize(
        quantized: torch.Tensor,
        scale: float
    ) -> torch.Tensor:
        """Dequantize absmax-quantized tensor."""
        return quantized.float() * scale

    @staticmethod
    def zeropoint_quantize(
        tensor: torch.Tensor,
        bits: int = 8
    ) -> Tuple[torch.Tensor, float, int]:
        """
        Zero-point (asymmetric) quantization.
        Better for tensors with non-symmetric distributions.
        """
        qmin, qmax = 0, 2**bits - 1
        min_val, max_val = tensor.min(), tensor.max()

        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = int(round(-min_val / scale))

        quantized = torch.round(tensor / scale + zero_point)
        quantized = torch.clamp(quantized, qmin, qmax).to(torch.uint8)

        return quantized, scale, zero_point

    @staticmethod
    def block_quantize(
        tensor: torch.Tensor,
        block_size: int = 128,
        bits: int = 4
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Block-wise quantization (used in GPTQ, AWQ).
        Each block has its own scale factor.
        """
        # Reshape into blocks
        original_shape = tensor.shape
        tensor_flat = tensor.view(-1)

        # Pad to multiple of block_size
        pad_len = (block_size - len(tensor_flat) % block_size) % block_size
        if pad_len:
            tensor_flat = F.pad(tensor_flat, (0, pad_len))

        blocks = tensor_flat.view(-1, block_size)

        # Quantize each block independently
        qmax = 2 ** (bits - 1) - 1
        scales = blocks.abs().max(dim=1, keepdim=True).values / qmax
        scales = scales.clamp(min=1e-8)

        quantized = torch.round(blocks / scales).to(torch.int8)

        return quantized, scales.squeeze()

# Demonstration
tensor = torch.randn(1024, 1024)

# Compare quantization methods
int8_quant, int8_scale = Quantizer.absmax_quantize(tensor, bits=8)
int4_quant, int4_scales = Quantizer.block_quantize(tensor, bits=4)

print(f"Original size: {tensor.numel() * 4 / 1024:.1f} KB")
print(f"INT8 size: {int8_quant.numel() * 1 / 1024:.1f} KB (4x reduction)")
print(f"INT4 size: {int4_quant.numel() * 0.5 / 1024:.1f} KB (8x reduction)")

2.2 Advanced Quantization Methods

GPTQ (Post-Training Quantization)

GPTQ uses second-order information (Hessian) to minimize quantization error. It processes weights column by column, adjusting remaining weights to compensate for quantization errors.

# Using AutoGPTQ for quantization
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

def quantize_with_gptq(
    model_id: str,
    output_dir: str,
    bits: int = 4,
    group_size: int = 128,
    calibration_samples: int = 128
):
    """
    Quantize a model using GPTQ.

    Args:
        model_id: HuggingFace model ID
        output_dir: Where to save quantized model
        bits: Target bit width (4 or 8)
        group_size: Block size for quantization
        calibration_samples: Number of samples for calibration
    """
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Configure quantization
    quantize_config = BaseQuantizeConfig(
        bits=bits,
        group_size=group_size,
        desc_act=True,  # Activation-order quantization
        sym=False,  # Asymmetric quantization
    )

    # Load model for quantization
    model = AutoGPTQForCausalLM.from_pretrained(
        model_id,
        quantize_config=quantize_config,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    # Prepare calibration data
    calibration_data = prepare_calibration_data(
        tokenizer,
        num_samples=calibration_samples
    )

    # Quantize
    model.quantize(calibration_data)

    # Save
    model.save_quantized(output_dir)
    tokenizer.save_pretrained(output_dir)

    return model

def prepare_calibration_data(tokenizer, num_samples=128):
    """Prepare calibration data from C4 dataset."""
    from datasets import load_dataset

    dataset = load_dataset("c4", "en", split="train", streaming=True)
    samples = []

    for i, example in enumerate(dataset):
        if i >= num_samples:
            break
        tokenized = tokenizer(
            example["text"],
            truncation=True,
            max_length=2048,
            return_tensors="pt"
        )
        samples.append(tokenized.input_ids)

    return samples

AWQ (Activation-aware Weight Quantization)

AWQ identifies salient weight channels based on activation magnitudes. It preserves these important channels at higher precision, achieving better quality than uniform quantization.

# Using AWQ for quantization
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

def quantize_with_awq(
    model_id: str,
    output_dir: str,
    bits: int = 4,
    group_size: int = 128,
    zero_point: bool = True
):
    """
    Quantize using AWQ (Activation-aware Weight Quantization).

    AWQ identifies salient weights based on activation patterns
    and preserves them at higher precision.
    """
    # Load model
    model = AutoAWQForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        safetensors=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Quantization config
    quant_config = {
        "zero_point": zero_point,
        "q_group_size": group_size,
        "w_bit": bits,
        "version": "GEMM"  # or "GEMV" for batch size 1
    }

    # Quantize with calibration
    model.quantize(
        tokenizer,
        quant_config=quant_config,
        calib_data="pileval",  # or custom dataset
        calib_size=512
    )

    # Save quantized model
    model.save_quantized(output_dir)
    tokenizer.save_pretrained(output_dir)

    print(f"Model quantized to {bits}-bit and saved to {output_dir}")
    return model

# Load and use quantized model
def load_awq_model(model_path: str):
    """Load AWQ quantized model for inference."""
    model = AutoAWQForCausalLM.from_quantized(
        model_path,
        fuse_layers=True,  # Fuse QKV and MLP for speed
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer

FP8 Quantization (2025-2026 Standard)

FP8 has emerged as the preferred format for inference on modern GPUs (H100, H200, B200). It maintains dynamic range through floating-point representation while achieving INT8-level efficiency.

import torch
from typing import Literal

def fp8_quantize(
    tensor: torch.Tensor,
    format: Literal["e4m3", "e5m2"] = "e4m3"
) -> torch.Tensor:
    """
    Convert tensor to FP8 format.

    E4M3: 4-bit exponent, 3-bit mantissa
        - Range: ±448, precision: ~1/16
        - Better for weights and activations

    E5M2: 5-bit exponent, 2-bit mantissa
        - Range: ±57344, precision: ~1/4
        - Better for gradients (wider range)
    """
    if format == "e4m3":
        return tensor.to(torch.float8_e4m3fn)
    else:
        return tensor.to(torch.float8_e5m2)

# FP8 inference with scaling
class FP8Linear(torch.nn.Module):
    """FP8 linear layer with per-tensor scaling."""

    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # FP8 weights with scale
        self.register_buffer(
            "weight",
            torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn)
        )
        self.register_buffer("weight_scale", torch.ones(1))
        self.register_buffer("input_scale", torch.ones(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Scale input to FP8 range
        x_scaled = x / self.input_scale
        x_fp8 = x_scaled.to(torch.float8_e4m3fn)

        # FP8 matmul (hardware accelerated on H100+)
        output = torch._scaled_mm(
            x_fp8,
            self.weight.t(),
            scale_a=self.input_scale,
            scale_b=self.weight_scale,
            out_dtype=torch.float16
        )

        return output

    @classmethod
    def from_float(cls, linear: torch.nn.Linear) -> "FP8Linear":
        """Convert a FP16/32 linear layer to FP8."""
        fp8_linear = cls(linear.in_features, linear.out_features)

        # Calculate optimal scales
        weight_scale = linear.weight.abs().max() / 448  # E4M3 max

        fp8_linear.weight = (linear.weight / weight_scale).to(torch.float8_e4m3fn)
        fp8_linear.weight_scale = weight_scale

        return fp8_linear

2.3 Quantization Comparison

Quantization Methods Comparison (Llama-3-70B)

Method Bits Memory (GB) Perplexity Speed vs FP16
FP16 (Baseline) 16 140 5.42 1.0x
FP8 8 70 5.44 1.8x
INT8 (Absmax) 8 70 5.51 1.6x
GPTQ 4 35 5.58 2.2x
AWQ 4 35 5.52 2.4x
GGUF Q4_K_M 4.5* ~40 5.55 2.1x

*Mixed precision with important layers at higher precision

3. High-Performance Inference Engines

Modern inference engines optimize the entire serving stack: batching, memory management, kernel optimization, and distributed execution. The two leading solutions in 2025-2026 are vLLM and TensorRT-LLM.

3.1 vLLM: PagedAttention

vLLM introduced PagedAttention, which manages KV cache like virtual memory pages. This eliminates memory fragmentation and enables efficient continuous batching, achieving 2-24x throughput improvements.

flowchart TB subgraph "Traditional KV Cache" T1[Request 1: 2048 tokens allocated] T2[Request 2: 2048 tokens allocated] T3[Wasted: 70% unused] end subgraph "PagedAttention" P1[Page Pool] P1 --> PA[Request 1: 5 pages used] P1 --> PB[Request 2: 3 pages used] P1 --> PC[Free pages: dynamic allocation] end style T3 fill:#ffcdd2 style PC fill:#c8e6c9
from vllm import LLM, SamplingParams
from typing import List, Dict

def setup_vllm_server(
    model: str,
    tensor_parallel_size: int = 1,
    gpu_memory_utilization: float = 0.9,
    max_model_len: int = 32768,
    quantization: str = None
) -> LLM:
    """
    Configure vLLM for high-throughput inference.

    Args:
        model: Model name or path
        tensor_parallel_size: Number of GPUs for tensor parallelism
        gpu_memory_utilization: Fraction of GPU memory to use
        max_model_len: Maximum sequence length
        quantization: "awq", "gptq", "fp8", or None
    """
    llm = LLM(
        model=model,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=max_model_len,
        quantization=quantization,
        # PagedAttention settings
        block_size=16,  # KV cache block size
        swap_space=4,   # CPU swap space in GB
        # Optimization flags
        enforce_eager=False,  # Use CUDA graphs
        enable_prefix_caching=True,  # Cache common prefixes
    )
    return llm

def batch_inference(
    llm: LLM,
    prompts: List[str],
    max_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.95
) -> List[str]:
    """
    Run batch inference with vLLM.

    vLLM automatically handles:
    - Continuous batching (dynamic batch composition)
    - PagedAttention (efficient KV cache management)
    - Prefix caching (reuse common prefixes)
    """
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        # Advanced options
        presence_penalty=0.0,
        frequency_penalty=0.0,
        repetition_penalty=1.0,
        stop=["", "[/INST]"],
    )

    outputs = llm.generate(prompts, sampling_params)

    return [output.outputs[0].text for output in outputs]

# Example usage
llm = setup_vllm_server(
    model="meta-llama/Llama-3.1-70B-Instruct",
    tensor_parallel_size=4,  # 4x A100
    quantization="awq"
)

prompts = [
    "Explain quantum computing in simple terms.",
    "Write a Python function to merge two sorted lists.",
    "What are the key differences between TCP and UDP?"
]

results = batch_inference(llm, prompts)
for prompt, result in zip(prompts, results):
    print(f"Q: {prompt[:50]}...")
    print(f"A: {result[:200]}...")
    print("-" * 50)

vLLM OpenAI-Compatible Server

# Start vLLM server with OpenAI-compatible API
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B-Instruct \
    --tensor-parallel-size 4 \
    --quantization awq \
    --max-model-len 32768 \
    --enable-prefix-caching \
    --port 8000

# Use with OpenAI client
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")

response = client.chat.completions.create(
    model="meta-llama/Llama-3.1-70B-Instruct",
    messages=[{"role": "user", "content": "Hello!"}],
    max_tokens=100
)

3.2 TensorRT-LLM: Maximum Performance

TensorRT-LLM is NVIDIA's inference library, providing the highest performance on NVIDIA GPUs through aggressive kernel fusion, in-flight batching, and custom CUDA kernels.

# TensorRT-LLM model building and serving
import tensorrt_llm
from tensorrt_llm import BuildConfig, Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.quantization import QuantMode

def build_trtllm_engine(
    model_dir: str,
    output_dir: str,
    tp_size: int = 1,
    pp_size: int = 1,
    max_batch_size: int = 64,
    max_input_len: int = 4096,
    max_output_len: int = 2048,
    quantization: str = "fp8"
):
    """
    Build TensorRT-LLM engine for maximum inference performance.

    Args:
        model_dir: Path to HuggingFace model
        output_dir: Where to save TRT engine
        tp_size: Tensor parallelism degree
        pp_size: Pipeline parallelism degree
        max_batch_size: Maximum concurrent requests
        max_input_len: Maximum input sequence length
        max_output_len: Maximum output sequence length
        quantization: "fp8", "int8_sq", "int4_awq", etc.
    """
    # Configure quantization
    quant_mode = QuantMode(0)
    if quantization == "fp8":
        quant_mode = QuantMode.use_fp8_qdq()
    elif quantization == "int8_sq":
        quant_mode = QuantMode.use_smooth_quant()
    elif quantization == "int4_awq":
        quant_mode = QuantMode.use_weight_only() | QuantMode.use_int4_weights()

    # Build configuration
    build_config = BuildConfig(
        max_batch_size=max_batch_size,
        max_input_len=max_input_len,
        max_output_len=max_output_len,
        max_beam_width=1,
        # Optimization settings
        strongly_typed=True,
        builder_opt_level=5,  # Maximum optimization
        # Memory optimization
        use_paged_context_fmha=True,
        tokens_per_block=128,
    )

    # Parallelism mapping
    mapping = Mapping(
        world_size=tp_size * pp_size,
        tp_size=tp_size,
        pp_size=pp_size
    )

    # Build engine
    model = LLaMAForCausalLM.from_hugging_face(
        model_dir,
        mapping=mapping,
        quant_mode=quant_mode
    )

    engine = tensorrt_llm.build(model, build_config)
    engine.save(output_dir)

    print(f"Engine saved to {output_dir}")
    return engine

# Running inference with TensorRT-LLM
from tensorrt_llm.runtime import ModelRunner

def trtllm_inference(
    engine_dir: str,
    prompts: list,
    max_output_len: int = 512
):
    """Run inference with TensorRT-LLM engine."""
    runner = ModelRunner.from_dir(
        engine_dir,
        rank=0,  # For multi-GPU, set appropriate rank
    )

    outputs = runner.generate(
        prompts,
        max_new_tokens=max_output_len,
        end_id=128001,  # Llama-3 EOS
        pad_id=128001,
        temperature=0.7,
        top_p=0.95,
        streaming=False
    )

    return outputs

3.3 Inference Engine Comparison

Feature vLLM TensorRT-LLM llama.cpp
Best For General serving Maximum throughput Local/Edge deployment
Hardware NVIDIA, AMD, Intel NVIDIA only CPU, GPU, Apple Silicon
Setup Complexity Low High Very Low
Quantization AWQ, GPTQ, FP8 FP8, INT8, INT4 GGUF (many formats)
Continuous Batching Yes Yes (In-flight) Limited
Throughput (relative) 1.0x 1.2-1.5x 0.3-0.5x

4. Long Context Handling

Context windows have expanded dramatically from 4K tokens in 2023 to 1M+ tokens in 2026. Managing these long contexts efficiently is a key challenge for modern inference systems.

4.1 Context Length Evolution

Model Context Length Release
GPT-4 (Original) 8K / 32K 2023-03
Claude 2 100K 2023-07
GPT-4 Turbo 128K 2023-11
Claude 3 200K 2024-03
Gemini 1.5 Pro 1M → 2M 2024-02
Llama 4 Scout 10M 2025-04
Gemini 3.0 10M+ 2026-01

4.2 KV Cache Optimization Techniques

import torch
from typing import Optional, Tuple

class SlidingWindowAttention:
    """
    Sliding window attention for efficient long context.
    Each token only attends to the last `window_size` tokens.
    """

    def __init__(self, window_size: int = 4096):
        self.window_size = window_size

    def create_mask(
        self,
        seq_len: int,
        device: torch.device
    ) -> torch.Tensor:
        """Create sliding window causal mask."""
        # Start with causal mask
        mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device),
            diagonal=1
        ).bool()

        # Add window constraint
        window_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=device),
            diagonal=-self.window_size
        ).bool()

        mask = mask | ~window_mask
        return mask


class StreamingLLM:
    """
    StreamingLLM: Efficient infinite context through attention sinks.

    Key insight: First few tokens ("attention sinks") capture global
    information. Keep these plus a sliding window for infinite streaming.

    Reference: https://arxiv.org/abs/2309.17453
    """

    def __init__(
        self,
        num_sink_tokens: int = 4,
        window_size: int = 4096
    ):
        self.num_sink_tokens = num_sink_tokens
        self.window_size = window_size

    def evict_kv_cache(
        self,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
        current_len: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Evict middle tokens while keeping sinks and recent window.

        Cache layout: [sink_tokens | ... evicted ... | window_tokens]
        """
        max_len = self.num_sink_tokens + self.window_size

        if current_len <= max_len:
            return k_cache, v_cache

        # Keep sink tokens and recent window
        sink_k = k_cache[:, :, :self.num_sink_tokens, :]
        sink_v = v_cache[:, :, :self.num_sink_tokens, :]

        window_k = k_cache[:, :, -self.window_size:, :]
        window_v = v_cache[:, :, -self.window_size:, :]

        new_k = torch.cat([sink_k, window_k], dim=2)
        new_v = torch.cat([sink_v, window_v], dim=2)

        return new_k, new_v


class MultiQueryAttention:
    """
    Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
    reduce KV cache size by sharing KV heads across query heads.

    MQA: 1 KV head for all Q heads (8-32x cache reduction)
    GQA: n KV heads for all Q heads (typical 4-8x reduction)
    """

    @staticmethod
    def calculate_kv_cache_size(
        batch_size: int,
        seq_len: int,
        num_layers: int,
        hidden_dim: int,
        num_q_heads: int,
        num_kv_heads: int,
        dtype_bytes: int = 2  # FP16
    ) -> dict:
        """Compare KV cache sizes for different attention variants."""
        head_dim = hidden_dim // num_q_heads

        # MHA: num_kv_heads == num_q_heads
        mha_size = 2 * batch_size * seq_len * num_layers * num_q_heads * head_dim * dtype_bytes

        # GQA: num_kv_heads < num_q_heads
        gqa_size = 2 * batch_size * seq_len * num_layers * num_kv_heads * head_dim * dtype_bytes

        # MQA: num_kv_heads == 1
        mqa_size = 2 * batch_size * seq_len * num_layers * 1 * head_dim * dtype_bytes

        return {
            "MHA": f"{mha_size / 1e9:.2f} GB",
            "GQA": f"{gqa_size / 1e9:.2f} GB",
            "MQA": f"{mqa_size / 1e9:.2f} GB",
            "GQA_reduction": f"{mha_size / gqa_size:.1f}x",
            "MQA_reduction": f"{mha_size / mqa_size:.1f}x"
        }

# Example: Llama-3-70B KV cache comparison
kv_comparison = MultiQueryAttention.calculate_kv_cache_size(
    batch_size=1,
    seq_len=128000,
    num_layers=80,
    hidden_dim=8192,
    num_q_heads=64,
    num_kv_heads=8  # GQA with 8 KV heads
)
print("KV Cache Comparison:")
for k, v in kv_comparison.items():
    print(f"  {k}: {v}")

4.3 Prefix Caching and Prompt Caching

Prompt Caching Benefits

Prompt caching stores the KV cache for frequently used prefixes (system prompts, few-shot examples). Benefits include:

# Using prompt caching with Anthropic API
import anthropic
import hashlib

def cached_completion(
    client: anthropic.Anthropic,
    system_prompt: str,
    user_message: str,
    model: str = "claude-sonnet-4-20250514"
):
    """
    Use prompt caching for repeated system prompts.

    The system prompt KV cache is stored for 5 minutes,
    providing 90% cost reduction on cached tokens.
    """
    response = client.messages.create(
        model=model,
        max_tokens=1024,
        system=[
            {
                "type": "text",
                "text": system_prompt,
                "cache_control": {"type": "ephemeral"}  # Enable caching
            }
        ],
        messages=[
            {"role": "user", "content": user_message}
        ]
    )

    # Check cache usage
    usage = response.usage
    print(f"Input tokens: {usage.input_tokens}")
    print(f"Cache read tokens: {getattr(usage, 'cache_read_input_tokens', 0)}")
    print(f"Cache creation tokens: {getattr(usage, 'cache_creation_input_tokens', 0)}")

    return response.content[0].text

# vLLM prefix caching
from vllm import LLM, SamplingParams

def setup_prefix_caching():
    """Configure vLLM with automatic prefix caching."""
    llm = LLM(
        model="meta-llama/Llama-3.1-70B-Instruct",
        enable_prefix_caching=True,  # Enable automatic prefix caching
        # Prefix cache configuration
        max_num_seqs=256,  # Maximum sequences to cache
    )
    return llm

# Example: Chat with shared system prompt
SYSTEM_PROMPT = """You are a helpful AI assistant specializing in
software development. You provide clear, concise, and accurate
technical guidance..."""  # Long system prompt

def chat_with_caching(llm: LLM, messages: list):
    """Multiple requests share the cached system prompt."""
    results = []

    for msg in messages:
        # System prompt is automatically cached after first request
        full_prompt = f"{SYSTEM_PROMPT}\n\nUser: {msg}\n\nAssistant:"

        output = llm.generate(
            [full_prompt],
            SamplingParams(max_tokens=512, temperature=0.7)
        )
        results.append(output[0].outputs[0].text)

    return results

5. Throughput Optimization

5.1 Continuous Batching

Traditional static batching waits for all requests in a batch to complete before accepting new ones. Continuous (or in-flight) batching dynamically adds and removes requests as they arrive and complete, maximizing GPU utilization.

sequenceDiagram participant Client participant Scheduler participant GPU Note over Scheduler,GPU: Continuous Batching Client->>Scheduler: Request 1 (500 tokens) Scheduler->>GPU: Add to batch Client->>Scheduler: Request 2 (100 tokens) Scheduler->>GPU: Add to batch GPU-->>Client: Request 2 done Client->>Scheduler: Request 3 (200 tokens) Scheduler->>GPU: Add Request 3 (slot freed) GPU-->>Client: Request 3 done GPU-->>Client: Request 1 done
import asyncio
from dataclasses import dataclass
from typing import List, Optional
import time

@dataclass
class InferenceRequest:
    """Represents a single inference request."""
    request_id: str
    prompt: str
    max_tokens: int
    arrival_time: float
    generated_tokens: int = 0
    is_complete: bool = False

class ContinuousBatcher:
    """
    Simplified continuous batching scheduler.

    Key concepts:
    - Requests join/leave batch dynamically
    - Each iteration processes one token per request
    - Completed requests free slots for waiting requests
    """

    def __init__(
        self,
        max_batch_size: int = 64,
        max_waiting_requests: int = 256
    ):
        self.max_batch_size = max_batch_size
        self.max_waiting_requests = max_waiting_requests
        self.active_batch: List[InferenceRequest] = []
        self.waiting_queue: List[InferenceRequest] = []
        self.completed: List[InferenceRequest] = []

    def add_request(self, request: InferenceRequest) -> bool:
        """Add request to batch or waiting queue."""
        if len(self.active_batch) < self.max_batch_size:
            self.active_batch.append(request)
            return True
        elif len(self.waiting_queue) < self.max_waiting_requests:
            self.waiting_queue.append(request)
            return True
        return False

    def iteration_step(self):
        """
        Process one iteration (one token per request).
        This represents a single forward pass of the model.
        """
        # Generate one token for each active request
        for request in self.active_batch:
            request.generated_tokens += 1
            if request.generated_tokens >= request.max_tokens:
                request.is_complete = True

        # Remove completed requests
        completed = [r for r in self.active_batch if r.is_complete]
        self.completed.extend(completed)
        self.active_batch = [r for r in self.active_batch if not r.is_complete]

        # Fill slots from waiting queue
        slots_available = self.max_batch_size - len(self.active_batch)
        for _ in range(min(slots_available, len(self.waiting_queue))):
            self.active_batch.append(self.waiting_queue.pop(0))

        return len(completed)

    def get_metrics(self) -> dict:
        """Get current scheduler metrics."""
        return {
            "active_requests": len(self.active_batch),
            "waiting_requests": len(self.waiting_queue),
            "completed_requests": len(self.completed),
            "batch_utilization": len(self.active_batch) / self.max_batch_size
        }

5.2 Speculative Decoding

Speculative decoding uses a small "draft" model to propose multiple tokens, which the large "target" model verifies in parallel. This can achieve 2-3x speedup for autoregressive generation.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class SpeculativeDecoder:
    """
    Speculative decoding for faster generation.

    Algorithm:
    1. Draft model generates K candidate tokens
    2. Target model verifies all K tokens in one pass
    3. Accept matching tokens, reject and regenerate from first mismatch
    """

    def __init__(
        self,
        target_model_name: str,
        draft_model_name: str,
        num_speculative_tokens: int = 5
    ):
        self.num_speculative_tokens = num_speculative_tokens

        # Load models
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            draft_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        max_tokens: int = 100,
        temperature: float = 1.0
    ) -> str:
        """Generate text using speculative decoding."""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        input_ids = input_ids.to(self.target_model.device)

        generated_tokens = []
        accepted_tokens = 0
        total_draft_tokens = 0

        while len(generated_tokens) < max_tokens:
            # Step 1: Draft model generates K tokens
            draft_tokens = self._draft_generate(
                input_ids,
                self.num_speculative_tokens,
                temperature
            )
            total_draft_tokens += len(draft_tokens)

            # Step 2: Target model verifies all tokens in parallel
            candidate_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=1)
            target_logits = self.target_model(candidate_ids).logits

            # Step 3: Accept/reject tokens
            num_accepted = self._verify_tokens(
                draft_tokens,
                target_logits[:, input_ids.shape[1]-1:-1, :],
                temperature
            )

            accepted_tokens += num_accepted

            # Add accepted tokens
            generated_tokens.extend(draft_tokens[:num_accepted].tolist())
            input_ids = torch.cat([
                input_ids,
                draft_tokens[:num_accepted].unsqueeze(0)
            ], dim=1)

            # If not all accepted, sample one token from target
            if num_accepted < len(draft_tokens):
                target_token = self._sample_token(
                    target_logits[:, input_ids.shape[1]-len(draft_tokens)+num_accepted-1, :],
                    temperature
                )
                generated_tokens.append(target_token.item())
                input_ids = torch.cat([
                    input_ids,
                    target_token.unsqueeze(0).unsqueeze(0)
                ], dim=1)

            # Check for EOS
            if self.tokenizer.eos_token_id in generated_tokens[-self.num_speculative_tokens:]:
                break

        acceptance_rate = accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0
        print(f"Acceptance rate: {acceptance_rate:.2%}")

        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    def _draft_generate(
        self,
        input_ids: torch.Tensor,
        num_tokens: int,
        temperature: float
    ) -> torch.Tensor:
        """Generate tokens with draft model."""
        draft_tokens = []
        current_ids = input_ids

        for _ in range(num_tokens):
            logits = self.draft_model(current_ids).logits[:, -1, :]
            token = self._sample_token(logits, temperature)
            draft_tokens.append(token)
            current_ids = torch.cat([current_ids, token.unsqueeze(0).unsqueeze(0)], dim=1)

        return torch.tensor(draft_tokens, device=input_ids.device)

    def _sample_token(
        self,
        logits: torch.Tensor,
        temperature: float
    ) -> torch.Tensor:
        """Sample a token from logits."""
        if temperature == 0:
            return logits.argmax(dim=-1)
        probs = torch.softmax(logits / temperature, dim=-1)
        return torch.multinomial(probs, num_samples=1).squeeze(-1)

    def _verify_tokens(
        self,
        draft_tokens: torch.Tensor,
        target_logits: torch.Tensor,
        temperature: float
    ) -> int:
        """Verify draft tokens against target distribution."""
        for i, (draft_token, logits) in enumerate(zip(draft_tokens, target_logits[0])):
            target_probs = torch.softmax(logits / max(temperature, 1e-8), dim=-1)
            draft_prob = target_probs[draft_token]

            # Accept with probability min(1, p_target / p_draft)
            # Simplified: accept if probability > threshold
            if draft_prob < 0.1:  # Rejection threshold
                return i

        return len(draft_tokens)

5.3 Production Deployment Checklist

Inference Optimization Checklist

Memory Optimization:

Throughput Optimization:

Latency Optimization:

Cost Optimization:

Summary

Chapter 4 Key Takeaways

Previous: LLM Training and Alignment Next: Practical LLM Applications
日本語