Back
advanced
Advanced RAG & Context

Project: Production-Ready RAG System

Build a complete production-ready RAG system with monitoring, caching, optimization, and deployment strategies

60 min read· Project· RAG· Production· Monitoring

Project: Production-Ready RAG System

Build a complete, production-ready RAG system with all the features needed for real-world deployment: monitoring, caching, error handling, optimization, and scalability.

Production Requirements: A production RAG system needs more than basic retrieval and generation - it requires monitoring, caching, error handling, optimization, and the ability to scale.

System Architecture

┌─────────────────────────────────────────────────────────────┐
│                      API Gateway                             │
│                   (Rate Limiting, Auth)                      │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│                    RAG Orchestrator                          │
│              (Query Processing, Caching)                     │
└─────────────────────────────────────────────────────────────┘
         │                  │                  │
         ▼                  ▼                  ▼
┌──────────────┐  ┌──────────────┐  ┌──────────────┐
│   Retriever  │  │   Reranker   │  │  Generator   │
│  (Hybrid)    │  │  (Cohere)    │  │   (GPT-4)    │
└──────────────┘  └──────────────┘  └──────────────┘
         │
         ▼
┌──────────────────────────────────────────────────────────┐
│              Vector Database (Pinecone/Weaviate)          │
└──────────────────────────────────────────────────────────┘
         │
         ▼
┌──────────────────────────────────────────────────────────┐
│        Monitoring & Logging (Prometheus, Grafana)         │
└──────────────────────────────────────────────────────────┘

Step 1: Configuration Management

python
from pydantic import BaseSettings, validator
from typing import Optional, List
import os


class RAGConfig(BaseSettings):
    """Production RAG configuration using Pydantic."""

    # API Keys
    openai_api_key: str
    cohere_api_key: Optional[str] = None
    pinecone_api_key: Optional[str] = None

    # Model Configuration
    embedding_model: str = "text-embedding-ada-002"
    chat_model: str = "gpt-4"
    rerank_model: str = "rerank-english-v2.0"

    # Retrieval Configuration
    initial_k: int = 50
    final_k: int = 5
    hybrid_alpha: float = 0.5

    # Cache Configuration
    enable_cache: bool = True
    cache_ttl: int = 3600  # 1 hour
    redis_url: Optional[str] = None

    # Performance
    max_tokens: int = 2000
    timeout_seconds: int = 30
    max_retries: int = 3

    # Monitoring
    enable_monitoring: bool = True
    log_level: str = "INFO"
    metrics_port: int = 8000

    # Rate Limiting
    requests_per_minute: int = 60
    burst_size: int = 10

    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"

    @validator("log_level")
    def validate_log_level(cls, v):
        valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
        if v.upper() not in valid_levels:
            raise ValueError(f"Invalid log level. Must be one of {valid_levels}")
        return v.upper()


# Load configuration
config = RAGConfig()

Step 2: Caching Layer

python
import hashlib
import json
import redis
from typing import Optional, Any
from datetime import timedelta


class RAGCache:
    """
    Caching layer for RAG system.

    Caches retrieval results and generated responses.
    """

    def __init__(
        self,
        redis_url: Optional[str] = None,
        ttl: int = 3600,
        enabled: bool = True
    ):
        self.enabled = enabled
        self.ttl = ttl

        if enabled and redis_url:
            self.redis_client = redis.from_url(redis_url)
        else:
            self.redis_client = None

    def _generate_key(self, prefix: str, data: Any) -> str:
        """Generate cache key from data."""
        data_str = json.dumps(data, sort_keys=True)
        hash_val = hashlib.md5(data_str.encode()).hexdigest()
        return f"{prefix}:{hash_val}"

    def get_retrieval(self, query: str) -> Optional[List]:
        """Get cached retrieval results."""
        if not self.enabled or not self.redis_client:
            return None

        key = self._generate_key("retrieval", {"query": query})

        try:
            cached = self.redis_client.get(key)
            if cached:
                return json.loads(cached)
        except Exception as e:
            print(f"Cache get error: {e}")

        return None

    def set_retrieval(self, query: str, results: List):
        """Cache retrieval results."""
        if not self.enabled or not self.redis_client:
            return

        key = self._generate_key("retrieval", {"query": query})

        try:
            self.redis_client.setex(
                key,
                timedelta(seconds=self.ttl),
                json.dumps(results)
            )
        except Exception as e:
            print(f"Cache set error: {e}")

    def get_response(self, query: str, context: str) -> Optional[str]:
        """Get cached response."""
        if not self.enabled or not self.redis_client:
            return None

        key = self._generate_key("response", {
            "query": query,
            "context": context
        })

        try:
            cached = self.redis_client.get(key)
            if cached:
                return cached.decode('utf-8')
        except Exception as e:
            print(f"Cache get error: {e}")

        return None

    def set_response(self, query: str, context: str, response: str):
        """Cache response."""
        if not self.enabled or not self.redis_client:
            return

        key = self._generate_key("response", {
            "query": query,
            "context": context
        })

        try:
            self.redis_client.setex(
                key,
                timedelta(seconds=self.ttl),
                response
            )
        except Exception as e:
            print(f"Cache set error: {e}")

    def clear(self):
        """Clear all cache."""
        if self.redis_client:
            try:
                self.redis_client.flushdb()
            except Exception as e:
                print(f"Cache clear error: {e}")

Step 3: Monitoring & Metrics

python
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
from functools import wraps
import logging


# Define metrics
query_counter = Counter(
    'rag_queries_total',
    'Total number of queries',
    ['status']
)

query_duration = Histogram(
    'rag_query_duration_seconds',
    'Query processing duration',
    ['stage']
)

retrieval_count = Histogram(
    'rag_retrieval_documents',
    'Number of documents retrieved',
    ['method']
)

cache_hits = Counter(
    'rag_cache_hits_total',
    'Cache hit count',
    ['cache_type']
)

cache_misses = Counter(
    'rag_cache_misses_total',
    'Cache miss count',
    ['cache_type']
)

active_requests = Gauge(
    'rag_active_requests',
    'Number of active requests'
)

error_counter = Counter(
    'rag_errors_total',
    'Total errors',
    ['error_type']
)


class RAGMonitor:
    """Monitoring and metrics for RAG system."""

    def __init__(self, config: RAGConfig):
        self.config = config
        self.logger = self._setup_logging()

        if config.enable_monitoring:
            start_http_server(config.metrics_port)
            self.logger.info(f"Metrics server started on port {config.metrics_port}")

    def _setup_logging(self) -> logging.Logger:
        """Setup logging configuration."""
        logging.basicConfig(
            level=self.config.log_level,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        return logging.getLogger('RAG')

    @staticmethod
    def track_query(func):
        """Decorator to track query metrics."""
        @wraps(func)
        def wrapper(*args, **kwargs):
            active_requests.inc()
            start_time = time.time()

            try:
                result = func(*args, **kwargs)
                query_counter.labels(status='success').inc()
                return result

            except Exception as e:
                query_counter.labels(status='error').inc()
                error_counter.labels(error_type=type(e).__name__).inc()
                raise

            finally:
                duration = time.time() - start_time
                query_duration.labels(stage='total').observe(duration)
                active_requests.dec()

        return wrapper

    @staticmethod
    def track_stage(stage_name: str):
        """Decorator to track individual stage duration."""
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                start_time = time.time()
                try:
                    return func(*args, **kwargs)
                finally:
                    duration = time.time() - start_time
                    query_duration.labels(stage=stage_name).observe(duration)
            return wrapper
        return decorator

Step 4: Production Retriever

python
from typing import List, Dict, Any
import openai
import cohere
from tenacity import retry, stop_after_attempt, wait_exponential


class ProductionRetriever:
    """
    Production-ready retriever with hybrid search and re-ranking.
    """

    def __init__(
        self,
        config: RAGConfig,
        cache: RAGCache,
        monitor: RAGMonitor
    ):
        self.config = config
        self.cache = cache
        self.monitor = monitor

        # Initialize clients
        openai.api_key = config.openai_api_key

        if config.cohere_api_key:
            self.cohere_client = cohere.Client(config.cohere_api_key)
        else:
            self.cohere_client = None

        # Initialize vector store (example with Pinecone)
        self.vector_store = self._init_vector_store()

    def _init_vector_store(self):
        """Initialize vector database."""
        # Example with Pinecone
        try:
            import pinecone

            pinecone.init(
                api_key=self.config.pinecone_api_key,
                environment="us-west1-gcp"
            )

            index_name = "rag-index"
            if index_name not in pinecone.list_indexes():
                pinecone.create_index(
                    index_name,
                    dimension=1536,  # OpenAI embedding dimension
                    metric="cosine"
                )

            return pinecone.Index(index_name)

        except Exception as e:
            self.monitor.logger.error(f"Vector store init error: {e}")
            return None

    @RAGMonitor.track_stage("retrieval")
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10)
    )
    def retrieve(
        self,
        query: str,
        use_cache: bool = True
    ) -> List[Dict[str, Any]]:
        """
        Retrieve relevant documents.

        Args:
            query: Search query
            use_cache: Whether to use cache

        Returns:
            Retrieved documents
        """
        # Check cache
        if use_cache:
            cached = self.cache.get_retrieval(query)
            if cached:
                cache_hits.labels(cache_type='retrieval').inc()
                self.monitor.logger.info("Cache hit for retrieval")
                return cached

            cache_misses.labels(cache_type='retrieval').inc()

        # Generate query embedding
        response = openai.Embedding.create(
            model=self.config.embedding_model,
            input=query
        )
        query_embedding = response['data'][0]['embedding']

        # Vector search
        results = self.vector_store.query(
            vector=query_embedding,
            top_k=self.config.initial_k,
            include_metadata=True
        )

        # Format results
        documents = []
        for match in results['matches']:
            documents.append({
                'id': match['id'],
                'content': match['metadata'].get('text', ''),
                'score': float(match['score']),
                'metadata': match['metadata']
            })

        retrieval_count.labels(method='vector').observe(len(documents))

        # Re-rank if available
        if self.cohere_client and len(documents) > self.config.final_k:
            documents = self._rerank(query, documents)

        # Cache results
        if use_cache:
            self.cache.set_retrieval(query, documents[:self.config.final_k])

        return documents[:self.config.final_k]

    @RAGMonitor.track_stage("reranking")
    def _rerank(
        self,
        query: str,
        documents: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """Re-rank documents using Cohere."""
        try:
            texts = [doc['content'] for doc in documents]

            rerank_response = self.cohere_client.rerank(
                model=self.config.rerank_model,
                query=query,
                documents=texts,
                top_n=self.config.final_k
            )

            # Reorder documents based on rerank results
            reranked = []
            for result in rerank_response.results:
                doc = documents[result.index].copy()
                doc['rerank_score'] = result.relevance_score
                reranked.append(doc)

            retrieval_count.labels(method='reranked').observe(len(reranked))

            return reranked

        except Exception as e:
            self.monitor.logger.error(f"Reranking error: {e}")
            error_counter.labels(error_type='rerank_error').inc()
            return documents  # Fallback to original ranking

Step 5: Production Generator

python
class ProductionGenerator:
    """
    Production-ready response generator.
    """

    def __init__(
        self,
        config: RAGConfig,
        cache: RAGCache,
        monitor: RAGMonitor
    ):
        self.config = config
        self.cache = cache
        self.monitor = monitor
        openai.api_key = config.openai_api_key

    @RAGMonitor.track_stage("generation")
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10)
    )
    def generate(
        self,
        query: str,
        context_docs: List[Dict[str, Any]],
        use_cache: bool = True
    ) -> Dict[str, Any]:
        """
        Generate response from context.

        Args:
            query: User query
            context_docs: Retrieved context documents
            use_cache: Whether to use cache

        Returns:
            Generated response with metadata
        """
        # Build context
        context = self._build_context(context_docs)

        # Check cache
        if use_cache:
            cached = self.cache.get_response(query, context)
            if cached:
                cache_hits.labels(cache_type='response').inc()
                return {
                    'answer': cached,
                    'cached': True,
                    'sources': context_docs
                }

            cache_misses.labels(cache_type='response').inc()

        # Generate response
        prompt = self._build_prompt(query, context)

        try:
            response = openai.ChatCompletion.create(
                model=self.config.chat_model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are a helpful assistant that answers questions based on provided context. Be concise and accurate."
                    },
                    {"role": "user", "content": prompt}
                ],
                temperature=0.3,
                max_tokens=self.config.max_tokens,
                timeout=self.config.timeout_seconds
            )

            answer = response.choices[0].message.content

            # Cache response
            if use_cache:
                self.cache.set_response(query, context, answer)

            return {
                'answer': answer,
                'cached': False,
                'sources': context_docs,
                'model': self.config.chat_model,
                'tokens_used': response['usage']['total_tokens']
            }

        except Exception as e:
            self.monitor.logger.error(f"Generation error: {e}")
            error_counter.labels(error_type='generation_error').inc()
            raise

    def _build_context(self, docs: List[Dict[str, Any]]) -> str:
        """Build context from documents."""
        context_parts = []

        for i, doc in enumerate(docs, 1):
            source_info = ""
            if 'metadata' in doc and 'source' in doc['metadata']:
                source_info = f" (Source: {doc['metadata']['source']})"

            context_parts.append(
                f"[{i}]{source_info}\n{doc['content']}"
            )

        return "\n\n".join(context_parts)

    def _build_prompt(self, query: str, context: str) -> str:
        """Build generation prompt."""
        return f"""Answer the following question based on the provided context.
If the answer cannot be found in the context, say so.

Context:
{context}

Question: {query}

Answer:"""

Step 6: Main RAG Orchestrator

python
from dataclasses import dataclass
from typing import Optional
import time


@dataclass
class RAGResponse:
    """Structured RAG response."""
    answer: str
    sources: List[Dict[str, Any]]
    query: str
    cached: bool
    processing_time: float
    tokens_used: Optional[int] = None
    metadata: Optional[Dict] = None


class ProductionRAG:
    """
    Complete production RAG system.
    """

    def __init__(self, config: Optional[RAGConfig] = None):
        self.config = config or RAGConfig()

        # Initialize components
        self.cache = RAGCache(
            redis_url=self.config.redis_url,
            ttl=self.config.cache_ttl,
            enabled=self.config.enable_cache
        )

        self.monitor = RAGMonitor(self.config)

        self.retriever = ProductionRetriever(
            self.config,
            self.cache,
            self.monitor
        )

        self.generator = ProductionGenerator(
            self.config,
            self.cache,
            self.monitor
        )

        self.monitor.logger.info("Production RAG system initialized")

    @RAGMonitor.track_query
    def query(
        self,
        question: str,
        use_cache: bool = True,
        metadata: Optional[Dict] = None
    ) -> RAGResponse:
        """
        Process a query through the RAG pipeline.

        Args:
            question: User question
            use_cache: Whether to use caching
            metadata: Optional metadata to include

        Returns:
            RAG response
        """
        start_time = time.time()

        self.monitor.logger.info(f"Processing query: {question[:50]}...")

        try:
            # Step 1: Retrieve relevant documents
            self.monitor.logger.debug("Starting retrieval...")
            context_docs = self.retriever.retrieve(
                question,
                use_cache=use_cache
            )

            self.monitor.logger.debug(
                f"Retrieved {len(context_docs)} documents"
            )

            # Step 2: Generate response
            self.monitor.logger.debug("Starting generation...")
            generation_result = self.generator.generate(
                question,
                context_docs,
                use_cache=use_cache
            )

            # Build response
            processing_time = time.time() - start_time

            response = RAGResponse(
                answer=generation_result['answer'],
                sources=generation_result['sources'],
                query=question,
                cached=generation_result['cached'],
                processing_time=processing_time,
                tokens_used=generation_result.get('tokens_used'),
                metadata=metadata
            )

            self.monitor.logger.info(
                f"Query completed in {processing_time:.2f}s "
                f"(cached: {response.cached})"
            )

            return response

        except Exception as e:
            self.monitor.logger.error(f"Query failed: {e}")
            raise

    def health_check(self) -> Dict[str, Any]:
        """Health check endpoint."""
        return {
            'status': 'healthy',
            'cache_enabled': self.cache.enabled,
            'monitoring_enabled': self.config.enable_monitoring,
            'vector_store': 'connected' if self.retriever.vector_store else 'disconnected'
        }

    def get_metrics(self) -> Dict[str, Any]:
        """Get current metrics."""
        # This would return Prometheus metrics in production
        return {
            'active_requests': active_requests._value.get(),
            'total_queries': query_counter._metrics.values(),
            # ... other metrics
        }

Step 7: API Layer (FastAPI)

python
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded


# Request/Response models
class QueryRequest(BaseModel):
    question: str
    use_cache: bool = True
    metadata: Optional[Dict] = None


class QueryResponse(BaseModel):
    answer: str
    sources: List[Dict]
    processing_time: float
    cached: bool
    tokens_used: Optional[int]


# Initialize FastAPI
app = FastAPI(
    title="Production RAG API",
    description="Production-ready RAG system",
    version="1.0.0"
)

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Rate limiting
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Initialize RAG system
rag_system = ProductionRAG()


@app.post("/query", response_model=QueryResponse)
@limiter.limit("60/minute")
async def query_endpoint(request: QueryRequest):
    """
    Query the RAG system.
    """
    try:
        response = rag_system.query(
            question=request.question,
            use_cache=request.use_cache,
            metadata=request.metadata
        )

        return QueryResponse(
            answer=response.answer,
            sources=response.sources,
            processing_time=response.processing_time,
            cached=response.cached,
            tokens_used=response.tokens_used
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health():
    """Health check endpoint."""
    return rag_system.health_check()


@app.get("/metrics")
async def metrics():
    """Get system metrics."""
    return rag_system.get_metrics()


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8080,
        log_level="info"
    )

Step 8: Document Indexing Pipeline

python
class DocumentIndexer:
    """
    Index documents into the vector store.
    """

    def __init__(self, retriever: ProductionRetriever):
        self.retriever = retriever

    def index_documents(
        self,
        documents: List[Dict[str, Any]],
        batch_size: int = 100
    ):
        """
        Index documents in batches.

        Args:
            documents: List of {id, text, metadata} dicts
            batch_size: Batch size for indexing
        """
        total = len(documents)
        print(f"📚 Indexing {total} documents...")

        for i in range(0, total, batch_size):
            batch = documents[i:i + batch_size]

            # Generate embeddings
            texts = [doc['text'] for doc in batch]
            embeddings = self._generate_embeddings(texts)

            # Prepare for upsert
            vectors = []
            for j, doc in enumerate(batch):
                vectors.append({
                    'id': doc['id'],
                    'values': embeddings[j],
                    'metadata': {
                        'text': doc['text'],
                        **doc.get('metadata', {})
                    }
                })

            # Upsert to vector store
            self.retriever.vector_store.upsert(vectors=vectors)

            print(f"   Indexed {i + len(batch)}/{total}")

        print("✅ Indexing complete")

    def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for texts."""
        response = openai.Embedding.create(
            model=self.retriever.config.embedding_model,
            input=texts
        )

        return [item['embedding'] for item in response['data']]


# Usage
indexer = DocumentIndexer(rag_system.retriever)

documents = [
    {
        'id': 'doc1',
        'text': 'Document content here...',
        'metadata': {'source': 'manual', 'date': '2024-01-01'}
    },
    # ... more documents
]

indexer.index_documents(documents)

Step 9: Deployment (Docker)

dockerfile
# Dockerfile
FROM python:3.10-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application
COPY . .

# Expose ports
EXPOSE 8080 8000

# Run application
CMD ["python", "api.py"]
yaml
# docker-compose.yml
version: '3.8'

services:
  rag-api:
    build: .
    ports:
      - "8080:8080"
      - "8000:8000"
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - COHERE_API_KEY=${COHERE_API_KEY}
      - PINECONE_API_KEY=${PINECONE_API_KEY}
      - REDIS_URL=redis://redis:6379
    depends_on:
      - redis

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"

  prometheus:
    image: prom/prometheus
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml

  grafana:
    image: grafana/grafana
    ports:
      - "3000:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin

Production Ready: This system includes caching, monitoring, error handling, rate limiting, and horizontal scalability - all essential for production deployment.

Step 10: Testing

python
import pytest
from unittest.mock import Mock, patch


class TestProductionRAG:
    """Test suite for production RAG."""

    @pytest.fixture
    def rag_system(self):
        """Create RAG system for testing."""
        config = RAGConfig(
            enable_cache=False,
            enable_monitoring=False
        )
        return ProductionRAG(config)

    def test_query_success(self, rag_system):
        """Test successful query."""
        response = rag_system.query("What is RAG?")

        assert response.answer is not None
        assert len(response.sources) > 0
        assert response.processing_time > 0

    def test_cache_hit(self, rag_system):
        """Test cache functionality."""
        rag_system.config.enable_cache = True

        # First query
        response1 = rag_system.query("Test query")
        assert not response1.cached

        # Second query (should be cached)
        response2 = rag_system.query("Test query")
        assert response2.cached

    @patch('openai.Embedding.create')
    def test_retrieval_error_handling(self, mock_embed, rag_system):
        """Test error handling in retrieval."""
        mock_embed.side_effect = Exception("API Error")

        with pytest.raises(Exception):
            rag_system.query("Test query")


# Run tests
# pytest test_rag.py -v

Monitoring is Critical: Always monitor your production RAG system for latency, error rates, cache hit rates, and token usage to identify issues early.

Key Takeaways

  1. Configuration management - use environment variables and validation
  2. Caching - dramatically improves performance and reduces costs
  3. Monitoring - track metrics to understand system behavior
  4. Error handling - retry logic and graceful degradation
  5. Scalability - horizontal scaling with load balancing
  6. Testing - comprehensive tests for reliability
  7. Deployment - containerization for easy deployment

Quiz

Test your understanding of production RAG systems: