Project: Production-Ready RAG System
Build a complete, production-ready RAG system with all the features needed for real-world deployment: monitoring, caching, error handling, optimization, and scalability.
Production Requirements: A production RAG system needs more than basic retrieval and generation - it requires monitoring, caching, error handling, optimization, and the ability to scale.
System Architecture
┌─────────────────────────────────────────────────────────────┐
│ API Gateway │
│ (Rate Limiting, Auth) │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ RAG Orchestrator │
│ (Query Processing, Caching) │
└─────────────────────────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Retriever │ │ Reranker │ │ Generator │
│ (Hybrid) │ │ (Cohere) │ │ (GPT-4) │
└──────────────┘ └──────────────┘ └──────────────┘
│
▼
┌──────────────────────────────────────────────────────────┐
│ Vector Database (Pinecone/Weaviate) │
└──────────────────────────────────────────────────────────┘
│
▼
┌──────────────────────────────────────────────────────────┐
│ Monitoring & Logging (Prometheus, Grafana) │
└──────────────────────────────────────────────────────────┘
Step 1: Configuration Management
from pydantic import BaseSettings, validator
from typing import Optional, List
import os
class RAGConfig(BaseSettings):
"""Production RAG configuration using Pydantic."""
# API Keys
openai_api_key: str
cohere_api_key: Optional[str] = None
pinecone_api_key: Optional[str] = None
# Model Configuration
embedding_model: str = "text-embedding-ada-002"
chat_model: str = "gpt-4"
rerank_model: str = "rerank-english-v2.0"
# Retrieval Configuration
initial_k: int = 50
final_k: int = 5
hybrid_alpha: float = 0.5
# Cache Configuration
enable_cache: bool = True
cache_ttl: int = 3600 # 1 hour
redis_url: Optional[str] = None
# Performance
max_tokens: int = 2000
timeout_seconds: int = 30
max_retries: int = 3
# Monitoring
enable_monitoring: bool = True
log_level: str = "INFO"
metrics_port: int = 8000
# Rate Limiting
requests_per_minute: int = 60
burst_size: int = 10
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@validator("log_level")
def validate_log_level(cls, v):
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in valid_levels:
raise ValueError(f"Invalid log level. Must be one of {valid_levels}")
return v.upper()
# Load configuration
config = RAGConfig()
Step 2: Caching Layer
import hashlib
import json
import redis
from typing import Optional, Any
from datetime import timedelta
class RAGCache:
"""
Caching layer for RAG system.
Caches retrieval results and generated responses.
"""
def __init__(
self,
redis_url: Optional[str] = None,
ttl: int = 3600,
enabled: bool = True
):
self.enabled = enabled
self.ttl = ttl
if enabled and redis_url:
self.redis_client = redis.from_url(redis_url)
else:
self.redis_client = None
def _generate_key(self, prefix: str, data: Any) -> str:
"""Generate cache key from data."""
data_str = json.dumps(data, sort_keys=True)
hash_val = hashlib.md5(data_str.encode()).hexdigest()
return f"{prefix}:{hash_val}"
def get_retrieval(self, query: str) -> Optional[List]:
"""Get cached retrieval results."""
if not self.enabled or not self.redis_client:
return None
key = self._generate_key("retrieval", {"query": query})
try:
cached = self.redis_client.get(key)
if cached:
return json.loads(cached)
except Exception as e:
print(f"Cache get error: {e}")
return None
def set_retrieval(self, query: str, results: List):
"""Cache retrieval results."""
if not self.enabled or not self.redis_client:
return
key = self._generate_key("retrieval", {"query": query})
try:
self.redis_client.setex(
key,
timedelta(seconds=self.ttl),
json.dumps(results)
)
except Exception as e:
print(f"Cache set error: {e}")
def get_response(self, query: str, context: str) -> Optional[str]:
"""Get cached response."""
if not self.enabled or not self.redis_client:
return None
key = self._generate_key("response", {
"query": query,
"context": context
})
try:
cached = self.redis_client.get(key)
if cached:
return cached.decode('utf-8')
except Exception as e:
print(f"Cache get error: {e}")
return None
def set_response(self, query: str, context: str, response: str):
"""Cache response."""
if not self.enabled or not self.redis_client:
return
key = self._generate_key("response", {
"query": query,
"context": context
})
try:
self.redis_client.setex(
key,
timedelta(seconds=self.ttl),
response
)
except Exception as e:
print(f"Cache set error: {e}")
def clear(self):
"""Clear all cache."""
if self.redis_client:
try:
self.redis_client.flushdb()
except Exception as e:
print(f"Cache clear error: {e}")
Step 3: Monitoring & Metrics
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
from functools import wraps
import logging
# Define metrics
query_counter = Counter(
'rag_queries_total',
'Total number of queries',
['status']
)
query_duration = Histogram(
'rag_query_duration_seconds',
'Query processing duration',
['stage']
)
retrieval_count = Histogram(
'rag_retrieval_documents',
'Number of documents retrieved',
['method']
)
cache_hits = Counter(
'rag_cache_hits_total',
'Cache hit count',
['cache_type']
)
cache_misses = Counter(
'rag_cache_misses_total',
'Cache miss count',
['cache_type']
)
active_requests = Gauge(
'rag_active_requests',
'Number of active requests'
)
error_counter = Counter(
'rag_errors_total',
'Total errors',
['error_type']
)
class RAGMonitor:
"""Monitoring and metrics for RAG system."""
def __init__(self, config: RAGConfig):
self.config = config
self.logger = self._setup_logging()
if config.enable_monitoring:
start_http_server(config.metrics_port)
self.logger.info(f"Metrics server started on port {config.metrics_port}")
def _setup_logging(self) -> logging.Logger:
"""Setup logging configuration."""
logging.basicConfig(
level=self.config.log_level,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
return logging.getLogger('RAG')
@staticmethod
def track_query(func):
"""Decorator to track query metrics."""
@wraps(func)
def wrapper(*args, **kwargs):
active_requests.inc()
start_time = time.time()
try:
result = func(*args, **kwargs)
query_counter.labels(status='success').inc()
return result
except Exception as e:
query_counter.labels(status='error').inc()
error_counter.labels(error_type=type(e).__name__).inc()
raise
finally:
duration = time.time() - start_time
query_duration.labels(stage='total').observe(duration)
active_requests.dec()
return wrapper
@staticmethod
def track_stage(stage_name: str):
"""Decorator to track individual stage duration."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
try:
return func(*args, **kwargs)
finally:
duration = time.time() - start_time
query_duration.labels(stage=stage_name).observe(duration)
return wrapper
return decorator
Step 4: Production Retriever
from typing import List, Dict, Any
import openai
import cohere
from tenacity import retry, stop_after_attempt, wait_exponential
class ProductionRetriever:
"""
Production-ready retriever with hybrid search and re-ranking.
"""
def __init__(
self,
config: RAGConfig,
cache: RAGCache,
monitor: RAGMonitor
):
self.config = config
self.cache = cache
self.monitor = monitor
# Initialize clients
openai.api_key = config.openai_api_key
if config.cohere_api_key:
self.cohere_client = cohere.Client(config.cohere_api_key)
else:
self.cohere_client = None
# Initialize vector store (example with Pinecone)
self.vector_store = self._init_vector_store()
def _init_vector_store(self):
"""Initialize vector database."""
# Example with Pinecone
try:
import pinecone
pinecone.init(
api_key=self.config.pinecone_api_key,
environment="us-west1-gcp"
)
index_name = "rag-index"
if index_name not in pinecone.list_indexes():
pinecone.create_index(
index_name,
dimension=1536, # OpenAI embedding dimension
metric="cosine"
)
return pinecone.Index(index_name)
except Exception as e:
self.monitor.logger.error(f"Vector store init error: {e}")
return None
@RAGMonitor.track_stage("retrieval")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10)
)
def retrieve(
self,
query: str,
use_cache: bool = True
) -> List[Dict[str, Any]]:
"""
Retrieve relevant documents.
Args:
query: Search query
use_cache: Whether to use cache
Returns:
Retrieved documents
"""
# Check cache
if use_cache:
cached = self.cache.get_retrieval(query)
if cached:
cache_hits.labels(cache_type='retrieval').inc()
self.monitor.logger.info("Cache hit for retrieval")
return cached
cache_misses.labels(cache_type='retrieval').inc()
# Generate query embedding
response = openai.Embedding.create(
model=self.config.embedding_model,
input=query
)
query_embedding = response['data'][0]['embedding']
# Vector search
results = self.vector_store.query(
vector=query_embedding,
top_k=self.config.initial_k,
include_metadata=True
)
# Format results
documents = []
for match in results['matches']:
documents.append({
'id': match['id'],
'content': match['metadata'].get('text', ''),
'score': float(match['score']),
'metadata': match['metadata']
})
retrieval_count.labels(method='vector').observe(len(documents))
# Re-rank if available
if self.cohere_client and len(documents) > self.config.final_k:
documents = self._rerank(query, documents)
# Cache results
if use_cache:
self.cache.set_retrieval(query, documents[:self.config.final_k])
return documents[:self.config.final_k]
@RAGMonitor.track_stage("reranking")
def _rerank(
self,
query: str,
documents: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Re-rank documents using Cohere."""
try:
texts = [doc['content'] for doc in documents]
rerank_response = self.cohere_client.rerank(
model=self.config.rerank_model,
query=query,
documents=texts,
top_n=self.config.final_k
)
# Reorder documents based on rerank results
reranked = []
for result in rerank_response.results:
doc = documents[result.index].copy()
doc['rerank_score'] = result.relevance_score
reranked.append(doc)
retrieval_count.labels(method='reranked').observe(len(reranked))
return reranked
except Exception as e:
self.monitor.logger.error(f"Reranking error: {e}")
error_counter.labels(error_type='rerank_error').inc()
return documents # Fallback to original ranking
Step 5: Production Generator
class ProductionGenerator:
"""
Production-ready response generator.
"""
def __init__(
self,
config: RAGConfig,
cache: RAGCache,
monitor: RAGMonitor
):
self.config = config
self.cache = cache
self.monitor = monitor
openai.api_key = config.openai_api_key
@RAGMonitor.track_stage("generation")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10)
)
def generate(
self,
query: str,
context_docs: List[Dict[str, Any]],
use_cache: bool = True
) -> Dict[str, Any]:
"""
Generate response from context.
Args:
query: User query
context_docs: Retrieved context documents
use_cache: Whether to use cache
Returns:
Generated response with metadata
"""
# Build context
context = self._build_context(context_docs)
# Check cache
if use_cache:
cached = self.cache.get_response(query, context)
if cached:
cache_hits.labels(cache_type='response').inc()
return {
'answer': cached,
'cached': True,
'sources': context_docs
}
cache_misses.labels(cache_type='response').inc()
# Generate response
prompt = self._build_prompt(query, context)
try:
response = openai.ChatCompletion.create(
model=self.config.chat_model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on provided context. Be concise and accurate."
},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=self.config.max_tokens,
timeout=self.config.timeout_seconds
)
answer = response.choices[0].message.content
# Cache response
if use_cache:
self.cache.set_response(query, context, answer)
return {
'answer': answer,
'cached': False,
'sources': context_docs,
'model': self.config.chat_model,
'tokens_used': response['usage']['total_tokens']
}
except Exception as e:
self.monitor.logger.error(f"Generation error: {e}")
error_counter.labels(error_type='generation_error').inc()
raise
def _build_context(self, docs: List[Dict[str, Any]]) -> str:
"""Build context from documents."""
context_parts = []
for i, doc in enumerate(docs, 1):
source_info = ""
if 'metadata' in doc and 'source' in doc['metadata']:
source_info = f" (Source: {doc['metadata']['source']})"
context_parts.append(
f"[{i}]{source_info}\n{doc['content']}"
)
return "\n\n".join(context_parts)
def _build_prompt(self, query: str, context: str) -> str:
"""Build generation prompt."""
return f"""Answer the following question based on the provided context.
If the answer cannot be found in the context, say so.
Context:
{context}
Question: {query}
Answer:"""
Step 6: Main RAG Orchestrator
from dataclasses import dataclass
from typing import Optional
import time
@dataclass
class RAGResponse:
"""Structured RAG response."""
answer: str
sources: List[Dict[str, Any]]
query: str
cached: bool
processing_time: float
tokens_used: Optional[int] = None
metadata: Optional[Dict] = None
class ProductionRAG:
"""
Complete production RAG system.
"""
def __init__(self, config: Optional[RAGConfig] = None):
self.config = config or RAGConfig()
# Initialize components
self.cache = RAGCache(
redis_url=self.config.redis_url,
ttl=self.config.cache_ttl,
enabled=self.config.enable_cache
)
self.monitor = RAGMonitor(self.config)
self.retriever = ProductionRetriever(
self.config,
self.cache,
self.monitor
)
self.generator = ProductionGenerator(
self.config,
self.cache,
self.monitor
)
self.monitor.logger.info("Production RAG system initialized")
@RAGMonitor.track_query
def query(
self,
question: str,
use_cache: bool = True,
metadata: Optional[Dict] = None
) -> RAGResponse:
"""
Process a query through the RAG pipeline.
Args:
question: User question
use_cache: Whether to use caching
metadata: Optional metadata to include
Returns:
RAG response
"""
start_time = time.time()
self.monitor.logger.info(f"Processing query: {question[:50]}...")
try:
# Step 1: Retrieve relevant documents
self.monitor.logger.debug("Starting retrieval...")
context_docs = self.retriever.retrieve(
question,
use_cache=use_cache
)
self.monitor.logger.debug(
f"Retrieved {len(context_docs)} documents"
)
# Step 2: Generate response
self.monitor.logger.debug("Starting generation...")
generation_result = self.generator.generate(
question,
context_docs,
use_cache=use_cache
)
# Build response
processing_time = time.time() - start_time
response = RAGResponse(
answer=generation_result['answer'],
sources=generation_result['sources'],
query=question,
cached=generation_result['cached'],
processing_time=processing_time,
tokens_used=generation_result.get('tokens_used'),
metadata=metadata
)
self.monitor.logger.info(
f"Query completed in {processing_time:.2f}s "
f"(cached: {response.cached})"
)
return response
except Exception as e:
self.monitor.logger.error(f"Query failed: {e}")
raise
def health_check(self) -> Dict[str, Any]:
"""Health check endpoint."""
return {
'status': 'healthy',
'cache_enabled': self.cache.enabled,
'monitoring_enabled': self.config.enable_monitoring,
'vector_store': 'connected' if self.retriever.vector_store else 'disconnected'
}
def get_metrics(self) -> Dict[str, Any]:
"""Get current metrics."""
# This would return Prometheus metrics in production
return {
'active_requests': active_requests._value.get(),
'total_queries': query_counter._metrics.values(),
# ... other metrics
}
Step 7: API Layer (FastAPI)
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Request/Response models
class QueryRequest(BaseModel):
question: str
use_cache: bool = True
metadata: Optional[Dict] = None
class QueryResponse(BaseModel):
answer: str
sources: List[Dict]
processing_time: float
cached: bool
tokens_used: Optional[int]
# Initialize FastAPI
app = FastAPI(
title="Production RAG API",
description="Production-ready RAG system",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Initialize RAG system
rag_system = ProductionRAG()
@app.post("/query", response_model=QueryResponse)
@limiter.limit("60/minute")
async def query_endpoint(request: QueryRequest):
"""
Query the RAG system.
"""
try:
response = rag_system.query(
question=request.question,
use_cache=request.use_cache,
metadata=request.metadata
)
return QueryResponse(
answer=response.answer,
sources=response.sources,
processing_time=response.processing_time,
cached=response.cached,
tokens_used=response.tokens_used
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""Health check endpoint."""
return rag_system.health_check()
@app.get("/metrics")
async def metrics():
"""Get system metrics."""
return rag_system.get_metrics()
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=8080,
log_level="info"
)
Step 8: Document Indexing Pipeline
class DocumentIndexer:
"""
Index documents into the vector store.
"""
def __init__(self, retriever: ProductionRetriever):
self.retriever = retriever
def index_documents(
self,
documents: List[Dict[str, Any]],
batch_size: int = 100
):
"""
Index documents in batches.
Args:
documents: List of {id, text, metadata} dicts
batch_size: Batch size for indexing
"""
total = len(documents)
print(f"📚 Indexing {total} documents...")
for i in range(0, total, batch_size):
batch = documents[i:i + batch_size]
# Generate embeddings
texts = [doc['text'] for doc in batch]
embeddings = self._generate_embeddings(texts)
# Prepare for upsert
vectors = []
for j, doc in enumerate(batch):
vectors.append({
'id': doc['id'],
'values': embeddings[j],
'metadata': {
'text': doc['text'],
**doc.get('metadata', {})
}
})
# Upsert to vector store
self.retriever.vector_store.upsert(vectors=vectors)
print(f" Indexed {i + len(batch)}/{total}")
print("✅ Indexing complete")
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for texts."""
response = openai.Embedding.create(
model=self.retriever.config.embedding_model,
input=texts
)
return [item['embedding'] for item in response['data']]
# Usage
indexer = DocumentIndexer(rag_system.retriever)
documents = [
{
'id': 'doc1',
'text': 'Document content here...',
'metadata': {'source': 'manual', 'date': '2024-01-01'}
},
# ... more documents
]
indexer.index_documents(documents)
Step 9: Deployment (Docker)
# Dockerfile
FROM python:3.10-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY . .
# Expose ports
EXPOSE 8080 8000
# Run application
CMD ["python", "api.py"]
# docker-compose.yml
version: '3.8'
services:
rag-api:
build: .
ports:
- "8080:8080"
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- COHERE_API_KEY=${COHERE_API_KEY}
- PINECONE_API_KEY=${PINECONE_API_KEY}
- REDIS_URL=redis://redis:6379
depends_on:
- redis
redis:
image: redis:7-alpine
ports:
- "6379:6379"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
grafana:
image: grafana/grafana
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
Production Ready: This system includes caching, monitoring, error handling, rate limiting, and horizontal scalability - all essential for production deployment.
Step 10: Testing
import pytest
from unittest.mock import Mock, patch
class TestProductionRAG:
"""Test suite for production RAG."""
@pytest.fixture
def rag_system(self):
"""Create RAG system for testing."""
config = RAGConfig(
enable_cache=False,
enable_monitoring=False
)
return ProductionRAG(config)
def test_query_success(self, rag_system):
"""Test successful query."""
response = rag_system.query("What is RAG?")
assert response.answer is not None
assert len(response.sources) > 0
assert response.processing_time > 0
def test_cache_hit(self, rag_system):
"""Test cache functionality."""
rag_system.config.enable_cache = True
# First query
response1 = rag_system.query("Test query")
assert not response1.cached
# Second query (should be cached)
response2 = rag_system.query("Test query")
assert response2.cached
@patch('openai.Embedding.create')
def test_retrieval_error_handling(self, mock_embed, rag_system):
"""Test error handling in retrieval."""
mock_embed.side_effect = Exception("API Error")
with pytest.raises(Exception):
rag_system.query("Test query")
# Run tests
# pytest test_rag.py -v
Monitoring is Critical: Always monitor your production RAG system for latency, error rates, cache hit rates, and token usage to identify issues early.
Key Takeaways
- Configuration management - use environment variables and validation
- Caching - dramatically improves performance and reduces costs
- Monitoring - track metrics to understand system behavior
- Error handling - retry logic and graceful degradation
- Scalability - horizontal scaling with load balancing
- Testing - comprehensive tests for reliability
- Deployment - containerization for easy deployment
Quiz
Test your understanding of production RAG systems: