Back
advanced
Optimization & Deployment

Model Serving with vLLM and TGI

Deploy LLMs at scale using vLLM and Text Generation Inference with advanced batching and caching

30 min read· serving· vLLM· TGI· deployment

Model Serving with vLLM and TGI

Learn to deploy LLMs at scale using state-of-the-art serving frameworks that maximize throughput and minimize latency.

What You'll Learn: vLLM and TGI are specialized inference servers that dramatically improve LLM serving performance through techniques like continuous batching, PagedAttention, and optimized CUDA kernels.

Introduction to Model Serving

Serving Challenges

python
import time
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from datetime import datetime

@dataclass
class InferenceRequest:
    id: str
    prompt: str
    max_tokens: int
    timestamp: datetime

class NaiveInferenceServer:
    """Demonstrates why naive serving is inefficient"""

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.request_queue = []

    def process_request(self, request: InferenceRequest):
        """Process single request (blocking)"""
        start = time.time()

        inputs = self.tokenizer(request.prompt, return_tensors="pt").to(self.model.device)

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=request.max_tokens,
            do_sample=False
        )

        text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        latency = time.time() - start

        return {
            "request_id": request.id,
            "output": text,
            "latency": latency
        }

    def benchmark_naive_serving(self, num_requests: int = 10):
        """Benchmark naive sequential processing"""

        requests = [
            InferenceRequest(
                id=f"req_{i}",
                prompt=f"Explain topic {i} in detail:",
                max_tokens=100,
                timestamp=datetime.now()
            )
            for i in range(num_requests)
        ]

        start = time.time()
        results = []

        for req in requests:
            result = self.process_request(req)
            results.append(result)

        total_time = time.time() - start

        # Calculate metrics
        avg_latency = sum(r["latency"] for r in results) / len(results)
        throughput = num_requests / total_time

        return {
            "total_time": total_time,
            "avg_latency": avg_latency,
            "throughput": throughput,
            "requests_per_second": throughput
        }

# The problem: Sequential processing = low throughput
# GPU sits idle while waiting for each request to complete

vLLM: High-Throughput Inference

vLLM uses PagedAttention to manage KV cache efficiently, enabling continuous batching and dramatically higher throughput than naive implementations.

vLLM Setup and Basic Usage

python
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import asyncio

class VLLMServer:
    def __init__(
        self,
        model_name: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        max_model_len: int = 2048
    ):
        """
        Initialize vLLM server

        Args:
            model_name: HuggingFace model name or path
            tensor_parallel_size: Number of GPUs for tensor parallelism
            gpu_memory_utilization: GPU memory usage (0.0-1.0)
            max_model_len: Maximum sequence length
        """

        self.llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=max_model_len,
            trust_remote_code=True
        )

        print(f"vLLM server initialized with {model_name}")

    def generate(
        self,
        prompts: List[str],
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 100,
        n: int = 1
    ) -> List[Dict]:
        """
        Generate completions for multiple prompts

        Args:
            prompts: List of input prompts
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            max_tokens: Maximum tokens to generate
            n: Number of completions per prompt
        """

        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            n=n
        )

        # Generate (automatically batched)
        outputs = self.llm.generate(prompts, sampling_params)

        results = []
        for output in outputs:
            results.append({
                "prompt": output.prompt,
                "outputs": [o.text for o in output.outputs],
                "tokens_generated": sum(len(o.token_ids) for o in output.outputs)
            })

        return results

    def benchmark_throughput(self, num_requests: int = 100):
        """Benchmark vLLM throughput"""
        import time

        # Generate test prompts
        prompts = [
            f"Write a detailed explanation about topic {i}:"
            for i in range(num_requests)
        ]

        start = time.time()
        results = self.generate(prompts, max_tokens=100)
        elapsed = time.time() - start

        total_tokens = sum(r["tokens_generated"] for r in results)

        metrics = {
            "total_time": elapsed,
            "num_requests": num_requests,
            "throughput_requests_per_sec": num_requests / elapsed,
            "throughput_tokens_per_sec": total_tokens / elapsed,
            "avg_latency": elapsed / num_requests
        }

        return metrics

# Example usage
vllm_server = VLLMServer(
    model_name="meta-llama/Llama-2-7b-hf",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9
)

# Generate for multiple prompts (automatically batched)
prompts = [
    "Explain quantum computing:",
    "What is machine learning?",
    "Describe neural networks:"
]

results = vllm_server.generate(prompts, temperature=0.7, max_tokens=50)

for result in results:
    print(f"Prompt: {result['prompt']}")
    print(f"Output: {result['outputs'][0]}\n")

# Benchmark
metrics = vllm_server.benchmark_throughput(num_requests=50)
print(f"\nvLLM Benchmark:")
print(f"Throughput: {metrics['throughput_requests_per_sec']:.2f} req/s")
print(f"Token throughput: {metrics['throughput_tokens_per_sec']:.2f} tokens/s")

Async vLLM Server

python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import uvicorn

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 0.7
    top_p: float = 0.9
    n: int = 1
    stream: bool = False

class GenerationResponse(BaseModel):
    text: str
    tokens_generated: int
    finish_reason: str

class AsyncVLLMServer:
    def __init__(self, model_name: str):
        # Initialize async engine
        engine_args = AsyncEngineArgs(
            model=model_name,
            tensor_parallel_size=1,
            gpu_memory_utilization=0.9
        )

        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    async def generate(
        self,
        prompt: str,
        sampling_params: SamplingParams
    ):
        """Async generation"""

        request_id = f"req_{time.time()}"

        # Generate
        results = []
        async for request_output in self.engine.generate(
            prompt,
            sampling_params,
            request_id
        ):
            # Collect outputs
            for output in request_output.outputs:
                results.append({
                    "text": output.text,
                    "tokens_generated": len(output.token_ids),
                    "finish_reason": output.finish_reason
                })

        return results

    async def stream_generate(
        self,
        prompt: str,
        sampling_params: SamplingParams
    ):
        """Stream generation token by token"""

        request_id = f"req_{time.time()}"

        async for request_output in self.engine.generate(
            prompt,
            sampling_params,
            request_id
        ):
            for output in request_output.outputs:
                yield output.text

# FastAPI application
app = FastAPI(title="vLLM Inference Server")

# Global server instance
vllm_async_server = None

@app.on_event("startup")
async def startup():
    global vllm_async_server
    vllm_async_server = AsyncVLLMServer(model_name="meta-llama/Llama-2-7b-hf")

@app.post("/v1/completions", response_model=List[GenerationResponse])
async def completions(request: GenerationRequest):
    """Generate completions"""

    if not vllm_async_server:
        raise HTTPException(status_code=503, detail="Server not initialized")

    sampling_params = SamplingParams(
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens,
        n=request.n
    )

    try:
        results = await vllm_async_server.generate(
            request.prompt,
            sampling_params
        )
        return results
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/completions/stream")
async def completions_stream(request: GenerationRequest):
    """Stream completions"""

    if not vllm_async_server:
        raise HTTPException(status_code=503, detail="Server not initialized")

    sampling_params = SamplingParams(
        temperature=request.temperature,
        top_p=request.top_p,
        max_tokens=request.max_tokens
    )

    async def generate():
        async for text in vllm_async_server.stream_generate(
            request.prompt,
            sampling_params
        ):
            yield f"data: {text}\n\n"

    from fastapi.responses import StreamingResponse
    return StreamingResponse(generate(), media_type="text/event-stream")

# Run server: uvicorn script:app --host 0.0.0.0 --port 8000

Text Generation Inference (TGI)

TGI is HuggingFace's production-ready inference server with built-in support for quantization, distributed inference, and monitoring.

TGI Deployment

python
import subprocess
import requests
from typing import Dict, Optional
import docker

class TGIServer:
    def __init__(
        self,
        model_id: str,
        port: int = 8080,
        num_shard: int = 1,
        quantize: Optional[str] = None
    ):
        """
        Deploy model with TGI

        Args:
            model_id: HuggingFace model ID
            port: Server port
            num_shard: Number of shards (GPUs)
            quantize: Quantization method (bitsandbytes, gptq, awq)
        """

        self.model_id = model_id
        self.port = port
        self.base_url = f"http://localhost:{port}"
        self.container = None

        # Build docker command
        docker_cmd = [
            "docker", "run", "-d",
            "--gpus", "all",
            "--shm-size", "1g",
            "-p", f"{port}:80",
            "-v", "$HOME/.cache/huggingface:/data",
            "-e", f"MODEL_ID={model_id}",
            "-e", f"NUM_SHARD={num_shard}",
            "-e", "MAX_TOTAL_TOKENS=2048",
            "-e", "MAX_INPUT_LENGTH=1024",
        ]

        if quantize:
            docker_cmd.extend(["-e", f"QUANTIZE={quantize}"])

        docker_cmd.append("ghcr.io/huggingface/text-generation-inference:latest")

        # Start container
        print(f"Starting TGI server for {model_id}...")
        result = subprocess.run(docker_cmd, capture_output=True, text=True)

        if result.returncode == 0:
            self.container_id = result.stdout.strip()
            print(f"TGI server started: {self.container_id[:12]}")
            self.wait_for_ready()
        else:
            raise RuntimeError(f"Failed to start TGI: {result.stderr}")

    def wait_for_ready(self, timeout: int = 300):
        """Wait for server to be ready"""
        import time

        start = time.time()
        while time.time() - start < timeout:
            try:
                response = requests.get(f"{self.base_url}/health")
                if response.status_code == 200:
                    print("TGI server ready!")
                    return
            except requests.exceptions.ConnectionError:
                pass

            time.sleep(5)

        raise TimeoutError("TGI server failed to start")

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True
    ) -> Dict:
        """Generate completion"""

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "do_sample": do_sample
            }
        }

        response = requests.post(
            f"{self.base_url}/generate",
            json=payload
        )

        if response.status_code == 200:
            return response.json()
        else:
            raise RuntimeError(f"Generation failed: {response.text}")

    def stream_generate(self, prompt: str, **kwargs):
        """Stream generation"""

        payload = {
            "inputs": prompt,
            "parameters": kwargs
        }

        response = requests.post(
            f"{self.base_url}/generate_stream",
            json=payload,
            stream=True
        )

        for line in response.iter_lines():
            if line:
                yield line.decode('utf-8')

    def get_info(self) -> Dict:
        """Get model info"""
        response = requests.get(f"{self.base_url}/info")
        return response.json()

    def get_metrics(self) -> str:
        """Get Prometheus metrics"""
        response = requests.get(f"{self.base_url}/metrics")
        return response.text

    def stop(self):
        """Stop TGI server"""
        if self.container_id:
            subprocess.run(["docker", "stop", self.container_id])
            subprocess.run(["docker", "rm", self.container_id])
            print("TGI server stopped")

# Example usage
tgi_server = TGIServer(
    model_id="meta-llama/Llama-2-7b-hf",
    port=8080,
    num_shard=1,
    quantize="bitsandbytes"  # Use 8-bit quantization
)

# Generate
result = tgi_server.generate(
    "Explain the theory of relativity:",
    max_new_tokens=100
)
print(result["generated_text"])

# Stream
print("\nStreaming generation:")
for chunk in tgi_server.stream_generate(
    "Write a story about AI:",
    max_new_tokens=100
):
    print(chunk, end="", flush=True)

# Get info
info = tgi_server.get_info()
print(f"\nModel: {info['model_id']}")
print(f"Device: {info['device_type']}")

# Stop server
tgi_server.stop()

TGI with Python Client

python
from huggingface_hub import InferenceClient
from typing import Iterator

class TGIClient:
    def __init__(self, endpoint: str, token: Optional[str] = None):
        """
        Initialize TGI client

        Args:
            endpoint: TGI server endpoint (local or HF Inference Endpoints)
            token: HuggingFace token for private models
        """
        self.client = InferenceClient(model=endpoint, token=token)

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        **kwargs
    ) -> str:
        """Generate completion"""

        response = self.client.text_generation(
            prompt,
            max_new_tokens=max_new_tokens,
            **kwargs
        )

        return response

    def stream_generate(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        **kwargs
    ) -> Iterator[str]:
        """Stream generation"""

        for token in self.client.text_generation(
            prompt,
            max_new_tokens=max_new_tokens,
            stream=True,
            **kwargs
        ):
            yield token

    def batch_generate(
        self,
        prompts: List[str],
        max_new_tokens: int = 100,
        **kwargs
    ) -> List[str]:
        """Batch generation"""

        results = []
        for prompt in prompts:
            result = self.generate(prompt, max_new_tokens, **kwargs)
            results.append(result)

        return results

    def chat(
        self,
        messages: List[Dict[str, str]],
        max_new_tokens: int = 100,
        **kwargs
    ) -> str:
        """Chat completion (for chat-tuned models)"""

        # Convert messages to prompt format
        prompt = self._format_chat_prompt(messages)

        return self.generate(prompt, max_new_tokens, **kwargs)

    def _format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
        """Format messages for chat models"""

        # Example format (adjust for your model)
        prompt = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]

            if role == "system":
                prompt += f"System: {content}\n"
            elif role == "user":
                prompt += f"User: {content}\n"
            elif role == "assistant":
                prompt += f"Assistant: {content}\n"

        prompt += "Assistant:"
        return prompt

# Example usage
client = TGIClient(endpoint="http://localhost:8080")

# Generate
response = client.generate("Explain machine learning:")
print(response)

# Stream
for token in client.stream_generate("Write a poem:"):
    print(token, end="", flush=True)

# Batch
prompts = [
    "What is Python?",
    "What is JavaScript?",
    "What is Rust?"
]
results = client.batch_generate(prompts)

# Chat
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What is AI?"}
]
response = client.chat(messages)
print(response)

Advanced Batching Strategies

Continuous Batching: Unlike traditional batching, continuous batching allows new requests to join the batch as soon as previous requests complete, maximizing GPU utilization.

python
import asyncio
from collections import deque
from typing import List, Optional
import time

class ContinuousBatchingEngine:
    def __init__(
        self,
        model,
        tokenizer,
        max_batch_size: int = 8,
        max_wait_time: float = 0.1
    ):
        """
        Continuous batching inference engine

        Args:
            model: LLM model
            tokenizer: Tokenizer
            max_batch_size: Maximum batch size
            max_wait_time: Maximum time to wait for batch to fill
        """
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time

        self.request_queue = asyncio.Queue()
        self.running = False

    async def add_request(
        self,
        prompt: str,
        max_tokens: int = 100
    ) -> str:
        """Add request to queue"""

        future = asyncio.Future()

        await self.request_queue.put({
            "prompt": prompt,
            "max_tokens": max_tokens,
            "future": future
        })

        # Wait for result
        result = await future
        return result

    async def process_batch(self, requests: List[Dict]):
        """Process a batch of requests"""

        prompts = [req["prompt"] for req in requests]

        # Tokenize batch
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.model.device)

        # Generate
        max_tokens = max(req["max_tokens"] for req in requests)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                pad_token_id=self.tokenizer.eos_token_id
            )

        # Decode and return results
        for i, req in enumerate(requests):
            text = self.tokenizer.decode(outputs[i], skip_special_tokens=True)
            req["future"].set_result(text)

    async def run(self):
        """Main processing loop"""

        self.running = True

        while self.running:
            batch = []
            batch_start = time.time()

            # Collect requests for batch
            while len(batch) < self.max_batch_size:
                timeout = self.max_wait_time - (time.time() - batch_start)

                if timeout <= 0:
                    break

                try:
                    request = await asyncio.wait_for(
                        self.request_queue.get(),
                        timeout=timeout
                    )
                    batch.append(request)
                except asyncio.TimeoutError:
                    break

            # Process batch if not empty
            if batch:
                await self.process_batch(batch)

            # Small delay to prevent busy waiting
            await asyncio.sleep(0.001)

    async def start(self):
        """Start the engine"""
        self.task = asyncio.create_task(self.run())

    async def stop(self):
        """Stop the engine"""
        self.running = False
        await self.task

# Example usage
async def test_continuous_batching():
    from transformers import AutoModelForCausalLM, AutoTokenizer

    model = AutoModelForCausalLM.from_pretrained(
        "gpt2",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Create engine
    engine = ContinuousBatchingEngine(
        model=model,
        tokenizer=tokenizer,
        max_batch_size=4,
        max_wait_time=0.1
    )

    # Start engine
    await engine.start()

    # Send requests concurrently
    async def send_request(i):
        start = time.time()
        result = await engine.add_request(
            prompt=f"Topic {i}:",
            max_tokens=50
        )
        latency = time.time() - start
        return result, latency

    # Create 10 concurrent requests
    tasks = [send_request(i) for i in range(10)]
    results = await asyncio.gather(*tasks)

    # Calculate metrics
    latencies = [r[1] for r in results]
    print(f"\nAverage latency: {sum(latencies)/len(latencies):.3f}s")
    print(f"Total time: {max(latencies):.3f}s")
    print(f"Throughput: {len(results)/max(latencies):.2f} req/s")

    # Stop engine
    await engine.stop()

# Run test
# asyncio.run(test_continuous_batching())

Production Deployment

python
from fastapi import FastAPI, BackgroundTasks
from prometheus_client import Counter, Histogram, generate_latest
import logging

# Metrics
request_counter = Counter('inference_requests_total', 'Total inference requests')
request_duration = Histogram('inference_duration_seconds', 'Inference duration')
token_counter = Counter('tokens_generated_total', 'Total tokens generated')

class ProductionInferenceServer:
    def __init__(self, model_name: str):
        self.vllm = VLLMServer(model_name)
        self.logger = self._setup_logging()

    def _setup_logging(self):
        """Setup logging"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        return logging.getLogger(__name__)

    async def generate_with_metrics(
        self,
        prompt: str,
        max_tokens: int = 100,
        **kwargs
    ):
        """Generate with metrics and logging"""

        request_counter.inc()
        self.logger.info(f"Request received: {prompt[:50]}...")

        with request_duration.time():
            results = self.vllm.generate(
                [prompt],
                max_tokens=max_tokens,
                **kwargs
            )

        tokens_generated = results[0]["tokens_generated"]
        token_counter.inc(tokens_generated)

        self.logger.info(f"Generated {tokens_generated} tokens")

        return results[0]

# FastAPI app
app = FastAPI(title="Production Inference Server")

@app.get("/metrics")
async def metrics():
    """Prometheus metrics endpoint"""
    return generate_latest()

@app.post("/v1/completions")
async def completions(request: GenerationRequest):
    """Generate completion with monitoring"""
    server = ProductionInferenceServer("meta-llama/Llama-2-7b-hf")

    result = await server.generate_with_metrics(
        request.prompt,
        request.max_tokens,
        temperature=request.temperature,
        top_p=request.top_p
    )

    return result

Quiz

Test your understanding of model serving:

Summary

In this lesson, you learned:

  • Model serving challenges: Why naive approaches are inefficient
  • vLLM: High-throughput inference with PagedAttention and continuous batching
  • TGI: Production-ready serving with built-in optimization and monitoring
  • Batching strategies: Continuous batching for maximum GPU utilization
  • Production deployment: Monitoring, metrics, and scaling strategies

Effective model serving is crucial for production LLM applications, directly impacting cost, latency, and user experience.