1. LLM Inference Fundamentals
Inference is where LLMs deliver value: generating text, answering questions, and powering applications. However, inference presents unique challenges compared to training. While training optimizes for throughput over hours or days, inference must deliver responses within milliseconds to seconds, often under strict cost constraints.
1.1 The Inference Pipeline
LLM inference consists of two distinct phases with different computational characteristics:
Two-Phase Inference
Prefill Phase (Prompt Processing):
- Processes all input tokens in parallel
- Compute-bound: benefits from GPU parallelism
- Generates KV cache for subsequent generation
- Time proportional to prompt length
Decode Phase (Token Generation):
- Generates tokens one at a time (autoregressive)
- Memory-bound: limited by KV cache access
- Each step requires reading entire KV cache
- Time proportional to output length
1.2 Memory Requirements
Understanding memory consumption is crucial for inference optimization. For a model with parameters $P$ and context length $L$:
def calculate_inference_memory(
params_billions: float,
context_length: int,
batch_size: int,
num_layers: int,
hidden_dim: int,
num_kv_heads: int,
precision: str = "fp16"
) -> dict:
"""
Calculate memory requirements for LLM inference.
Args:
params_billions: Model size in billions of parameters
context_length: Maximum sequence length
batch_size: Number of concurrent requests
num_layers: Number of transformer layers
hidden_dim: Model hidden dimension
num_kv_heads: Number of KV attention heads
precision: Weight precision ("fp32", "fp16", "int8", "int4")
Returns:
Memory breakdown in GB
"""
bytes_per_param = {
"fp32": 4, "fp16": 2, "bf16": 2,
"int8": 1, "int4": 0.5, "fp8": 1
}
# Model weights
weight_bytes = params_billions * 1e9 * bytes_per_param[precision]
weight_gb = weight_bytes / (1024**3)
# KV cache: 2 (K and V) * layers * batch * seq * heads * head_dim
head_dim = hidden_dim // num_kv_heads
kv_cache_bytes = (
2 * num_layers * batch_size * context_length *
num_kv_heads * head_dim * 2 # fp16 for KV cache
)
kv_cache_gb = kv_cache_bytes / (1024**3)
# Activation memory (approximate)
activation_gb = batch_size * context_length * hidden_dim * 4 / (1024**3)
return {
"weights_gb": round(weight_gb, 2),
"kv_cache_gb": round(kv_cache_gb, 2),
"activations_gb": round(activation_gb, 2),
"total_gb": round(weight_gb + kv_cache_gb + activation_gb, 2)
}
# Example: Llama-3-70B with 128K context
memory = calculate_inference_memory(
params_billions=70,
context_length=128000,
batch_size=1,
num_layers=80,
hidden_dim=8192,
num_kv_heads=8, # GQA
precision="int4"
)
print(f"Memory breakdown: {memory}")
# weights_gb: 35.0, kv_cache_gb: 40.0, activations_gb: 4.0, total_gb: 79.0
1.3 Latency Components
| Component | Description | Optimization Target |
|---|---|---|
| Time To First Token (TTFT) | Latency until first token appears | Prefill optimization, prompt caching |
| Time Per Output Token (TPOT) | Latency per generated token | Decode optimization, batching |
| End-to-End Latency | TTFT + (tokens × TPOT) | Speculative decoding |
| Throughput | Tokens/second across all requests | Continuous batching |
2. Model Quantization
Quantization reduces model precision from 16-bit or 32-bit floating point to lower bit-widths, dramatically reducing memory footprint and accelerating inference. Modern quantization techniques achieve near-lossless quality while cutting memory by 4-8x.
2.1 Quantization Fundamentals
import torch
import torch.nn.functional as F
from typing import Tuple
class Quantizer:
"""Basic quantization utilities for understanding the fundamentals."""
@staticmethod
def absmax_quantize(
tensor: torch.Tensor,
bits: int = 8
) -> Tuple[torch.Tensor, float]:
"""
Absmax (symmetric) quantization.
Maps values to [-2^(bits-1), 2^(bits-1)-1]
"""
qmax = 2 ** (bits - 1) - 1
scale = tensor.abs().max() / qmax
quantized = torch.round(tensor / scale).to(torch.int8)
return quantized, scale
@staticmethod
def absmax_dequantize(
quantized: torch.Tensor,
scale: float
) -> torch.Tensor:
"""Dequantize absmax-quantized tensor."""
return quantized.float() * scale
@staticmethod
def zeropoint_quantize(
tensor: torch.Tensor,
bits: int = 8
) -> Tuple[torch.Tensor, float, int]:
"""
Zero-point (asymmetric) quantization.
Better for tensors with non-symmetric distributions.
"""
qmin, qmax = 0, 2**bits - 1
min_val, max_val = tensor.min(), tensor.max()
scale = (max_val - min_val) / (qmax - qmin)
zero_point = int(round(-min_val / scale))
quantized = torch.round(tensor / scale + zero_point)
quantized = torch.clamp(quantized, qmin, qmax).to(torch.uint8)
return quantized, scale, zero_point
@staticmethod
def block_quantize(
tensor: torch.Tensor,
block_size: int = 128,
bits: int = 4
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Block-wise quantization (used in GPTQ, AWQ).
Each block has its own scale factor.
"""
# Reshape into blocks
original_shape = tensor.shape
tensor_flat = tensor.view(-1)
# Pad to multiple of block_size
pad_len = (block_size - len(tensor_flat) % block_size) % block_size
if pad_len:
tensor_flat = F.pad(tensor_flat, (0, pad_len))
blocks = tensor_flat.view(-1, block_size)
# Quantize each block independently
qmax = 2 ** (bits - 1) - 1
scales = blocks.abs().max(dim=1, keepdim=True).values / qmax
scales = scales.clamp(min=1e-8)
quantized = torch.round(blocks / scales).to(torch.int8)
return quantized, scales.squeeze()
# Demonstration
tensor = torch.randn(1024, 1024)
# Compare quantization methods
int8_quant, int8_scale = Quantizer.absmax_quantize(tensor, bits=8)
int4_quant, int4_scales = Quantizer.block_quantize(tensor, bits=4)
print(f"Original size: {tensor.numel() * 4 / 1024:.1f} KB")
print(f"INT8 size: {int8_quant.numel() * 1 / 1024:.1f} KB (4x reduction)")
print(f"INT4 size: {int4_quant.numel() * 0.5 / 1024:.1f} KB (8x reduction)")
2.2 Advanced Quantization Methods
GPTQ (Post-Training Quantization)
GPTQ uses second-order information (Hessian) to minimize quantization error. It processes weights column by column, adjusting remaining weights to compensate for quantization errors.
# Using AutoGPTQ for quantization
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
def quantize_with_gptq(
model_id: str,
output_dir: str,
bits: int = 4,
group_size: int = 128,
calibration_samples: int = 128
):
"""
Quantize a model using GPTQ.
Args:
model_id: HuggingFace model ID
output_dir: Where to save quantized model
bits: Target bit width (4 or 8)
group_size: Block size for quantization
calibration_samples: Number of samples for calibration
"""
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Configure quantization
quantize_config = BaseQuantizeConfig(
bits=bits,
group_size=group_size,
desc_act=True, # Activation-order quantization
sym=False, # Asymmetric quantization
)
# Load model for quantization
model = AutoGPTQForCausalLM.from_pretrained(
model_id,
quantize_config=quantize_config,
torch_dtype=torch.float16,
device_map="auto"
)
# Prepare calibration data
calibration_data = prepare_calibration_data(
tokenizer,
num_samples=calibration_samples
)
# Quantize
model.quantize(calibration_data)
# Save
model.save_quantized(output_dir)
tokenizer.save_pretrained(output_dir)
return model
def prepare_calibration_data(tokenizer, num_samples=128):
"""Prepare calibration data from C4 dataset."""
from datasets import load_dataset
dataset = load_dataset("c4", "en", split="train", streaming=True)
samples = []
for i, example in enumerate(dataset):
if i >= num_samples:
break
tokenized = tokenizer(
example["text"],
truncation=True,
max_length=2048,
return_tensors="pt"
)
samples.append(tokenized.input_ids)
return samples
AWQ (Activation-aware Weight Quantization)
AWQ identifies salient weight channels based on activation magnitudes. It preserves these important channels at higher precision, achieving better quality than uniform quantization.
# Using AWQ for quantization
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
def quantize_with_awq(
model_id: str,
output_dir: str,
bits: int = 4,
group_size: int = 128,
zero_point: bool = True
):
"""
Quantize using AWQ (Activation-aware Weight Quantization).
AWQ identifies salient weights based on activation patterns
and preserves them at higher precision.
"""
# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_id,
device_map="auto",
safetensors=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Quantization config
quant_config = {
"zero_point": zero_point,
"q_group_size": group_size,
"w_bit": bits,
"version": "GEMM" # or "GEMV" for batch size 1
}
# Quantize with calibration
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data="pileval", # or custom dataset
calib_size=512
)
# Save quantized model
model.save_quantized(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model quantized to {bits}-bit and saved to {output_dir}")
return model
# Load and use quantized model
def load_awq_model(model_path: str):
"""Load AWQ quantized model for inference."""
model = AutoAWQForCausalLM.from_quantized(
model_path,
fuse_layers=True, # Fuse QKV and MLP for speed
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
FP8 Quantization (2025-2026 Standard)
FP8 has emerged as the preferred format for inference on modern GPUs (H100, H200, B200). It maintains dynamic range through floating-point representation while achieving INT8-level efficiency.
import torch
from typing import Literal
def fp8_quantize(
tensor: torch.Tensor,
format: Literal["e4m3", "e5m2"] = "e4m3"
) -> torch.Tensor:
"""
Convert tensor to FP8 format.
E4M3: 4-bit exponent, 3-bit mantissa
- Range: ±448, precision: ~1/16
- Better for weights and activations
E5M2: 5-bit exponent, 2-bit mantissa
- Range: ±57344, precision: ~1/4
- Better for gradients (wider range)
"""
if format == "e4m3":
return tensor.to(torch.float8_e4m3fn)
else:
return tensor.to(torch.float8_e5m2)
# FP8 inference with scaling
class FP8Linear(torch.nn.Module):
"""FP8 linear layer with per-tensor scaling."""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# FP8 weights with scale
self.register_buffer(
"weight",
torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn)
)
self.register_buffer("weight_scale", torch.ones(1))
self.register_buffer("input_scale", torch.ones(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Scale input to FP8 range
x_scaled = x / self.input_scale
x_fp8 = x_scaled.to(torch.float8_e4m3fn)
# FP8 matmul (hardware accelerated on H100+)
output = torch._scaled_mm(
x_fp8,
self.weight.t(),
scale_a=self.input_scale,
scale_b=self.weight_scale,
out_dtype=torch.float16
)
return output
@classmethod
def from_float(cls, linear: torch.nn.Linear) -> "FP8Linear":
"""Convert a FP16/32 linear layer to FP8."""
fp8_linear = cls(linear.in_features, linear.out_features)
# Calculate optimal scales
weight_scale = linear.weight.abs().max() / 448 # E4M3 max
fp8_linear.weight = (linear.weight / weight_scale).to(torch.float8_e4m3fn)
fp8_linear.weight_scale = weight_scale
return fp8_linear
2.3 Quantization Comparison
Quantization Methods Comparison (Llama-3-70B)
| Method | Bits | Memory (GB) | Perplexity | Speed vs FP16 |
|---|---|---|---|---|
| FP16 (Baseline) | 16 | 140 | 5.42 | 1.0x |
| FP8 | 8 | 70 | 5.44 | 1.8x |
| INT8 (Absmax) | 8 | 70 | 5.51 | 1.6x |
| GPTQ | 4 | 35 | 5.58 | 2.2x |
| AWQ | 4 | 35 | 5.52 | 2.4x |
| GGUF Q4_K_M | 4.5* | ~40 | 5.55 | 2.1x |
*Mixed precision with important layers at higher precision
3. High-Performance Inference Engines
Modern inference engines optimize the entire serving stack: batching, memory management, kernel optimization, and distributed execution. The two leading solutions in 2025-2026 are vLLM and TensorRT-LLM.
3.1 vLLM: PagedAttention
vLLM introduced PagedAttention, which manages KV cache like virtual memory pages. This eliminates memory fragmentation and enables efficient continuous batching, achieving 2-24x throughput improvements.
from vllm import LLM, SamplingParams
from typing import List, Dict
def setup_vllm_server(
model: str,
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
max_model_len: int = 32768,
quantization: str = None
) -> LLM:
"""
Configure vLLM for high-throughput inference.
Args:
model: Model name or path
tensor_parallel_size: Number of GPUs for tensor parallelism
gpu_memory_utilization: Fraction of GPU memory to use
max_model_len: Maximum sequence length
quantization: "awq", "gptq", "fp8", or None
"""
llm = LLM(
model=model,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
max_model_len=max_model_len,
quantization=quantization,
# PagedAttention settings
block_size=16, # KV cache block size
swap_space=4, # CPU swap space in GB
# Optimization flags
enforce_eager=False, # Use CUDA graphs
enable_prefix_caching=True, # Cache common prefixes
)
return llm
def batch_inference(
llm: LLM,
prompts: List[str],
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95
) -> List[str]:
"""
Run batch inference with vLLM.
vLLM automatically handles:
- Continuous batching (dynamic batch composition)
- PagedAttention (efficient KV cache management)
- Prefix caching (reuse common prefixes)
"""
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
# Advanced options
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
stop=["", "[/INST]"],
)
outputs = llm.generate(prompts, sampling_params)
return [output.outputs[0].text for output in outputs]
# Example usage
llm = setup_vllm_server(
model="meta-llama/Llama-3.1-70B-Instruct",
tensor_parallel_size=4, # 4x A100
quantization="awq"
)
prompts = [
"Explain quantum computing in simple terms.",
"Write a Python function to merge two sorted lists.",
"What are the key differences between TCP and UDP?"
]
results = batch_inference(llm, prompts)
for prompt, result in zip(prompts, results):
print(f"Q: {prompt[:50]}...")
print(f"A: {result[:200]}...")
print("-" * 50)
vLLM OpenAI-Compatible Server
# Start vLLM server with OpenAI-compatible API
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3.1-70B-Instruct \
--tensor-parallel-size 4 \
--quantization awq \
--max-model-len 32768 \
--enable-prefix-caching \
--port 8000
# Use with OpenAI client
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": "Hello!"}],
max_tokens=100
)
3.2 TensorRT-LLM: Maximum Performance
TensorRT-LLM is NVIDIA's inference library, providing the highest performance on NVIDIA GPUs through aggressive kernel fusion, in-flight batching, and custom CUDA kernels.
# TensorRT-LLM model building and serving
import tensorrt_llm
from tensorrt_llm import BuildConfig, Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.quantization import QuantMode
def build_trtllm_engine(
model_dir: str,
output_dir: str,
tp_size: int = 1,
pp_size: int = 1,
max_batch_size: int = 64,
max_input_len: int = 4096,
max_output_len: int = 2048,
quantization: str = "fp8"
):
"""
Build TensorRT-LLM engine for maximum inference performance.
Args:
model_dir: Path to HuggingFace model
output_dir: Where to save TRT engine
tp_size: Tensor parallelism degree
pp_size: Pipeline parallelism degree
max_batch_size: Maximum concurrent requests
max_input_len: Maximum input sequence length
max_output_len: Maximum output sequence length
quantization: "fp8", "int8_sq", "int4_awq", etc.
"""
# Configure quantization
quant_mode = QuantMode(0)
if quantization == "fp8":
quant_mode = QuantMode.use_fp8_qdq()
elif quantization == "int8_sq":
quant_mode = QuantMode.use_smooth_quant()
elif quantization == "int4_awq":
quant_mode = QuantMode.use_weight_only() | QuantMode.use_int4_weights()
# Build configuration
build_config = BuildConfig(
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_beam_width=1,
# Optimization settings
strongly_typed=True,
builder_opt_level=5, # Maximum optimization
# Memory optimization
use_paged_context_fmha=True,
tokens_per_block=128,
)
# Parallelism mapping
mapping = Mapping(
world_size=tp_size * pp_size,
tp_size=tp_size,
pp_size=pp_size
)
# Build engine
model = LLaMAForCausalLM.from_hugging_face(
model_dir,
mapping=mapping,
quant_mode=quant_mode
)
engine = tensorrt_llm.build(model, build_config)
engine.save(output_dir)
print(f"Engine saved to {output_dir}")
return engine
# Running inference with TensorRT-LLM
from tensorrt_llm.runtime import ModelRunner
def trtllm_inference(
engine_dir: str,
prompts: list,
max_output_len: int = 512
):
"""Run inference with TensorRT-LLM engine."""
runner = ModelRunner.from_dir(
engine_dir,
rank=0, # For multi-GPU, set appropriate rank
)
outputs = runner.generate(
prompts,
max_new_tokens=max_output_len,
end_id=128001, # Llama-3 EOS
pad_id=128001,
temperature=0.7,
top_p=0.95,
streaming=False
)
return outputs
3.3 Inference Engine Comparison
| Feature | vLLM | TensorRT-LLM | llama.cpp |
|---|---|---|---|
| Best For | General serving | Maximum throughput | Local/Edge deployment |
| Hardware | NVIDIA, AMD, Intel | NVIDIA only | CPU, GPU, Apple Silicon |
| Setup Complexity | Low | High | Very Low |
| Quantization | AWQ, GPTQ, FP8 | FP8, INT8, INT4 | GGUF (many formats) |
| Continuous Batching | Yes | Yes (In-flight) | Limited |
| Throughput (relative) | 1.0x | 1.2-1.5x | 0.3-0.5x |
4. Long Context Handling
Context windows have expanded dramatically from 4K tokens in 2023 to 1M+ tokens in 2026. Managing these long contexts efficiently is a key challenge for modern inference systems.
4.1 Context Length Evolution
| Model | Context Length | Release |
|---|---|---|
| GPT-4 (Original) | 8K / 32K | 2023-03 |
| Claude 2 | 100K | 2023-07 |
| GPT-4 Turbo | 128K | 2023-11 |
| Claude 3 | 200K | 2024-03 |
| Gemini 1.5 Pro | 1M → 2M | 2024-02 |
| Llama 4 Scout | 10M | 2025-04 |
| Gemini 3.0 | 10M+ | 2026-01 |
4.2 KV Cache Optimization Techniques
import torch
from typing import Optional, Tuple
class SlidingWindowAttention:
"""
Sliding window attention for efficient long context.
Each token only attends to the last `window_size` tokens.
"""
def __init__(self, window_size: int = 4096):
self.window_size = window_size
def create_mask(
self,
seq_len: int,
device: torch.device
) -> torch.Tensor:
"""Create sliding window causal mask."""
# Start with causal mask
mask = torch.triu(
torch.ones(seq_len, seq_len, device=device),
diagonal=1
).bool()
# Add window constraint
window_mask = torch.triu(
torch.ones(seq_len, seq_len, device=device),
diagonal=-self.window_size
).bool()
mask = mask | ~window_mask
return mask
class StreamingLLM:
"""
StreamingLLM: Efficient infinite context through attention sinks.
Key insight: First few tokens ("attention sinks") capture global
information. Keep these plus a sliding window for infinite streaming.
Reference: https://arxiv.org/abs/2309.17453
"""
def __init__(
self,
num_sink_tokens: int = 4,
window_size: int = 4096
):
self.num_sink_tokens = num_sink_tokens
self.window_size = window_size
def evict_kv_cache(
self,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
current_len: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evict middle tokens while keeping sinks and recent window.
Cache layout: [sink_tokens | ... evicted ... | window_tokens]
"""
max_len = self.num_sink_tokens + self.window_size
if current_len <= max_len:
return k_cache, v_cache
# Keep sink tokens and recent window
sink_k = k_cache[:, :, :self.num_sink_tokens, :]
sink_v = v_cache[:, :, :self.num_sink_tokens, :]
window_k = k_cache[:, :, -self.window_size:, :]
window_v = v_cache[:, :, -self.window_size:, :]
new_k = torch.cat([sink_k, window_k], dim=2)
new_v = torch.cat([sink_v, window_v], dim=2)
return new_k, new_v
class MultiQueryAttention:
"""
Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
reduce KV cache size by sharing KV heads across query heads.
MQA: 1 KV head for all Q heads (8-32x cache reduction)
GQA: n KV heads for all Q heads (typical 4-8x reduction)
"""
@staticmethod
def calculate_kv_cache_size(
batch_size: int,
seq_len: int,
num_layers: int,
hidden_dim: int,
num_q_heads: int,
num_kv_heads: int,
dtype_bytes: int = 2 # FP16
) -> dict:
"""Compare KV cache sizes for different attention variants."""
head_dim = hidden_dim // num_q_heads
# MHA: num_kv_heads == num_q_heads
mha_size = 2 * batch_size * seq_len * num_layers * num_q_heads * head_dim * dtype_bytes
# GQA: num_kv_heads < num_q_heads
gqa_size = 2 * batch_size * seq_len * num_layers * num_kv_heads * head_dim * dtype_bytes
# MQA: num_kv_heads == 1
mqa_size = 2 * batch_size * seq_len * num_layers * 1 * head_dim * dtype_bytes
return {
"MHA": f"{mha_size / 1e9:.2f} GB",
"GQA": f"{gqa_size / 1e9:.2f} GB",
"MQA": f"{mqa_size / 1e9:.2f} GB",
"GQA_reduction": f"{mha_size / gqa_size:.1f}x",
"MQA_reduction": f"{mha_size / mqa_size:.1f}x"
}
# Example: Llama-3-70B KV cache comparison
kv_comparison = MultiQueryAttention.calculate_kv_cache_size(
batch_size=1,
seq_len=128000,
num_layers=80,
hidden_dim=8192,
num_q_heads=64,
num_kv_heads=8 # GQA with 8 KV heads
)
print("KV Cache Comparison:")
for k, v in kv_comparison.items():
print(f" {k}: {v}")
4.3 Prefix Caching and Prompt Caching
Prompt Caching Benefits
Prompt caching stores the KV cache for frequently used prefixes (system prompts, few-shot examples). Benefits include:
- Latency: Skip prefill for cached prompts (up to 85% TTFT reduction)
- Cost: Major providers offer 50-90% discount for cached tokens
- Throughput: More GPU memory available for active requests
# Using prompt caching with Anthropic API
import anthropic
import hashlib
def cached_completion(
client: anthropic.Anthropic,
system_prompt: str,
user_message: str,
model: str = "claude-sonnet-4-20250514"
):
"""
Use prompt caching for repeated system prompts.
The system prompt KV cache is stored for 5 minutes,
providing 90% cost reduction on cached tokens.
"""
response = client.messages.create(
model=model,
max_tokens=1024,
system=[
{
"type": "text",
"text": system_prompt,
"cache_control": {"type": "ephemeral"} # Enable caching
}
],
messages=[
{"role": "user", "content": user_message}
]
)
# Check cache usage
usage = response.usage
print(f"Input tokens: {usage.input_tokens}")
print(f"Cache read tokens: {getattr(usage, 'cache_read_input_tokens', 0)}")
print(f"Cache creation tokens: {getattr(usage, 'cache_creation_input_tokens', 0)}")
return response.content[0].text
# vLLM prefix caching
from vllm import LLM, SamplingParams
def setup_prefix_caching():
"""Configure vLLM with automatic prefix caching."""
llm = LLM(
model="meta-llama/Llama-3.1-70B-Instruct",
enable_prefix_caching=True, # Enable automatic prefix caching
# Prefix cache configuration
max_num_seqs=256, # Maximum sequences to cache
)
return llm
# Example: Chat with shared system prompt
SYSTEM_PROMPT = """You are a helpful AI assistant specializing in
software development. You provide clear, concise, and accurate
technical guidance...""" # Long system prompt
def chat_with_caching(llm: LLM, messages: list):
"""Multiple requests share the cached system prompt."""
results = []
for msg in messages:
# System prompt is automatically cached after first request
full_prompt = f"{SYSTEM_PROMPT} \n\nUser: {msg}\n\nAssistant:"
output = llm.generate(
[full_prompt],
SamplingParams(max_tokens=512, temperature=0.7)
)
results.append(output[0].outputs[0].text)
return results
5. Throughput Optimization
5.1 Continuous Batching
Traditional static batching waits for all requests in a batch to complete before accepting new ones. Continuous (or in-flight) batching dynamically adds and removes requests as they arrive and complete, maximizing GPU utilization.
import asyncio
from dataclasses import dataclass
from typing import List, Optional
import time
@dataclass
class InferenceRequest:
"""Represents a single inference request."""
request_id: str
prompt: str
max_tokens: int
arrival_time: float
generated_tokens: int = 0
is_complete: bool = False
class ContinuousBatcher:
"""
Simplified continuous batching scheduler.
Key concepts:
- Requests join/leave batch dynamically
- Each iteration processes one token per request
- Completed requests free slots for waiting requests
"""
def __init__(
self,
max_batch_size: int = 64,
max_waiting_requests: int = 256
):
self.max_batch_size = max_batch_size
self.max_waiting_requests = max_waiting_requests
self.active_batch: List[InferenceRequest] = []
self.waiting_queue: List[InferenceRequest] = []
self.completed: List[InferenceRequest] = []
def add_request(self, request: InferenceRequest) -> bool:
"""Add request to batch or waiting queue."""
if len(self.active_batch) < self.max_batch_size:
self.active_batch.append(request)
return True
elif len(self.waiting_queue) < self.max_waiting_requests:
self.waiting_queue.append(request)
return True
return False
def iteration_step(self):
"""
Process one iteration (one token per request).
This represents a single forward pass of the model.
"""
# Generate one token for each active request
for request in self.active_batch:
request.generated_tokens += 1
if request.generated_tokens >= request.max_tokens:
request.is_complete = True
# Remove completed requests
completed = [r for r in self.active_batch if r.is_complete]
self.completed.extend(completed)
self.active_batch = [r for r in self.active_batch if not r.is_complete]
# Fill slots from waiting queue
slots_available = self.max_batch_size - len(self.active_batch)
for _ in range(min(slots_available, len(self.waiting_queue))):
self.active_batch.append(self.waiting_queue.pop(0))
return len(completed)
def get_metrics(self) -> dict:
"""Get current scheduler metrics."""
return {
"active_requests": len(self.active_batch),
"waiting_requests": len(self.waiting_queue),
"completed_requests": len(self.completed),
"batch_utilization": len(self.active_batch) / self.max_batch_size
}
5.2 Speculative Decoding
Speculative decoding uses a small "draft" model to propose multiple tokens, which the large "target" model verifies in parallel. This can achieve 2-3x speedup for autoregressive generation.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class SpeculativeDecoder:
"""
Speculative decoding for faster generation.
Algorithm:
1. Draft model generates K candidate tokens
2. Target model verifies all K tokens in one pass
3. Accept matching tokens, reject and regenerate from first mismatch
"""
def __init__(
self,
target_model_name: str,
draft_model_name: str,
num_speculative_tokens: int = 5
):
self.num_speculative_tokens = num_speculative_tokens
# Load models
self.target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)
@torch.no_grad()
def generate(
self,
prompt: str,
max_tokens: int = 100,
temperature: float = 1.0
) -> str:
"""Generate text using speculative decoding."""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
input_ids = input_ids.to(self.target_model.device)
generated_tokens = []
accepted_tokens = 0
total_draft_tokens = 0
while len(generated_tokens) < max_tokens:
# Step 1: Draft model generates K tokens
draft_tokens = self._draft_generate(
input_ids,
self.num_speculative_tokens,
temperature
)
total_draft_tokens += len(draft_tokens)
# Step 2: Target model verifies all tokens in parallel
candidate_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=1)
target_logits = self.target_model(candidate_ids).logits
# Step 3: Accept/reject tokens
num_accepted = self._verify_tokens(
draft_tokens,
target_logits[:, input_ids.shape[1]-1:-1, :],
temperature
)
accepted_tokens += num_accepted
# Add accepted tokens
generated_tokens.extend(draft_tokens[:num_accepted].tolist())
input_ids = torch.cat([
input_ids,
draft_tokens[:num_accepted].unsqueeze(0)
], dim=1)
# If not all accepted, sample one token from target
if num_accepted < len(draft_tokens):
target_token = self._sample_token(
target_logits[:, input_ids.shape[1]-len(draft_tokens)+num_accepted-1, :],
temperature
)
generated_tokens.append(target_token.item())
input_ids = torch.cat([
input_ids,
target_token.unsqueeze(0).unsqueeze(0)
], dim=1)
# Check for EOS
if self.tokenizer.eos_token_id in generated_tokens[-self.num_speculative_tokens:]:
break
acceptance_rate = accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0
print(f"Acceptance rate: {acceptance_rate:.2%}")
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
def _draft_generate(
self,
input_ids: torch.Tensor,
num_tokens: int,
temperature: float
) -> torch.Tensor:
"""Generate tokens with draft model."""
draft_tokens = []
current_ids = input_ids
for _ in range(num_tokens):
logits = self.draft_model(current_ids).logits[:, -1, :]
token = self._sample_token(logits, temperature)
draft_tokens.append(token)
current_ids = torch.cat([current_ids, token.unsqueeze(0).unsqueeze(0)], dim=1)
return torch.tensor(draft_tokens, device=input_ids.device)
def _sample_token(
self,
logits: torch.Tensor,
temperature: float
) -> torch.Tensor:
"""Sample a token from logits."""
if temperature == 0:
return logits.argmax(dim=-1)
probs = torch.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
def _verify_tokens(
self,
draft_tokens: torch.Tensor,
target_logits: torch.Tensor,
temperature: float
) -> int:
"""Verify draft tokens against target distribution."""
for i, (draft_token, logits) in enumerate(zip(draft_tokens, target_logits[0])):
target_probs = torch.softmax(logits / max(temperature, 1e-8), dim=-1)
draft_prob = target_probs[draft_token]
# Accept with probability min(1, p_target / p_draft)
# Simplified: accept if probability > threshold
if draft_prob < 0.1: # Rejection threshold
return i
return len(draft_tokens)
5.3 Production Deployment Checklist
Inference Optimization Checklist
Memory Optimization:
- Use quantization (AWQ/GPTQ for 4-bit, FP8 for minimal quality loss)
- Enable GQA/MQA for KV cache reduction
- Set appropriate max_model_len for your use case
- Configure swap space for memory overflow
Throughput Optimization:
- Enable continuous batching (vLLM or TensorRT-LLM)
- Use tensor parallelism for large models
- Enable CUDA graphs for reduced kernel launch overhead
- Consider speculative decoding for latency-sensitive applications
Latency Optimization:
- Enable prefix caching for repeated prompts
- Use streaming for real-time applications
- Consider smaller models with speculative decoding
- Profile TTFT vs TPOT for your workload
Cost Optimization:
- Right-size instances (don't over-provision GPU memory)
- Use spot instances for batch workloads
- Implement request queuing for load smoothing
- Monitor and optimize batch sizes
Summary
Chapter 4 Key Takeaways
- Inference Phases: Prefill (compute-bound) and Decode (memory-bound) require different optimizations
- Quantization: AWQ and GPTQ for 4-bit, FP8 for near-lossless 8-bit; 2-8x memory reduction
- Inference Engines: vLLM for general serving, TensorRT-LLM for maximum NVIDIA performance
- Long Context: Sliding window, StreamingLLM, and GQA enable efficient 1M+ token handling
- Throughput: Continuous batching and speculative decoding maximize GPU utilization