diff --git a/examples/mock_backend/main.py b/examples/mock_backend/main.py new file mode 100644 index 000000000..da5978ce4 --- /dev/null +++ b/examples/mock_backend/main.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +""" +Mock vLLM backend used for local Swagger UI and routing smoke tests. + +Features: + - Implements /v1/chat/completions, /v1/completions, /v1/embeddings + - Lightweight, deterministic style responses + - Adjustable port via --port (default 8000) + +This is NOT a production server; it's only for development / CI smoke tests. +""" + +# Copyright 2024-2025 The vLLM Production Stack Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import time +import uuid + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +app = FastAPI(title="Mock vLLM Backend", version="1.0.0") + + +@app.post("/v1/chat/completions") +async def mock_chat_completions( + request: Request, +): # pragma: no cover - exercised in e2e + body = await request.json() + response = { + "id": f"chatcmpl-{uuid.uuid4().hex[:10]}", + "object": "chat.completion", + "created": int(time.time()), + "model": body.get("model", "mock-model"), + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! This is a mock response from the Swagger UI integration test.", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + return JSONResponse(content=response) + + +@app.post("/v1/completions") +async def mock_completions(request: Request): # pragma: no cover + body = await request.json() + response = { + "id": f"cmpl-{uuid.uuid4().hex[:10]}", + "object": "text_completion", + "created": int(time.time()), + "model": body.get("model", "mock-model"), + "choices": [ + { + "text": " This is a mock completion response.", + "index": 0, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 8, "total_tokens": 13}, + } + return JSONResponse(content=response) + + +@app.post("/v1/embeddings") +async def mock_embeddings(request: Request): # pragma: no cover + body = await request.json() + response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 100, # 500-dim + "index": 0, + } + ], + "model": body.get("model", "mock-embedding-model"), + "usage": {"prompt_tokens": 8, "total_tokens": 8}, + } + return JSONResponse(content=response) + + +@app.get("/health") +async def health(): # pragma: no cover + return {"status": "healthy"} + + +def parse_args(): # pragma: no cover + p = argparse.ArgumentParser() + p.add_argument("--port", type=int, default=8000) + p.add_argument("--host", type=str, default="0.0.0.0") + return p.parse_args() + + +def main(): # pragma: no cover + args = parse_args() + print(f"🚀 Starting Mock vLLM Backend on http://{args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/pyproject.toml b/pyproject.toml index dfc923313..749aa0188 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,5 +60,6 @@ lint = [ ] test = [ "pytest>=8.3.4", - "pytest-asyncio>=0.25.3" + "pytest-asyncio>=0.25.3", + "httpx==0.28.1" ] diff --git a/requirements-test.txt b/requirements-test.txt index f180cb7c9..b59cac88d 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,5 @@ faiss-cpu>=1.7.4 +httpx==0.28.1 huggingface-hub==0.33.0 pytest pytest-asyncio diff --git a/scripts/_swagger_smoke_core.py b/scripts/_swagger_smoke_core.py new file mode 100644 index 000000000..f73dd59a7 --- /dev/null +++ b/scripts/_swagger_smoke_core.py @@ -0,0 +1,109 @@ +"""Reusable Swagger smoke test core logic for CLI + pytest. + +Keep this dependency-light. Do not import internal router modules here. +""" + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from typing import Dict, List + +import requests + + +@dataclass +class TestResult: + name: str + success: bool + detail: str = "" + extra: Dict = field(default_factory=dict) + + +class SwaggerUITester: + def __init__(self, base_url: str): + self.base_url = base_url.rstrip("/") + self.session = requests.Session() + self.results: List[TestResult] = [] + + def record(self, name: str, success: bool, detail: str = "", **extra): + self.results.append(TestResult(name, success, detail, extra)) + + def _url(self, path: str) -> str: + return f"{self.base_url}{path}" + + def test_docs(self): + try: + r = self.session.get(self._url("/docs"), timeout=5) + self.record("/docs", r.status_code == 200, f"status={r.status_code}") + except Exception as e: # pragma: no cover + self.record("/docs", False, str(e)) + + def test_openapi(self): + try: + r = self.session.get(self._url("/openapi.json"), timeout=5) + if r.status_code != 200: + self.record("openapi", False, f"status={r.status_code}") + return + schema = r.json() + paths = schema.get("paths", {}) + expected = ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"] + missing = [p for p in expected if p not in paths] + self.record( + "openapi", not missing, "ok" if not missing else f"missing={missing}" + ) + except Exception as e: # pragma: no cover + self.record("openapi", False, str(e)) + + def test_core_endpoints(self): + # chat valid + chat = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 4, + } + r = self.session.post(self._url("/v1/chat/completions"), json=chat, timeout=8) + self.record("chat_valid", r.status_code == 200, f"status={r.status_code}") + # chat 422 + r2 = self.session.post( + self._url("/v1/chat/completions"), json={"messages": []}, timeout=5 + ) + self.record("chat_422", r2.status_code == 422, f"status={r2.status_code}") + # completions + comp = {"model": "gpt-3.5-turbo", "prompt": "Hello", "max_tokens": 5} + r3 = self.session.post(self._url("/v1/completions"), json=comp, timeout=8) + self.record( + "completions_valid", r3.status_code == 200, f"status={r3.status_code}" + ) + # embeddings + emb = {"model": "text-embedding-ada-002", "input": "hello"} + r4 = self.session.post(self._url("/v1/embeddings"), json=emb, timeout=8) + self.record( + "embeddings_valid", r4.status_code == 200, f"status={r4.status_code}" + ) + + def run(self): + start = time.time() + self.test_docs() + self.test_openapi() + self.test_core_endpoints() + elapsed = time.time() - start + passed = sum(1 for r in self.results if r.success) + return passed == len(self.results), elapsed, self.results + + +def run_smoke(base_url: str) -> bool: + tester = SwaggerUITester(base_url) + ok, elapsed, results = tester.run() + print( + f"Swagger smoke: {passed_count(results)}/{len(results)} passed in {elapsed:.2f}s" + ) + for r in results: + icon = "✅" if r.success else "❌" + print(f" {icon} {r.name} - {r.detail}") + return ok + + +def passed_count(results: List[TestResult]) -> int: + return sum(1 for r in results if r.success) diff --git a/scripts/swagger_smoke.py b/scripts/swagger_smoke.py new file mode 100644 index 000000000..ec8c689a5 --- /dev/null +++ b/scripts/swagger_smoke.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +"""CLI Swagger smoke test script. + +Uses the same tester logic as the e2e pytest but runnable standalone. +Exit code 0 = all pass; non-zero otherwise. +""" + +# Copyright 2024-2025 The vLLM Production Stack Authors. +# Licensed under the Apache License, Version 2.0. + +from __future__ import annotations + +import os +import sys + +from _swagger_smoke_core import run_smoke + + +def main(): # pragma: no cover + base = ( + sys.argv[1] + if len(sys.argv) > 1 + else os.getenv("SWAGGER_BASE_URL", "http://localhost:8080") + ) + ok = run_smoke(base) + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/src/tests/test_swagger_integration.py b/src/tests/test_swagger_integration.py new file mode 100644 index 000000000..4e753390f --- /dev/null +++ b/src/tests/test_swagger_integration.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +"""Unit tests for Swagger UI integration: request validation & OpenAPI generation.""" + +import json +import os +import sys + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +class MockServiceDiscovery: + def get_endpoint_info(self): + return [] + + def get_health(self): + return True + + +class MockEngineStatsScraper: + def get_health(self): + return True + + +class MockDynamicConfigWatcher: + def get_current_config(self): + class MockConfig: + def to_json_str(self): + return '{"mock": true}' + + return MockConfig() + + +sys.modules["vllm_router.service_discovery"] = type( + "MockModule", (), {"get_service_discovery": lambda: MockServiceDiscovery()} +)() +sys.modules["vllm_router.stats.engine_stats"] = type( + "MockModule", (), {"get_engine_stats_scraper": lambda: MockEngineStatsScraper()} +)() +sys.modules["vllm_router.dynamic_config"] = type( + "MockModule", (), {"get_dynamic_config_watcher": lambda: MockDynamicConfigWatcher()} +)() +sys.modules["vllm_router.version"] = type("MockModule", (), {"__version__": "1.0.0"})() + + +class MockRequestModule: + @staticmethod + async def route_general_request( + request, endpoint, background_tasks, request_body=None + ): + if request_body: + data = json.loads(request_body) + else: + data = await request.json() + return { + "mock_response": True, + "endpoint": endpoint, + "model": data.get("model"), + "request_type": "pydantic" if request_body else "raw", + "data": data, + } + + @staticmethod + def route_sleep_wakeup_request(r, e, b): # pragma: no cover + return {"sleep": True} + + +sys.modules["vllm_router.services.request_service.request"] = MockRequestModule() + +from vllm_router.routers.main_router import main_router # noqa: E402 + +app = FastAPI() +app.include_router(main_router) +client = TestClient(app) + + +class TestSwaggerIntegration: + def test_chat_completions_pydantic_model(self): + resp = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100, + "temperature": 0.7, + }, + ) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["mock_response"] + assert data["endpoint"] == "/v1/chat/completions" + assert data["model"] == "gpt-3.5-turbo" + assert data["request_type"] == "pydantic" + + def test_completions_pydantic_model(self): + resp = client.post( + "/v1/completions", + json={"model": "gpt-3.5-turbo", "prompt": "Hello world", "max_tokens": 50}, + ) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["mock_response"] + assert data["endpoint"] == "/v1/completions" + assert data["model"] == "gpt-3.5-turbo" + assert data["request_type"] == "pydantic" + + def test_embeddings_pydantic_model(self): + resp = client.post( + "/v1/embeddings", + json={"model": "text-embedding-ada-002", "input": "Hello world"}, + ) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["mock_response"] + assert data["endpoint"] == "/v1/embeddings" + assert data["model"] == "text-embedding-ada-002" + assert data["request_type"] == "pydantic" + + def test_extra_fields_handling(self): + resp = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 100, + "unknown_field": "ignored", + }, + ) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["mock_response"] + assert data["request_type"] == "pydantic" + + +class TestSemanticCacheCompatibility: + def test_semantic_cache_uses_raw_request(self): + import vllm_router.routers.main_router as router_module + + received_request_type = None + + async def mock_check_semantic_cache(request): + nonlocal received_request_type + received_request_type = type(request).__name__ + return None + + if hasattr(router_module, "check_semantic_cache"): + original_check = router_module.check_semantic_cache + original_flag = getattr(router_module, "semantic_cache_available", False) + router_module.check_semantic_cache = mock_check_semantic_cache + router_module.semantic_cache_available = True + try: + resp = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Cache test"}], + }, + ) + assert resp.status_code == 200 + assert received_request_type == "Request" + finally: + router_module.check_semantic_cache = original_check + router_module.semantic_cache_available = original_flag + else: + resp = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + assert resp.status_code == 200 + + def test_semantic_cache_with_pydantic_request_body(self): + resp = client.post( + "/v1/chat/completions", + json={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test message"}], + "cache_similarity_threshold": 0.9, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["mock_response"] + assert data["request_type"] == "pydantic" + assert "cache_similarity_threshold" in data["data"] + assert data["data"]["cache_similarity_threshold"] == 0.9 + + +class TestBackwardCompatibility: + def test_validation_errors(self): + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "Hello"}]}, + ) + assert resp.status_code == 422 + error_data = resp.json() + assert "detail" in error_data + assert any("model" in str(err).lower() for err in error_data["detail"]) + + def test_invalid_json(self): + resp = client.post( + "/v1/chat/completions", + data="invalid json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + def test_openapi_schema_generation(self): + resp = client.get("/openapi.json") + assert resp.status_code == 200 + schema = resp.json() + assert "paths" in schema + paths = schema["paths"] + for p in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: + assert p in paths + chat_post = paths["/v1/chat/completions"]["post"] + assert "requestBody" in chat_post + rb = chat_post["requestBody"]["content"]["application/json"] + assert "$ref" in rb["schema"] + + +if __name__ == "__main__": # pragma: no cover + pytest.main([__file__, "-v"]) diff --git a/src/vllm_router/protocols.py b/src/vllm_router/protocols.py index 449ed5214..d3526b7f8 100644 --- a/src/vllm_router/protocols.py +++ b/src/vllm_router/protocols.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional +from typing import List, Optional, Union from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -54,3 +54,73 @@ class ModelCard(OpenAIBaseModel): class ModelList(OpenAIBaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) + + +# ===== Core Request Models ===== +# Based on vLLM official protocol.py definitions + + +class ChatCompletionRequest(OpenAIBaseModel): + """ChatCompletion API request model based on OpenAI specification""" + + # Core required fields + messages: List[dict] # Simplified message type to avoid complex nested definitions + model: str # Required field according to OpenAI API spec + + # Core sampling parameters + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + + # Core control parameters + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + n: Optional[int] = 1 + + # Other common parameters + seed: Optional[int] = None + user: Optional[str] = None + + +class CompletionRequest(OpenAIBaseModel): + """Completion API request model based on OpenAI specification""" + + # Core required fields + prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None + model: str # Required field according to OpenAI API spec + + # Core sampling parameters + max_tokens: Optional[int] = 16 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + + # Core control parameters + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + n: int = 1 + echo: Optional[bool] = False + + # Other common parameters + seed: Optional[int] = None + user: Optional[str] = None + best_of: Optional[int] = None + logprobs: Optional[int] = None + suffix: Optional[str] = None + + +class EmbeddingRequest(OpenAIBaseModel): + """Embedding API request model based on OpenAI specification""" + + # Core required fields + input: Union[str, List[str], List[int], List[List[int]]] + model: str # Required field according to OpenAI API spec + + # Core control parameters + encoding_format: Optional[str] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None diff --git a/src/vllm_router/routers/main_router.py b/src/vllm_router/routers/main_router.py index 5d77124dd..a052dffac 100644 --- a/src/vllm_router/routers/main_router.py +++ b/src/vllm_router/routers/main_router.py @@ -22,16 +22,31 @@ from vllm_router.dynamic_config import get_dynamic_config_watcher from vllm_router.log import init_logger -from vllm_router.protocols import ModelCard, ModelList +from vllm_router.protocols import ( + ChatCompletionRequest, + CompletionRequest, + EmbeddingRequest, + ModelCard, + ModelList, +) from vllm_router.service_discovery import get_service_discovery from vllm_router.services.request_service.request import ( route_general_request, - route_general_transcriptions, route_sleep_wakeup_request, ) from vllm_router.stats.engine_stats import get_engine_stats_scraper from vllm_router.version import __version__ +try: + # Semantic cache integration + from vllm_router.services.request_service.request import ( + route_general_transcriptions, + ) + + _route_general_transcriptions = True +except ImportError: + _route_general_transcriptions = False + try: # Semantic cache integration from vllm_router.experimental.semantic_cache_integration import ( @@ -48,30 +63,52 @@ @main_router.post("/v1/chat/completions") -async def route_chat_completion(request: Request, background_tasks: BackgroundTasks): +async def route_chat_completion( + request: ChatCompletionRequest, + raw_request: Request, + background_tasks: BackgroundTasks, +): if semantic_cache_available: # Check if the request can be served from the semantic cache logger.debug("Received chat completion request, checking semantic cache") - cache_response = await check_semantic_cache(request=request) + cache_response = await check_semantic_cache(request=raw_request) if cache_response: logger.info("Serving response from semantic cache") return cache_response logger.debug("No cache hit, forwarding request to backend") + + # Convert Pydantic model to JSON bytes for existing service + request_body = request.model_dump_json().encode("utf-8") + return await route_general_request( - request, "/v1/chat/completions", background_tasks + raw_request, "/v1/chat/completions", background_tasks, request_body ) @main_router.post("/v1/completions") -async def route_completion(request: Request, background_tasks: BackgroundTasks): - return await route_general_request(request, "/v1/completions", background_tasks) +async def route_completion( + request: CompletionRequest, raw_request: Request, background_tasks: BackgroundTasks +): + # Convert Pydantic model to JSON bytes for existing service + request_body = request.model_dump_json().encode("utf-8") + + return await route_general_request( + raw_request, "/v1/completions", background_tasks, request_body + ) @main_router.post("/v1/embeddings") -async def route_embeddings(request: Request, background_tasks: BackgroundTasks): - return await route_general_request(request, "/v1/embeddings", background_tasks) +async def route_embeddings( + request: EmbeddingRequest, raw_request: Request, background_tasks: BackgroundTasks +): + # Convert Pydantic model to JSON bytes for existing service + request_body = request.model_dump_json().encode("utf-8") + + return await route_general_request( + raw_request, "/v1/embeddings", background_tasks, request_body + ) @main_router.post("/tokenize") @@ -236,11 +273,13 @@ async def health() -> Response: return JSONResponse(content={"status": "healthy"}, status_code=200) -@main_router.post("/v1/audio/transcriptions") -async def route_v1_audio_transcriptions( - request: Request, background_tasks: BackgroundTasks -): - """Handles audio transcription requests.""" - return await route_general_transcriptions( - request, "/v1/audio/transcriptions", background_tasks - ) +if _route_general_transcriptions: + + @main_router.post("/v1/audio/transcriptions") + async def route_v1_audio_transcriptions( + request: Request, background_tasks: BackgroundTasks + ): + """Handles audio transcription requests.""" + return await route_general_transcriptions( + request, "/v1/audio/transcriptions", background_tasks + ) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 0c5005715..f231b255e 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -138,7 +138,10 @@ async def process_request( async def route_general_request( - request: Request, endpoint: str, background_tasks: BackgroundTasks + request: Request, + endpoint: str, + background_tasks: BackgroundTasks, + request_body: Optional[bytes] = None, ): """ Route the incoming request to the backend server and stream the response back to the client. @@ -163,8 +166,14 @@ async def route_general_request( in_router_time = time.time() # Same as vllm, Get request_id from X-Request-Id header if available request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) - request_body = await request.body() + + # Use pre-provided request_body if available, otherwise read from request + if request_body is None: + request_body = await request.body() + request_json = json.loads(request_body) + # Determine if request expects streaming (OpenAI style) + is_streaming = bool(request_json.get("stream", False)) if request.query_params: request_endpoint = request.query_params.get("id") @@ -195,6 +204,11 @@ async def route_general_request( ) logger.info(f"Request for model {requested_model} was rewritten") request_body = rewritten_body + # IMPORTANT: after rewriting, update Content-Length so backend reads full JSON + try: + update_content_length(request, request_body) + except Exception as e: + logger.warning(f"Failed to update Content-Length after rewrite: {e}") # Update request_json if the body was rewritten try: request_json = json.loads(request_body) @@ -294,11 +308,13 @@ async def route_general_request( headers, status = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()} headers_dict["X-Request-Id"] = request_id + # Choose appropriate media type. If client didn't request streaming, return JSON. + media_type = "text/event-stream" if is_streaming else "application/json" return StreamingResponse( stream_generator, status_code=status, headers=headers_dict, - media_type="text/event-stream", + media_type=media_type, )