This chapter covers Production Deployment. You will learn essential concepts and techniques.
1. System Architecture
1.1 Microservice Design
Production RAG systems are built by separating them into multiple services.
Main Components of RAG System:
- Ingestion Service: Document intake and preprocessing
- Embedding Service: Vectorization processing
- Search Service: Vector DB search and reranking
- Generation Service: LLM invocation and answer generation
- API Gateway: Request management and authentication
Implementation Example 1: FastAPI-based RAG System
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Optional
import asyncio
from functools import lru_cache
app = FastAPI(title="RAG API", version="1.0.0")
# Request/Response models
class QueryRequest(BaseModel):
query: str
top_k: int = 5
filters: Optional[dict] = None
class SearchResult(BaseModel):
content: str
score: float
metadata: dict
class QueryResponse(BaseModel):
answer: str
sources: List[SearchResult]
processing_time: float
# Dependency injection
@lru_cache()
def get_rag_service():
"""Get RAG service singleton"""
from services.rag_service import RAGService
return RAGService()
@app.post("/query", response_model=QueryResponse)
async def query_endpoint(
request: QueryRequest,
rag_service = Depends(get_rag_service)
):
"""RAG query endpoint"""
try:
import time
start = time.time()
# Execute search and generation in parallel
answer, sources = await rag_service.query(
query=request.query,
top_k=request.top_k,
filters=request.filters
)
processing_time = time.time() - start
return QueryResponse(
answer=answer,
sources=sources,
processing_time=processing_time
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class IndexRequest(BaseModel):
documents: List[dict]
collection_name: str
@app.post("/index")
async def index_endpoint(
request: IndexRequest,
rag_service = Depends(get_rag_service)
):
"""Document indexing"""
try:
result = await rag_service.index_documents(
documents=request.documents,
collection_name=request.collection_name
)
return {"status": "success", "indexed": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check"""
return {"status": "healthy"}
# RAG service implementation (services/rag_service.py)
class RAGService:
"""RAG business logic"""
def __init__(self):
self.vectorstore = self._init_vectorstore()
self.llm = self._init_llm()
self.embeddings = self._init_embeddings()
def _init_vectorstore(self):
# Initialize vector store
pass
def _init_llm(self):
# Initialize LLM
pass
def _init_embeddings(self):
# Initialize embedding model
pass
async def query(self, query: str, top_k: int = 5, filters: dict = None):
"""Asynchronous query processing"""
# Search
search_results = await self._search(query, top_k, filters)
# Generate
answer = await self._generate(query, search_results)
return answer, search_results
async def _search(self, query: str, top_k: int, filters: dict):
"""Asynchronous search"""
# Implementation
pass
async def _generate(self, query: str, context: list):
"""Asynchronous generation"""
# Implementation
pass
async def index_documents(self, documents: list, collection_name: str):
"""Asynchronous indexing"""
# Implementation
pass
# Execution
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
2. Performance Optimization
2.1 Caching Strategy
Utilize caching for frequent queries to reduce response time.
Implementation Example 2: Multi-level Caching
import redis
from functools import lru_cache
import hashlib
import json
import time
class MultiLevelCache:
"""Multi-level caching system"""
def __init__(self, redis_url: str = "redis://localhost:6379"):
# L1: Memory cache (LRU)
self.memory_cache_size = 100
# L2: Redis
self.redis_client = redis.from_url(redis_url)
# Cache statistics
self.stats = {
'hits': 0,
'misses': 0,
'l1_hits': 0,
'l2_hits': 0
}
def _generate_key(self, query: str, params: dict) -> str:
"""Generate cache key"""
cache_input = f"{query}:{json.dumps(params, sort_keys=True)}"
return hashlib.md5(cache_input.encode()).hexdigest()
@lru_cache(maxsize=100)
def _l1_get(self, key: str):
"""L1 cache get (memory)"""
# Automatically managed by LRU decorator
return None
def get(self, query: str, params: dict):
"""Cache get (L1 → L2)"""
key = self._generate_key(query, params)
# L1 check
try:
result = self._l1_get(key)
if result:
self.stats['hits'] += 1
self.stats['l1_hits'] += 1
return result
except:
pass
# L2 check (Redis)
try:
cached = self.redis_client.get(key)
if cached:
result = json.loads(cached)
# Promote to L1
self._l1_set(key, result)
self.stats['hits'] += 1
self.stats['l2_hits'] += 1
return result
except Exception as e:
print(f"Redis error: {e}")
self.stats['misses'] += 1
return None
def _l1_set(self, key: str, value):
"""L1 cache set"""
# Set in LRU cache
self._l1_get.__wrapped__(self, key) # Trigger
self._l1_get.cache_info()
def set(self, query: str, params: dict, value, ttl: int = 3600):
"""Cache set (L1 & L2)"""
key = self._generate_key(query, params)
# L1 set
self._l1_set(key, value)
# L2 set (Redis)
try:
self.redis_client.setex(
key,
ttl,
json.dumps(value)
)
except Exception as e:
print(f"Redis set error: {e}")
def invalidate(self, pattern: str = "*"):
"""Cache invalidation"""
# L1 clear
self._l1_get.cache_clear()
# L2 clear (pattern matching)
try:
keys = self.redis_client.keys(pattern)
if keys:
self.redis_client.delete(*keys)
except Exception as e:
print(f"Redis invalidate error: {e}")
def get_stats(self):
"""Cache statistics"""
total = self.stats['hits'] + self.stats['misses']
hit_rate = self.stats['hits'] / total if total > 0 else 0
return {
**self.stats,
'hit_rate': hit_rate,
'total_requests': total
}
# Integration with RAG system
class CachedRAGService:
"""RAG service with caching"""
def __init__(self):
self.cache = MultiLevelCache()
self.rag_service = RAGService()
async def query(self, query: str, top_k: int = 5):
"""Cache-aware query"""
params = {'top_k': top_k}
# Check cache
cached_result = self.cache.get(query, params)
if cached_result:
print("Cache hit!")
return cached_result
# Cache miss: execute
print("Cache miss, executing query...")
result = await self.rag_service.query(query, top_k)
# Save to cache (1 hour TTL)
self.cache.set(query, params, result, ttl=3600)
return result
def get_cache_stats(self):
"""Get cache statistics"""
return self.cache.get_stats()
# Usage example
cached_rag = CachedRAGService()
# First query (cache miss)
result1 = await cached_rag.query("What is machine learning")
# Same query (cache hit)
result2 = await cached_rag.query("What is machine learning")
# Display statistics
stats = cached_rag.get_cache_stats()
print(f"Hit rate: {stats['hit_rate']:.2%}")
2.2 Batch Processing Optimization
Implement batch processing to efficiently handle multiple documents.
Implementation Example 3: Batch Indexing
import asyncio
from typing import List
from concurrent.futures import ThreadPoolExecutor
import numpy as np
class BatchIndexer:
"""Batch indexing system"""
def __init__(self, embeddings, vectorstore, batch_size=50, max_workers=4):
self.embeddings = embeddings
self.vectorstore = vectorstore
self.batch_size = batch_size
self.max_workers = max_workers
def create_batches(self, documents: List, batch_size: int):
"""Split documents into batches"""
for i in range(0, len(documents), batch_size):
yield documents[i:i + batch_size]
async def process_batch_async(self, batch: List):
"""Asynchronous batch processing"""
# Embedding generation (parallel)
texts = [doc.page_content for doc in batch]
# Get batch embeddings
embeddings = await asyncio.to_thread(
self.embeddings.embed_documents,
texts
)
# Add to vector store
await asyncio.to_thread(
self.vectorstore.add_documents,
batch
)
return len(batch)
async def index_documents_parallel(self, documents: List):
"""Parallel batch indexing"""
batches = list(self.create_batches(documents, self.batch_size))
print(f"Processing started: {len(documents)} documents, {len(batches)} batches")
# Parallel processing
tasks = [
self.process_batch_async(batch)
for batch in batches
]
# Limit concurrent execution with semaphore
semaphore = asyncio.Semaphore(self.max_workers)
async def limited_task(task):
async with semaphore:
return await task
results = await asyncio.gather(
*[limited_task(task) for task in tasks]
)
total_indexed = sum(results)
print(f"Indexing completed: {total_indexed} documents")
return total_indexed
def index_with_progress(self, documents: List):
"""Indexing with progress bar"""
from tqdm import tqdm
batches = list(self.create_batches(documents, self.batch_size))
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for batch in batches:
future = executor.submit(self._process_batch_sync, batch)
futures.append(future)
# Progress display
with tqdm(total=len(documents), desc="Indexing") as pbar:
for future in futures:
count = future.result()
pbar.update(count)
print("Indexing completed")
def _process_batch_sync(self, batch: List):
"""Synchronous batch processing"""
texts = [doc.page_content for doc in batch]
self.embeddings.embed_documents(texts)
self.vectorstore.add_documents(batch)
return len(batch)
# Usage example
batch_indexer = BatchIndexer(
embeddings=embeddings,
vectorstore=vectorstore,
batch_size=50,
max_workers=4
)
# Large number of documents
large_documents = [...] # 10000 documents
# Asynchronous parallel processing
await batch_indexer.index_documents_parallel(large_documents)
# Processing with progress bar
batch_indexer.index_with_progress(large_documents)
3. Monitoring and Evaluation
3.1 Metrics Design
Define metrics to measure RAG system quality and performance.
Key Metrics:
- Latency: Search time, generation time, total processing time
- Accuracy: Search precision, answer correctness
- Relevance: Relevance between search results and query
- Cost: API call count, token usage
Implementation Example 4: Metrics Collection System
from prometheus_client import Counter, Histogram, Gauge
import time
from functools import wraps
# Prometheus metrics definition
query_counter = Counter(
'rag_queries_total',
'Total number of RAG queries',
['status']
)
query_latency = Histogram(
'rag_query_duration_seconds',
'RAG query duration',
['component']
)
search_results_count = Gauge(
'rag_search_results',
'Number of search results returned'
)
llm_tokens = Counter(
'rag_llm_tokens_total',
'Total LLM tokens used',
['type'] # prompt or completion
)
class RAGMetrics:
"""RAG metrics collection"""
def __init__(self):
self.metrics_data = []
def track_query(self, func):
"""Query processing metrics"""
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
try:
result = await func(*args, **kwargs)
query_counter.labels(status='success').inc()
# Record latency
duration = time.time() - start
query_latency.labels(component='total').observe(duration)
# Record result count
if hasattr(result, 'sources'):
search_results_count.set(len(result.sources))
return result
except Exception as e:
query_counter.labels(status='error').inc()
raise
return wrapper
def track_search(self, func):
"""Search metrics"""
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
result = await func(*args, **kwargs)
duration = time.time() - start
query_latency.labels(component='search').observe(duration)
return result
return wrapper
def track_generation(self, func):
"""Generation metrics"""
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
result = await func(*args, **kwargs)
duration = time.time() - start
query_latency.labels(component='generation').observe(duration)
# Record token count
if hasattr(result, 'usage'):
llm_tokens.labels(type='prompt').inc(result.usage.prompt_tokens)
llm_tokens.labels(type='completion').inc(result.usage.completion_tokens)
return result
return wrapper
class RAGEvaluator:
"""RAG quality evaluation"""
def __init__(self, llm):
self.llm = llm
def evaluate_retrieval(self, query: str, retrieved_docs: list, relevant_docs: list):
"""Search accuracy evaluation"""
# Precision@K
retrieved_ids = {doc.metadata.get('id') for doc in retrieved_docs}
relevant_ids = {doc.metadata.get('id') for doc in relevant_docs}
hits = retrieved_ids.intersection(relevant_ids)
precision = len(hits) / len(retrieved_ids) if retrieved_ids else 0
recall = len(hits) / len(relevant_ids) if relevant_ids else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1_score': f1
}
async def evaluate_answer_quality(self, query: str, answer: str, ground_truth: str):
"""Answer quality evaluation (LLM-based)"""
prompt = f"""Please evaluate the following question, answer, and ground truth.
Rate with a score from 1-5 (5 being the best) and provide reasoning.
Question: {query}
Answer: {answer}
Ground Truth: {ground_truth}
Evaluation (JSON format):
{{
"accuracy_score": <1-5>,
"relevance_score": <1-5>,
"completeness_score": <1-5>,
"reasoning": ""
}}
"""
response = await self.llm(prompt)
# JSON parsing
import json
evaluation = json.loads(response.content)
return evaluation
def calculate_mrr(self, queries: list, results: list):
"""MRR (Mean Reciprocal Rank) calculation"""
reciprocal_ranks = []
for query_results in results:
for rank, doc in enumerate(query_results, 1):
if doc.metadata.get('is_relevant'):
reciprocal_ranks.append(1 / rank)
break
else:
reciprocal_ranks.append(0)
mrr = sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0
return mrr
# Usage example
metrics = RAGMetrics()
evaluator = RAGEvaluator(llm)
# Apply metrics decorators
class MonitoredRAGService:
def __init__(self):
self.metrics = RAGMetrics()
@metrics.track_query
async def query(self, query: str):
# Query processing
search_results = await self.search(query)
answer = await self.generate(query, search_results)
return answer
@metrics.track_search
async def search(self, query: str):
# Search processing
pass
@metrics.track_generation
async def generate(self, query: str, context: list):
# Generation processing
pass
# Execute evaluation
evaluation = await evaluator.evaluate_answer_quality(
query="What is machine learning",
answer="Generated answer",
ground_truth="Ground truth answer"
)
print(f"Accuracy score: {evaluation['accuracy_score']}/5")
3.2 A/B Testing Implementation
Implement A/B testing to compare the effectiveness of different RAG configurations.
Implementation Example 5: A/B Testing Framework
import random
from typing import Dict, Any
from dataclasses import dataclass
from collections import defaultdict
@dataclass
class ExperimentVariant:
"""Experiment variant"""
name: str
config: Dict[str, Any]
traffic_ratio: float # 0.0-1.0
class ABTestManager:
"""A/B test management"""
def __init__(self):
self.experiments = {}
self.results = defaultdict(lambda: defaultdict(list))
def create_experiment(self, experiment_name: str, variants: list):
"""Create experiment"""
# Check traffic ratio sum
total_ratio = sum(v.traffic_ratio for v in variants)
if abs(total_ratio - 1.0) > 0.001:
raise ValueError("Traffic ratio sum must be 1.0")
self.experiments[experiment_name] = variants
print(f"Experiment created: {experiment_name}")
def assign_variant(self, experiment_name: str, user_id: str):
"""Assign variant to user"""
if experiment_name not in self.experiments:
raise ValueError(f"Experiment does not exist: {experiment_name}")
# Deterministic assignment (same user gets same variant)
hash_val = hash(f"{experiment_name}:{user_id}") % 1000 / 1000
cumulative_ratio = 0
for variant in self.experiments[experiment_name]:
cumulative_ratio += variant.traffic_ratio
if hash_val < cumulative_ratio:
return variant
# Fallback
return self.experiments[experiment_name][0]
def record_result(self, experiment_name: str, variant_name: str,
metric_name: str, value: float):
"""Record result"""
self.results[experiment_name][variant_name].append({
'metric': metric_name,
'value': value
})
def analyze_results(self, experiment_name: str, metric_name: str):
"""Analyze results"""
if experiment_name not in self.results:
return None
analysis = {}
for variant_name, results in self.results[experiment_name].items():
metric_values = [
r['value'] for r in results
if r['metric'] == metric_name
]
if metric_values:
analysis[variant_name] = {
'mean': np.mean(metric_values),
'std': np.std(metric_values),
'count': len(metric_values),
'min': min(metric_values),
'max': max(metric_values)
}
return analysis
# Usage in RAG system
class ABTestedRAGService:
"""A/B tested RAG service"""
def __init__(self):
self.ab_test = ABTestManager()
self._setup_experiments()
def _setup_experiments(self):
"""Setup experiments"""
# Chunking strategy test
chunking_variants = [
ExperimentVariant(
name="fixed_500",
config={"chunk_size": 500, "overlap": 50},
traffic_ratio=0.5
),
ExperimentVariant(
name="fixed_1000",
config={"chunk_size": 1000, "overlap": 100},
traffic_ratio=0.5
)
]
self.ab_test.create_experiment("chunking_strategy", chunking_variants)
# Reranking test
rerank_variants = [
ExperimentVariant(
name="no_rerank",
config={"use_reranking": False},
traffic_ratio=0.33
),
ExperimentVariant(
name="cross_encoder",
config={"use_reranking": True, "method": "cross_encoder"},
traffic_ratio=0.33
),
ExperimentVariant(
name="mmr",
config={"use_reranking": True, "method": "mmr"},
traffic_ratio=0.34
)
]
self.ab_test.create_experiment("reranking_method", rerank_variants)
async def query(self, query: str, user_id: str):
"""Query with variant applied"""
# Assign variants
chunking_variant = self.ab_test.assign_variant("chunking_strategy", user_id)
rerank_variant = self.ab_test.assign_variant("reranking_method", user_id)
print(f"User {user_id}:")
print(f" Chunking: {chunking_variant.name}")
print(f" Reranking: {rerank_variant.name}")
# Apply configuration
start = time.time()
# Execute query (based on configuration)
result = await self._execute_query(
query,
chunking_variant.config,
rerank_variant.config
)
latency = time.time() - start
# Record results
self.ab_test.record_result(
"chunking_strategy",
chunking_variant.name,
"latency",
latency
)
self.ab_test.record_result(
"reranking_method",
rerank_variant.name,
"latency",
latency
)
return result
async def _execute_query(self, query: str, chunking_config: dict,
rerank_config: dict):
"""Execute query with configuration"""
# Implementation
pass
def get_experiment_results(self):
"""Get experiment results"""
chunking_results = self.ab_test.analyze_results(
"chunking_strategy", "latency"
)
rerank_results = self.ab_test.analyze_results(
"reranking_method", "latency"
)
return {
'chunking': chunking_results,
'reranking': rerank_results
}
# Usage example
ab_rag = ABTestedRAGService()
# Execute test queries
for user_id in range(100):
await ab_rag.query("Machine learning evaluation metrics", f"user_{user_id}")
# Analyze results
results = ab_rag.get_experiment_results()
print("\nChunking strategy comparison:")
for variant, stats in results['chunking'].items():
print(f"{variant}: average {stats['mean']:.3f}s (n={stats['count']})")
4. Scalability and Security
4.1 Distributed Processing and Load Balancing
Build a distributed architecture to handle large-scale traffic.
Implementation Example 6: Celery-based Asynchronous Processing
from celery import Celery
from kombu import Queue
import os
# Celery configuration
celery_app = Celery(
'rag_tasks',
broker=os.getenv('REDIS_URL', 'redis://localhost:6379/0'),
backend=os.getenv('REDIS_URL', 'redis://localhost:6379/0')
)
celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
timezone='Asia/Tokyo',
enable_utc=True,
task_routes={
'rag_tasks.index_document': {'queue': 'indexing'},
'rag_tasks.generate_embedding': {'queue': 'embedding'},
'rag_tasks.query_rag': {'queue': 'query'}
}
)
# Task definitions
@celery_app.task(name='rag_tasks.index_document')
def index_document_task(document_data: dict):
"""Document indexing task"""
from services.indexer import DocumentIndexer
indexer = DocumentIndexer()
result = indexer.index(document_data)
return {
'status': 'completed',
'document_id': document_data.get('id'),
'indexed_chunks': result['chunks']
}
@celery_app.task(name='rag_tasks.generate_embedding')
def generate_embedding_task(text: str):
"""Embedding generation task"""
from services.embeddings import EmbeddingService
embedding_service = EmbeddingService()
embedding = embedding_service.generate(text)
return embedding.tolist()
@celery_app.task(name='rag_tasks.query_rag', bind=True)
def query_rag_task(self, query: str, user_id: str):
"""RAG query task (with retry)"""
try:
from services.rag_service import RAGService
rag = RAGService()
result = rag.query_sync(query)
return {
'status': 'success',
'answer': result['answer'],
'sources': result['sources']
}
except Exception as e:
# Retry (up to 3 times, exponential backoff)
raise self.retry(exc=e, countdown=2 ** self.request.retries, max_retries=3)
# FastAPI integration
from fastapi import BackgroundTasks
@app.post("/query_async")
async def query_async(request: QueryRequest, background_tasks: BackgroundTasks):
"""Asynchronous query"""
# Execute Celery task
task = query_rag_task.delay(request.query, "user_123")
return {
'task_id': task.id,
'status': 'processing'
}
@app.get("/task_status/{task_id}")
async def get_task_status(task_id: str):
"""Check task status"""
task = celery_app.AsyncResult(task_id)
if task.ready():
return {
'status': 'completed',
'result': task.result
}
else:
return {
'status': 'processing'
}
# Batch indexing
@app.post("/batch_index")
async def batch_index(documents: List[dict]):
"""Batch indexing"""
# Execute parallel tasks
tasks = [
index_document_task.delay(doc)
for doc in documents
]
return {
'status': 'processing',
'task_count': len(tasks),
'task_ids': [task.id for task in tasks]
}
# Celery worker startup command
# celery -A tasks.celery_app worker --loglevel=info --queues=indexing,embedding,query
Production Operation Best Practices:
- Monitoring: Visualization with Prometheus + Grafana
- Log Management: Structured logs (JSON) and Elasticsearch integration
- Security: API key authentication, rate limiting, input validation
- CI/CD: Automated testing, staged deployment
- Disaster Recovery: Backups, replication
Summary
- Separate each component with microservice architecture
- Optimize performance with multi-level caching and batch processing
- Continuous improvement through comprehensive metrics collection and A/B testing
- Handle large-scale traffic with Celery-based asynchronous processing
- Production operation design considering security and scalability
Disclaimer
- This content is provided solely for educational, research, and informational purposes and does not constitute professional advice (legal, accounting, technical warranty, etc.).
- This content and accompanying code examples are provided "AS IS" without any warranty, express or implied, including but not limited to merchantability, fitness for a particular purpose, non-infringement, accuracy, completeness, operation, or safety.
- The author and Tohoku University assume no responsibility for the content, availability, or safety of external links, third-party data, tools, libraries, etc.
- To the maximum extent permitted by applicable law, the author and Tohoku University shall not be liable for any direct, indirect, incidental, special, consequential, or punitive damages arising from the use, execution, or interpretation of this content.
- The content may be changed, updated, or discontinued without notice.
- The copyright and license of this content are subject to the stated conditions (e.g., CC BY 4.0). Such licenses typically include no-warranty clauses.