Caching Strategies for LLMs
Learn advanced caching techniques to dramatically reduce inference costs and latency for LLM applications.
What You'll Learn: Caching is one of the most effective ways to optimize LLM applications. We'll explore prompt caching, KV cache management, and distributed caching with Redis.
Understanding LLM Caching
Cache Types
from typing import Dict, Optional, List, Tuple
import hashlib
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
@dataclass
class CacheEntry:
key: str
value: str
created_at: datetime
hit_count: int = 0
last_accessed: datetime = None
def __post_init__(self):
if self.last_accessed is None:
self.last_accessed = self.created_at
class LLMCachingLayers:
"""Demonstrates different caching layers"""
def __init__(self):
# Layer 1: Response cache (full prompt + response)
self.response_cache: Dict[str, CacheEntry] = {}
# Layer 2: KV cache (prefix cache for transformers)
self.kv_cache: Dict[str, any] = {}
# Layer 3: Semantic cache (similar prompts)
self.semantic_cache: Dict[str, List[Tuple[str, str]]] = {}
def compute_cache_key(self, prompt: str, model_params: Dict) -> str:
"""Generate cache key from prompt and parameters"""
# Include model parameters in key
key_string = f"{prompt}|{sorted(model_params.items())}"
return hashlib.sha256(key_string.encode()).hexdigest()
def get_response_cache(
self,
prompt: str,
model_params: Dict
) -> Optional[str]:
"""Get cached response"""
key = self.compute_cache_key(prompt, model_params)
if key in self.response_cache:
entry = self.response_cache[key]
entry.hit_count += 1
entry.last_accessed = datetime.now()
print(f"Cache HIT! (hits: {entry.hit_count})")
return entry.value
print("Cache MISS")
return None
def set_response_cache(
self,
prompt: str,
response: str,
model_params: Dict
):
"""Store response in cache"""
key = self.compute_cache_key(prompt, model_params)
self.response_cache[key] = CacheEntry(
key=key,
value=response,
created_at=datetime.now()
)
def demonstrate_caching_impact(self):
"""Demonstrate cache impact on performance"""
model_params = {"temperature": 0.7, "max_tokens": 100}
# First request (cache miss)
start = time.time()
result1 = self.get_response_cache("What is AI?", model_params)
if result1 is None:
# Simulate LLM inference (expensive)
time.sleep(2)
result1 = "AI is artificial intelligence..."
self.set_response_cache("What is AI?", result1, model_params)
time1 = time.time() - start
# Second request (cache hit)
start = time.time()
result2 = self.get_response_cache("What is AI?", model_params)
time2 = time.time() - start
print(f"\nFirst request: {time1:.3f}s")
print(f"Second request: {time2:.3f}s")
print(f"Speedup: {time1/time2:.1f}x")
# Example usage
cache = LLMCachingLayers()
cache.demonstrate_caching_impact()
Prompt Caching
Prompt Caching: Cache common prompt prefixes to avoid recomputing attention for repeated content. This is especially effective for system prompts and few-shot examples.
Prefix Caching Implementation
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Tuple, Optional
class PrefixCache:
def __init__(self, max_size: int = 100):
"""
Prefix cache for transformer KV states
Args:
max_size: Maximum number of cached prefixes
"""
self.cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
self.max_size = max_size
self.access_times: Dict[str, float] = {}
def get_cache_key(self, prefix: str) -> str:
"""Generate cache key for prefix"""
return hashlib.sha256(prefix.encode()).hexdigest()
def get(self, prefix: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Get cached KV states for prefix"""
key = self.get_cache_key(prefix)
if key in self.cache:
self.access_times[key] = time.time()
return self.cache[key]
return None
def set(
self,
prefix: str,
past_key_values: Tuple[torch.Tensor, torch.Tensor]
):
"""Cache KV states for prefix"""
# Evict oldest if cache is full
if len(self.cache) >= self.max_size:
oldest_key = min(self.access_times, key=self.access_times.get)
del self.cache[oldest_key]
del self.access_times[oldest_key]
key = self.get_cache_key(prefix)
self.cache[key] = past_key_values
self.access_times[key] = time.time()
def clear(self):
"""Clear cache"""
self.cache.clear()
self.access_times.clear()
class CachedLLM:
def __init__(self, model_name: str):
"""Initialize model with prefix caching"""
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.prefix_cache = PrefixCache(max_size=50)
def generate_with_cache(
self,
prompt: str,
system_prompt: str = "",
max_new_tokens: int = 100
) -> str:
"""
Generate with prefix caching
Args:
prompt: User prompt
system_prompt: System prompt (will be cached)
max_new_tokens: Maximum tokens to generate
"""
# Try to get cached KV states for system prompt
past_key_values = None
prefix_tokens = None
if system_prompt:
past_key_values = self.prefix_cache.get(system_prompt)
if past_key_values:
print("Using cached system prompt!")
# Tokenize only the user prompt
prefix_tokens = self.tokenizer(
system_prompt,
return_tensors="pt"
)["input_ids"]
user_tokens = self.tokenizer(
prompt,
return_tensors="pt"
)["input_ids"]
input_ids = torch.cat([prefix_tokens, user_tokens], dim=1)
else:
print("System prompt not cached, computing...")
# Tokenize full prompt
full_prompt = system_prompt + prompt
input_ids = self.tokenizer(
full_prompt,
return_tensors="pt"
)["input_ids"]
# Compute KV states for system prompt
prefix_tokens = self.tokenizer(
system_prompt,
return_tensors="pt"
)["input_ids"]
with torch.no_grad():
outputs = self.model(
prefix_tokens.to(self.model.device),
use_cache=True
)
past_key_values = outputs.past_key_values
# Cache the KV states
self.prefix_cache.set(system_prompt, past_key_values)
else:
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
# Generate with cached past_key_values
input_ids = input_ids.to(self.model.device)
# If we have cached past, only use the new tokens
if past_key_values and prefix_tokens is not None:
# Use only user prompt tokens for generation
input_ids = input_ids[:, prefix_tokens.shape[1]:]
with torch.no_grad():
outputs = self.model.generate(
input_ids,
past_key_values=past_key_values,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def benchmark_caching(self):
"""Benchmark prefix caching benefits"""
system_prompt = """You are a helpful AI assistant. You are expert in
mathematics, science, and programming. You provide clear, accurate,
and detailed explanations. """
prompts = [
"What is calculus?",
"Explain quantum physics.",
"How does a computer work?",
"What is machine learning?"
]
# Without cache
self.prefix_cache.clear()
start = time.time()
for prompt in prompts:
_ = self.generate_with_cache(prompt, system_prompt, max_new_tokens=50)
time_no_cache = time.time() - start
# With cache (second run)
start = time.time()
for prompt in prompts:
_ = self.generate_with_cache(prompt, system_prompt, max_new_tokens=50)
time_with_cache = time.time() - start
print(f"\nWithout cache: {time_no_cache:.2f}s")
print(f"With cache: {time_with_cache:.2f}s")
print(f"Speedup: {time_no_cache/time_with_cache:.2f}x")
# Example usage
cached_llm = CachedLLM("gpt2")
system_prompt = "You are a helpful assistant. "
# First call - cache miss
response1 = cached_llm.generate_with_cache(
"What is Python?",
system_prompt=system_prompt
)
# Second call - cache hit
response2 = cached_llm.generate_with_cache(
"What is JavaScript?",
system_prompt=system_prompt
)
# Benchmark
cached_llm.benchmark_caching()
KV Cache Management
KV Cache: Transformers cache key and value tensors during generation to avoid recomputing attention for previous tokens. Managing this cache is crucial for memory efficiency.
Advanced KV Cache Strategies
class KVCacheManager:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def estimate_kv_cache_size(
self,
batch_size: int,
sequence_length: int,
num_layers: int = None,
hidden_size: int = None,
num_heads: int = None
) -> Dict[str, float]:
"""
Estimate KV cache memory usage
Formula: 2 * batch_size * num_layers * num_heads * seq_len * (hidden_size / num_heads) * bytes_per_element
"""
if num_layers is None:
num_layers = self.model.config.num_hidden_layers
if hidden_size is None:
hidden_size = self.model.config.hidden_size
if num_heads is None:
num_heads = self.model.config.num_attention_heads
head_dim = hidden_size // num_heads
bytes_per_element = 2 # FP16
# KV cache size (keys + values)
cache_size_bytes = (
2 * batch_size * num_layers * num_heads *
sequence_length * head_dim * bytes_per_element
)
return {
"total_bytes": cache_size_bytes,
"total_mb": cache_size_bytes / (1024 ** 2),
"total_gb": cache_size_bytes / (1024 ** 3),
"per_token_mb": cache_size_bytes / sequence_length / (1024 ** 2)
}
def sliding_window_generation(
self,
prompt: str,
max_new_tokens: int = 100,
window_size: int = 512
) -> str:
"""
Generate with sliding window KV cache
Keeps only the last `window_size` tokens in KV cache
to limit memory usage for long generations.
"""
input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
input_ids = input_ids.to(self.model.device)
generated_tokens = []
past_key_values = None
for _ in range(max_new_tokens):
# If past is too long, trim it
if past_key_values is not None:
# Get the length of cached sequence
cached_len = past_key_values[0][0].shape[2]
if cached_len > window_size:
# Trim past_key_values to window_size
trimmed_past = []
for layer_past in past_key_values:
trimmed_layer = tuple(
p[:, :, -window_size:, :] for p in layer_past
)
trimmed_past.append(trimmed_layer)
past_key_values = tuple(trimmed_past)
# Adjust input_ids (use only last token)
input_ids = input_ids[:, -1:]
# Forward pass
with torch.no_grad():
outputs = self.model(
input_ids,
past_key_values=past_key_values,
use_cache=True
)
# Get next token
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
generated_tokens.append(next_token.item())
# Update for next iteration
input_ids = next_token
past_key_values = outputs.past_key_values
# Stop if EOS
if next_token.item() == self.tokenizer.eos_token_id:
break
# Decode
full_output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return full_output
def demonstrate_memory_savings(self):
"""Demonstrate memory savings from KV cache management"""
# Long sequence scenario
batch_size = 1
seq_length = 2048
# Without windowing
full_cache = self.estimate_kv_cache_size(batch_size, seq_length)
# With windowing (512 tokens)
windowed_cache = self.estimate_kv_cache_size(batch_size, 512)
print("KV Cache Memory Usage:")
print(f"Full cache ({seq_length} tokens): {full_cache['total_mb']:.2f} MB")
print(f"Windowed cache (512 tokens): {windowed_cache['total_mb']:.2f} MB")
print(f"Memory saved: {(1 - windowed_cache['total_mb']/full_cache['total_mb'])*100:.1f}%")
# Example usage
kv_manager = KVCacheManager(cached_llm.model, cached_llm.tokenizer)
# Estimate cache size
cache_info = kv_manager.estimate_kv_cache_size(
batch_size=4,
sequence_length=1024
)
print(f"Estimated KV cache: {cache_info['total_mb']:.2f} MB")
# Sliding window generation
response = kv_manager.sliding_window_generation(
"Write a long story about space exploration:",
max_new_tokens=200,
window_size=512
)
# Demonstrate savings
kv_manager.demonstrate_memory_savings()
Redis Caching for Distributed Systems
Redis Caching: For distributed applications, use Redis to share cached responses across multiple instances and reduce overall inference costs.
Redis Cache Implementation
import redis
import json
from typing import Optional, Dict, Any
import pickle
class RedisLLMCache:
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
ttl: int = 3600 # 1 hour default TTL
):
"""
Initialize Redis cache for LLM responses
Args:
host: Redis host
port: Redis port
db: Redis database number
ttl: Time to live in seconds
"""
self.redis_client = redis.Redis(
host=host,
port=port,
db=db,
decode_responses=False # We'll use pickle for complex objects
)
self.ttl = ttl
def generate_cache_key(
self,
prompt: str,
model_name: str,
params: Dict[str, Any]
) -> str:
"""Generate unique cache key"""
# Sort params for consistency
param_str = json.dumps(params, sort_keys=True)
key_content = f"{model_name}:{prompt}:{param_str}"
return hashlib.sha256(key_content.encode()).hexdigest()
def get(
self,
prompt: str,
model_name: str,
params: Dict[str, Any]
) -> Optional[str]:
"""Get cached response"""
key = self.generate_cache_key(prompt, model_name, params)
cached_value = self.redis_client.get(key)
if cached_value:
# Increment hit counter
self.redis_client.hincrby("llm_cache_stats", "hits", 1)
# Deserialize
return pickle.loads(cached_value)
# Increment miss counter
self.redis_client.hincrby("llm_cache_stats", "misses", 1)
return None
def set(
self,
prompt: str,
model_name: str,
params: Dict[str, Any],
response: str
):
"""Cache response"""
key = self.generate_cache_key(prompt, model_name, params)
# Serialize and store with TTL
self.redis_client.setex(
key,
self.ttl,
pickle.dumps(response)
)
def get_stats(self) -> Dict[str, int]:
"""Get cache statistics"""
stats = self.redis_client.hgetall("llm_cache_stats")
return {
"hits": int(stats.get(b"hits", 0)),
"misses": int(stats.get(b"misses", 0)),
"hit_rate": self._calculate_hit_rate(stats)
}
def _calculate_hit_rate(self, stats: Dict) -> float:
"""Calculate cache hit rate"""
hits = int(stats.get(b"hits", 0))
misses = int(stats.get(b"misses", 0))
total = hits + misses
if total == 0:
return 0.0
return (hits / total) * 100
def clear_cache(self):
"""Clear all cached entries"""
self.redis_client.flushdb()
def clear_old_entries(self, max_age_seconds: int):
"""Clear entries older than max_age_seconds"""
# Redis TTL handles this automatically, but we can manually scan
cursor = 0
deleted = 0
while True:
cursor, keys = self.redis_client.scan(
cursor,
count=100
)
for key in keys:
if key != b"llm_cache_stats":
ttl = self.redis_client.ttl(key)
if ttl > 0 and (self.ttl - ttl) > max_age_seconds:
self.redis_client.delete(key)
deleted += 1
if cursor == 0:
break
return deleted
class CachedLLMWithRedis:
def __init__(
self,
model_name: str,
redis_host: str = "localhost",
redis_port: int = 6379
):
"""Initialize LLM with Redis caching"""
self.model_name = model_name
self.cache = RedisLLMCache(host=redis_host, port=redis_port)
# Initialize model
from transformers import pipeline
self.generator = pipeline(
"text-generation",
model=model_name,
device_map="auto"
)
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
"""Generate with Redis caching"""
params = {
"max_length": max_length,
"temperature": temperature,
"top_p": top_p
}
# Check cache
cached_response = self.cache.get(prompt, self.model_name, params)
if cached_response:
print("Cache HIT (Redis)")
return cached_response
# Generate
print("Cache MISS - Generating...")
start = time.time()
response = self.generator(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1
)[0]["generated_text"]
generation_time = time.time() - start
# Cache response
self.cache.set(prompt, self.model_name, params, response)
print(f"Generated in {generation_time:.2f}s")
return response
def get_cache_stats(self) -> Dict:
"""Get cache statistics"""
return self.cache.get_stats()
# Example usage
redis_llm = CachedLLMWithRedis(
model_name="gpt2",
redis_host="localhost"
)
# First request - cache miss
response1 = redis_llm.generate("What is machine learning?")
# Second request - cache hit
response2 = redis_llm.generate("What is machine learning?")
# Different params - cache miss
response3 = redis_llm.generate(
"What is machine learning?",
temperature=0.5
)
# Get stats
stats = redis_llm.get_cache_stats()
print(f"\nCache Statistics:")
print(f"Hits: {stats['hits']}")
print(f"Misses: {stats['misses']}")
print(f"Hit Rate: {stats['hit_rate']:.2f}%")
Semantic Caching
from sentence_transformers import SentenceTransformer
import numpy as np
class SemanticCache:
def __init__(
self,
similarity_threshold: float = 0.95,
model_name: str = "all-MiniLM-L6-v2"
):
"""
Semantic cache that matches similar (not just identical) prompts
Args:
similarity_threshold: Minimum cosine similarity for cache hit
model_name: Sentence embedding model
"""
self.similarity_threshold = similarity_threshold
self.encoder = SentenceTransformer(model_name)
self.cache: List[Dict[str, Any]] = []
def compute_embedding(self, text: str) -> np.ndarray:
"""Compute text embedding"""
return self.encoder.encode(text, convert_to_numpy=True)
def cosine_similarity(
self,
embedding1: np.ndarray,
embedding2: np.ndarray
) -> float:
"""Calculate cosine similarity"""
return np.dot(embedding1, embedding2) / (
np.linalg.norm(embedding1) * np.linalg.norm(embedding2)
)
def get(self, prompt: str) -> Optional[Tuple[str, float]]:
"""Get cached response for similar prompt"""
if not self.cache:
return None
# Compute prompt embedding
prompt_embedding = self.compute_embedding(prompt)
# Find most similar cached prompt
best_match = None
best_similarity = 0.0
for entry in self.cache:
similarity = self.cosine_similarity(
prompt_embedding,
entry["embedding"]
)
if similarity > best_similarity:
best_similarity = similarity
best_match = entry
# Return if above threshold
if best_similarity >= self.similarity_threshold:
print(f"Semantic cache HIT (similarity: {best_similarity:.3f})")
return best_match["response"], best_similarity
print("Semantic cache MISS")
return None
def set(self, prompt: str, response: str):
"""Cache prompt-response pair"""
embedding = self.compute_embedding(prompt)
self.cache.append({
"prompt": prompt,
"response": response,
"embedding": embedding,
"timestamp": datetime.now()
})
# Example usage
semantic_cache = SemanticCache(similarity_threshold=0.90)
# Cache original
semantic_cache.set(
"What is artificial intelligence?",
"AI is the simulation of human intelligence by machines..."
)
# Try similar prompt
result = semantic_cache.get("Can you explain what AI is?")
if result:
response, similarity = result
print(f"Found similar prompt: {response}")
# Try different prompt
result = semantic_cache.get("What is quantum computing?")
# This should miss
Production Caching Strategy
class ProductionCacheStrategy:
def __init__(self):
self.redis_cache = RedisLLMCache()
self.semantic_cache = SemanticCache()
self.kv_cache_manager = None # Set when model is loaded
def multi_level_lookup(
self,
prompt: str,
model_name: str,
params: Dict
) -> Optional[str]:
"""
Multi-level cache lookup:
1. Exact match (Redis)
2. Semantic match (embeddings)
3. KV cache (for prefixes)
"""
# Level 1: Exact match in Redis
exact_match = self.redis_cache.get(prompt, model_name, params)
if exact_match:
print("L1 Cache HIT (exact)")
return exact_match
# Level 2: Semantic match
semantic_match = self.semantic_cache.get(prompt)
if semantic_match:
response, similarity = semantic_match
print(f"L2 Cache HIT (semantic, similarity: {similarity:.3f})")
# Also cache in Redis for faster future lookups
self.redis_cache.set(prompt, model_name, params, response)
return response
# Level 3: KV cache will be used during generation
print("Cache MISS - Will use KV cache during generation")
return None
def cache_response(
self,
prompt: str,
response: str,
model_name: str,
params: Dict
):
"""Cache at all levels"""
# Cache in Redis
self.redis_cache.set(prompt, model_name, params, response)
# Cache in semantic cache
self.semantic_cache.set(prompt, response)
Quiz
Test your understanding of caching strategies:
Summary
In this lesson, you learned:
- Cache layers: Response cache, KV cache, and semantic cache
- Prefix caching: Optimize repeated system prompts and few-shot examples
- KV cache management: Sliding windows and memory optimization
- Redis caching: Distributed caching for production systems
- Semantic caching: Match similar prompts using embeddings
Effective caching can reduce inference costs by 50-90% in production LLM applications while improving response times.