Back
advanced
Optimization & Deployment

Caching Strategies for LLMs

Implement prompt caching, KV cache management, and Redis integration for efficient LLM inference

20 min read· caching· optimization· redis· kv-cache

Caching Strategies for LLMs

Learn advanced caching techniques to dramatically reduce inference costs and latency for LLM applications.

What You'll Learn: Caching is one of the most effective ways to optimize LLM applications. We'll explore prompt caching, KV cache management, and distributed caching with Redis.

Understanding LLM Caching

Cache Types

python
from typing import Dict, Optional, List, Tuple
import hashlib
import time
from dataclasses import dataclass
from datetime import datetime, timedelta

@dataclass
class CacheEntry:
    key: str
    value: str
    created_at: datetime
    hit_count: int = 0
    last_accessed: datetime = None

    def __post_init__(self):
        if self.last_accessed is None:
            self.last_accessed = self.created_at

class LLMCachingLayers:
    """Demonstrates different caching layers"""

    def __init__(self):
        # Layer 1: Response cache (full prompt + response)
        self.response_cache: Dict[str, CacheEntry] = {}

        # Layer 2: KV cache (prefix cache for transformers)
        self.kv_cache: Dict[str, any] = {}

        # Layer 3: Semantic cache (similar prompts)
        self.semantic_cache: Dict[str, List[Tuple[str, str]]] = {}

    def compute_cache_key(self, prompt: str, model_params: Dict) -> str:
        """Generate cache key from prompt and parameters"""

        # Include model parameters in key
        key_string = f"{prompt}|{sorted(model_params.items())}"
        return hashlib.sha256(key_string.encode()).hexdigest()

    def get_response_cache(
        self,
        prompt: str,
        model_params: Dict
    ) -> Optional[str]:
        """Get cached response"""

        key = self.compute_cache_key(prompt, model_params)

        if key in self.response_cache:
            entry = self.response_cache[key]
            entry.hit_count += 1
            entry.last_accessed = datetime.now()

            print(f"Cache HIT! (hits: {entry.hit_count})")
            return entry.value

        print("Cache MISS")
        return None

    def set_response_cache(
        self,
        prompt: str,
        response: str,
        model_params: Dict
    ):
        """Store response in cache"""

        key = self.compute_cache_key(prompt, model_params)

        self.response_cache[key] = CacheEntry(
            key=key,
            value=response,
            created_at=datetime.now()
        )

    def demonstrate_caching_impact(self):
        """Demonstrate cache impact on performance"""

        model_params = {"temperature": 0.7, "max_tokens": 100}

        # First request (cache miss)
        start = time.time()
        result1 = self.get_response_cache("What is AI?", model_params)
        if result1 is None:
            # Simulate LLM inference (expensive)
            time.sleep(2)
            result1 = "AI is artificial intelligence..."
            self.set_response_cache("What is AI?", result1, model_params)
        time1 = time.time() - start

        # Second request (cache hit)
        start = time.time()
        result2 = self.get_response_cache("What is AI?", model_params)
        time2 = time.time() - start

        print(f"\nFirst request: {time1:.3f}s")
        print(f"Second request: {time2:.3f}s")
        print(f"Speedup: {time1/time2:.1f}x")

# Example usage
cache = LLMCachingLayers()
cache.demonstrate_caching_impact()

Prompt Caching

Prompt Caching: Cache common prompt prefixes to avoid recomputing attention for repeated content. This is especially effective for system prompts and few-shot examples.

Prefix Caching Implementation

python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Tuple, Optional

class PrefixCache:
    def __init__(self, max_size: int = 100):
        """
        Prefix cache for transformer KV states

        Args:
            max_size: Maximum number of cached prefixes
        """
        self.cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
        self.max_size = max_size
        self.access_times: Dict[str, float] = {}

    def get_cache_key(self, prefix: str) -> str:
        """Generate cache key for prefix"""
        return hashlib.sha256(prefix.encode()).hexdigest()

    def get(self, prefix: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        """Get cached KV states for prefix"""

        key = self.get_cache_key(prefix)

        if key in self.cache:
            self.access_times[key] = time.time()
            return self.cache[key]

        return None

    def set(
        self,
        prefix: str,
        past_key_values: Tuple[torch.Tensor, torch.Tensor]
    ):
        """Cache KV states for prefix"""

        # Evict oldest if cache is full
        if len(self.cache) >= self.max_size:
            oldest_key = min(self.access_times, key=self.access_times.get)
            del self.cache[oldest_key]
            del self.access_times[oldest_key]

        key = self.get_cache_key(prefix)
        self.cache[key] = past_key_values
        self.access_times[key] = time.time()

    def clear(self):
        """Clear cache"""
        self.cache.clear()
        self.access_times.clear()

class CachedLLM:
    def __init__(self, model_name: str):
        """Initialize model with prefix caching"""

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.prefix_cache = PrefixCache(max_size=50)

    def generate_with_cache(
        self,
        prompt: str,
        system_prompt: str = "",
        max_new_tokens: int = 100
    ) -> str:
        """
        Generate with prefix caching

        Args:
            prompt: User prompt
            system_prompt: System prompt (will be cached)
            max_new_tokens: Maximum tokens to generate
        """

        # Try to get cached KV states for system prompt
        past_key_values = None
        prefix_tokens = None

        if system_prompt:
            past_key_values = self.prefix_cache.get(system_prompt)

            if past_key_values:
                print("Using cached system prompt!")
                # Tokenize only the user prompt
                prefix_tokens = self.tokenizer(
                    system_prompt,
                    return_tensors="pt"
                )["input_ids"]

                user_tokens = self.tokenizer(
                    prompt,
                    return_tensors="pt"
                )["input_ids"]

                input_ids = torch.cat([prefix_tokens, user_tokens], dim=1)
            else:
                print("System prompt not cached, computing...")
                # Tokenize full prompt
                full_prompt = system_prompt + prompt
                input_ids = self.tokenizer(
                    full_prompt,
                    return_tensors="pt"
                )["input_ids"]

                # Compute KV states for system prompt
                prefix_tokens = self.tokenizer(
                    system_prompt,
                    return_tensors="pt"
                )["input_ids"]

                with torch.no_grad():
                    outputs = self.model(
                        prefix_tokens.to(self.model.device),
                        use_cache=True
                    )
                    past_key_values = outputs.past_key_values

                # Cache the KV states
                self.prefix_cache.set(system_prompt, past_key_values)
        else:
            input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]

        # Generate with cached past_key_values
        input_ids = input_ids.to(self.model.device)

        # If we have cached past, only use the new tokens
        if past_key_values and prefix_tokens is not None:
            # Use only user prompt tokens for generation
            input_ids = input_ids[:, prefix_tokens.shape[1]:]

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids,
                past_key_values=past_key_values,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7
            )

        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def benchmark_caching(self):
        """Benchmark prefix caching benefits"""

        system_prompt = """You are a helpful AI assistant. You are expert in
        mathematics, science, and programming. You provide clear, accurate,
        and detailed explanations. """

        prompts = [
            "What is calculus?",
            "Explain quantum physics.",
            "How does a computer work?",
            "What is machine learning?"
        ]

        # Without cache
        self.prefix_cache.clear()
        start = time.time()
        for prompt in prompts:
            _ = self.generate_with_cache(prompt, system_prompt, max_new_tokens=50)
        time_no_cache = time.time() - start

        # With cache (second run)
        start = time.time()
        for prompt in prompts:
            _ = self.generate_with_cache(prompt, system_prompt, max_new_tokens=50)
        time_with_cache = time.time() - start

        print(f"\nWithout cache: {time_no_cache:.2f}s")
        print(f"With cache: {time_with_cache:.2f}s")
        print(f"Speedup: {time_no_cache/time_with_cache:.2f}x")

# Example usage
cached_llm = CachedLLM("gpt2")

system_prompt = "You are a helpful assistant. "

# First call - cache miss
response1 = cached_llm.generate_with_cache(
    "What is Python?",
    system_prompt=system_prompt
)

# Second call - cache hit
response2 = cached_llm.generate_with_cache(
    "What is JavaScript?",
    system_prompt=system_prompt
)

# Benchmark
cached_llm.benchmark_caching()

KV Cache Management

KV Cache: Transformers cache key and value tensors during generation to avoid recomputing attention for previous tokens. Managing this cache is crucial for memory efficiency.

Advanced KV Cache Strategies

python
class KVCacheManager:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def estimate_kv_cache_size(
        self,
        batch_size: int,
        sequence_length: int,
        num_layers: int = None,
        hidden_size: int = None,
        num_heads: int = None
    ) -> Dict[str, float]:
        """
        Estimate KV cache memory usage

        Formula: 2 * batch_size * num_layers * num_heads * seq_len * (hidden_size / num_heads) * bytes_per_element
        """

        if num_layers is None:
            num_layers = self.model.config.num_hidden_layers
        if hidden_size is None:
            hidden_size = self.model.config.hidden_size
        if num_heads is None:
            num_heads = self.model.config.num_attention_heads

        head_dim = hidden_size // num_heads
        bytes_per_element = 2  # FP16

        # KV cache size (keys + values)
        cache_size_bytes = (
            2 * batch_size * num_layers * num_heads *
            sequence_length * head_dim * bytes_per_element
        )

        return {
            "total_bytes": cache_size_bytes,
            "total_mb": cache_size_bytes / (1024 ** 2),
            "total_gb": cache_size_bytes / (1024 ** 3),
            "per_token_mb": cache_size_bytes / sequence_length / (1024 ** 2)
        }

    def sliding_window_generation(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        window_size: int = 512
    ) -> str:
        """
        Generate with sliding window KV cache

        Keeps only the last `window_size` tokens in KV cache
        to limit memory usage for long generations.
        """

        input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
        input_ids = input_ids.to(self.model.device)

        generated_tokens = []
        past_key_values = None

        for _ in range(max_new_tokens):
            # If past is too long, trim it
            if past_key_values is not None:
                # Get the length of cached sequence
                cached_len = past_key_values[0][0].shape[2]

                if cached_len > window_size:
                    # Trim past_key_values to window_size
                    trimmed_past = []
                    for layer_past in past_key_values:
                        trimmed_layer = tuple(
                            p[:, :, -window_size:, :] for p in layer_past
                        )
                        trimmed_past.append(trimmed_layer)
                    past_key_values = tuple(trimmed_past)

                    # Adjust input_ids (use only last token)
                    input_ids = input_ids[:, -1:]

            # Forward pass
            with torch.no_grad():
                outputs = self.model(
                    input_ids,
                    past_key_values=past_key_values,
                    use_cache=True
                )

            # Get next token
            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated_tokens.append(next_token.item())

            # Update for next iteration
            input_ids = next_token
            past_key_values = outputs.past_key_values

            # Stop if EOS
            if next_token.item() == self.tokenizer.eos_token_id:
                break

        # Decode
        full_output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return full_output

    def demonstrate_memory_savings(self):
        """Demonstrate memory savings from KV cache management"""

        # Long sequence scenario
        batch_size = 1
        seq_length = 2048

        # Without windowing
        full_cache = self.estimate_kv_cache_size(batch_size, seq_length)

        # With windowing (512 tokens)
        windowed_cache = self.estimate_kv_cache_size(batch_size, 512)

        print("KV Cache Memory Usage:")
        print(f"Full cache ({seq_length} tokens): {full_cache['total_mb']:.2f} MB")
        print(f"Windowed cache (512 tokens): {windowed_cache['total_mb']:.2f} MB")
        print(f"Memory saved: {(1 - windowed_cache['total_mb']/full_cache['total_mb'])*100:.1f}%")

# Example usage
kv_manager = KVCacheManager(cached_llm.model, cached_llm.tokenizer)

# Estimate cache size
cache_info = kv_manager.estimate_kv_cache_size(
    batch_size=4,
    sequence_length=1024
)
print(f"Estimated KV cache: {cache_info['total_mb']:.2f} MB")

# Sliding window generation
response = kv_manager.sliding_window_generation(
    "Write a long story about space exploration:",
    max_new_tokens=200,
    window_size=512
)

# Demonstrate savings
kv_manager.demonstrate_memory_savings()

Redis Caching for Distributed Systems

Redis Caching: For distributed applications, use Redis to share cached responses across multiple instances and reduce overall inference costs.

Redis Cache Implementation

python
import redis
import json
from typing import Optional, Dict, Any
import pickle

class RedisLLMCache:
    def __init__(
        self,
        host: str = "localhost",
        port: int = 6379,
        db: int = 0,
        ttl: int = 3600  # 1 hour default TTL
    ):
        """
        Initialize Redis cache for LLM responses

        Args:
            host: Redis host
            port: Redis port
            db: Redis database number
            ttl: Time to live in seconds
        """
        self.redis_client = redis.Redis(
            host=host,
            port=port,
            db=db,
            decode_responses=False  # We'll use pickle for complex objects
        )
        self.ttl = ttl

    def generate_cache_key(
        self,
        prompt: str,
        model_name: str,
        params: Dict[str, Any]
    ) -> str:
        """Generate unique cache key"""

        # Sort params for consistency
        param_str = json.dumps(params, sort_keys=True)
        key_content = f"{model_name}:{prompt}:{param_str}"

        return hashlib.sha256(key_content.encode()).hexdigest()

    def get(
        self,
        prompt: str,
        model_name: str,
        params: Dict[str, Any]
    ) -> Optional[str]:
        """Get cached response"""

        key = self.generate_cache_key(prompt, model_name, params)

        cached_value = self.redis_client.get(key)

        if cached_value:
            # Increment hit counter
            self.redis_client.hincrby("llm_cache_stats", "hits", 1)

            # Deserialize
            return pickle.loads(cached_value)

        # Increment miss counter
        self.redis_client.hincrby("llm_cache_stats", "misses", 1)
        return None

    def set(
        self,
        prompt: str,
        model_name: str,
        params: Dict[str, Any],
        response: str
    ):
        """Cache response"""

        key = self.generate_cache_key(prompt, model_name, params)

        # Serialize and store with TTL
        self.redis_client.setex(
            key,
            self.ttl,
            pickle.dumps(response)
        )

    def get_stats(self) -> Dict[str, int]:
        """Get cache statistics"""

        stats = self.redis_client.hgetall("llm_cache_stats")

        return {
            "hits": int(stats.get(b"hits", 0)),
            "misses": int(stats.get(b"misses", 0)),
            "hit_rate": self._calculate_hit_rate(stats)
        }

    def _calculate_hit_rate(self, stats: Dict) -> float:
        """Calculate cache hit rate"""

        hits = int(stats.get(b"hits", 0))
        misses = int(stats.get(b"misses", 0))
        total = hits + misses

        if total == 0:
            return 0.0

        return (hits / total) * 100

    def clear_cache(self):
        """Clear all cached entries"""
        self.redis_client.flushdb()

    def clear_old_entries(self, max_age_seconds: int):
        """Clear entries older than max_age_seconds"""

        # Redis TTL handles this automatically, but we can manually scan
        cursor = 0
        deleted = 0

        while True:
            cursor, keys = self.redis_client.scan(
                cursor,
                count=100
            )

            for key in keys:
                if key != b"llm_cache_stats":
                    ttl = self.redis_client.ttl(key)
                    if ttl > 0 and (self.ttl - ttl) > max_age_seconds:
                        self.redis_client.delete(key)
                        deleted += 1

            if cursor == 0:
                break

        return deleted

class CachedLLMWithRedis:
    def __init__(
        self,
        model_name: str,
        redis_host: str = "localhost",
        redis_port: int = 6379
    ):
        """Initialize LLM with Redis caching"""

        self.model_name = model_name
        self.cache = RedisLLMCache(host=redis_host, port=redis_port)

        # Initialize model
        from transformers import pipeline
        self.generator = pipeline(
            "text-generation",
            model=model_name,
            device_map="auto"
        )

    def generate(
        self,
        prompt: str,
        max_length: int = 100,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> str:
        """Generate with Redis caching"""

        params = {
            "max_length": max_length,
            "temperature": temperature,
            "top_p": top_p
        }

        # Check cache
        cached_response = self.cache.get(prompt, self.model_name, params)

        if cached_response:
            print("Cache HIT (Redis)")
            return cached_response

        # Generate
        print("Cache MISS - Generating...")
        start = time.time()

        response = self.generator(
            prompt,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            num_return_sequences=1
        )[0]["generated_text"]

        generation_time = time.time() - start

        # Cache response
        self.cache.set(prompt, self.model_name, params, response)

        print(f"Generated in {generation_time:.2f}s")
        return response

    def get_cache_stats(self) -> Dict:
        """Get cache statistics"""
        return self.cache.get_stats()

# Example usage
redis_llm = CachedLLMWithRedis(
    model_name="gpt2",
    redis_host="localhost"
)

# First request - cache miss
response1 = redis_llm.generate("What is machine learning?")

# Second request - cache hit
response2 = redis_llm.generate("What is machine learning?")

# Different params - cache miss
response3 = redis_llm.generate(
    "What is machine learning?",
    temperature=0.5
)

# Get stats
stats = redis_llm.get_cache_stats()
print(f"\nCache Statistics:")
print(f"Hits: {stats['hits']}")
print(f"Misses: {stats['misses']}")
print(f"Hit Rate: {stats['hit_rate']:.2f}%")

Semantic Caching

python
from sentence_transformers import SentenceTransformer
import numpy as np

class SemanticCache:
    def __init__(
        self,
        similarity_threshold: float = 0.95,
        model_name: str = "all-MiniLM-L6-v2"
    ):
        """
        Semantic cache that matches similar (not just identical) prompts

        Args:
            similarity_threshold: Minimum cosine similarity for cache hit
            model_name: Sentence embedding model
        """
        self.similarity_threshold = similarity_threshold
        self.encoder = SentenceTransformer(model_name)

        self.cache: List[Dict[str, Any]] = []

    def compute_embedding(self, text: str) -> np.ndarray:
        """Compute text embedding"""
        return self.encoder.encode(text, convert_to_numpy=True)

    def cosine_similarity(
        self,
        embedding1: np.ndarray,
        embedding2: np.ndarray
    ) -> float:
        """Calculate cosine similarity"""
        return np.dot(embedding1, embedding2) / (
            np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
        )

    def get(self, prompt: str) -> Optional[Tuple[str, float]]:
        """Get cached response for similar prompt"""

        if not self.cache:
            return None

        # Compute prompt embedding
        prompt_embedding = self.compute_embedding(prompt)

        # Find most similar cached prompt
        best_match = None
        best_similarity = 0.0

        for entry in self.cache:
            similarity = self.cosine_similarity(
                prompt_embedding,
                entry["embedding"]
            )

            if similarity > best_similarity:
                best_similarity = similarity
                best_match = entry

        # Return if above threshold
        if best_similarity >= self.similarity_threshold:
            print(f"Semantic cache HIT (similarity: {best_similarity:.3f})")
            return best_match["response"], best_similarity

        print("Semantic cache MISS")
        return None

    def set(self, prompt: str, response: str):
        """Cache prompt-response pair"""

        embedding = self.compute_embedding(prompt)

        self.cache.append({
            "prompt": prompt,
            "response": response,
            "embedding": embedding,
            "timestamp": datetime.now()
        })

# Example usage
semantic_cache = SemanticCache(similarity_threshold=0.90)

# Cache original
semantic_cache.set(
    "What is artificial intelligence?",
    "AI is the simulation of human intelligence by machines..."
)

# Try similar prompt
result = semantic_cache.get("Can you explain what AI is?")
if result:
    response, similarity = result
    print(f"Found similar prompt: {response}")

# Try different prompt
result = semantic_cache.get("What is quantum computing?")
# This should miss

Production Caching Strategy

python
class ProductionCacheStrategy:
    def __init__(self):
        self.redis_cache = RedisLLMCache()
        self.semantic_cache = SemanticCache()
        self.kv_cache_manager = None  # Set when model is loaded

    def multi_level_lookup(
        self,
        prompt: str,
        model_name: str,
        params: Dict
    ) -> Optional[str]:
        """
        Multi-level cache lookup:
        1. Exact match (Redis)
        2. Semantic match (embeddings)
        3. KV cache (for prefixes)
        """

        # Level 1: Exact match in Redis
        exact_match = self.redis_cache.get(prompt, model_name, params)
        if exact_match:
            print("L1 Cache HIT (exact)")
            return exact_match

        # Level 2: Semantic match
        semantic_match = self.semantic_cache.get(prompt)
        if semantic_match:
            response, similarity = semantic_match
            print(f"L2 Cache HIT (semantic, similarity: {similarity:.3f})")

            # Also cache in Redis for faster future lookups
            self.redis_cache.set(prompt, model_name, params, response)
            return response

        # Level 3: KV cache will be used during generation
        print("Cache MISS - Will use KV cache during generation")
        return None

    def cache_response(
        self,
        prompt: str,
        response: str,
        model_name: str,
        params: Dict
    ):
        """Cache at all levels"""

        # Cache in Redis
        self.redis_cache.set(prompt, model_name, params, response)

        # Cache in semantic cache
        self.semantic_cache.set(prompt, response)

Quiz

Test your understanding of caching strategies:

Summary

In this lesson, you learned:

  • Cache layers: Response cache, KV cache, and semantic cache
  • Prefix caching: Optimize repeated system prompts and few-shot examples
  • KV cache management: Sliding windows and memory optimization
  • Redis caching: Distributed caching for production systems
  • Semantic caching: Match similar prompts using embeddings

Effective caching can reduce inference costs by 50-90% in production LLM applications while improving response times.