Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ dependencies = [
"python-dotenv>=1.2.2",
# Used for token estimation before LLM calls (LCORE-1569 / conversation compaction)
"tiktoken>=0.8.0",
# Used for Pydantic AI
"pydantic-ai>=1.99.0"
]


Expand Down
1 change: 1 addition & 0 deletions src/pydantic_ai_lightspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Pydantic AI integrations/extensions for Lightspeed Core Stack."""
5 changes: 5 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Pydantic AI provider for Llama Stack."""

from ._provider import LlamaStackProvider

__all__ = ["LlamaStackProvider"]
115 changes: 115 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Llama Stack provider implementation for Pydantic AI."""

from __future__ import annotations as _annotations

import os
from typing import TYPE_CHECKING

import httpx
from openai import AsyncOpenAI
from pydantic_ai import ModelProfile
from pydantic_ai.models import create_async_http_client
from pydantic_ai.profiles.openai import openai_model_profile
from pydantic_ai.providers import Provider

from ._transport import LlamaStackLibraryTransport

if TYPE_CHECKING:
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient

DEFAULT_BASE_URL = "http://localhost:8321/v1"


class LlamaStackProvider(Provider[AsyncOpenAI]):
"""Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API.

Supports two modes:

1. **Server mode** — connect to a running Llama Stack server via HTTP
2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient``
"""

@property
def name(self) -> str:
"""The provider name."""
return "llama-stack"

@property
def base_url(self) -> str:
"""The base URL for the provider API."""
return str(self._client.base_url)

@property
def client(self) -> AsyncOpenAI:
"""The OpenAI-compatible client for the provider."""
return self._client

@staticmethod
def model_profile(model_name: str) -> ModelProfile | None:
"""Return the model profile for the named model, if available."""
return openai_model_profile(model_name)

def __init__(
self,
*,
base_url: str | None = None,
api_key: str | None = None,
library_client: AsyncLlamaStackAsLibraryClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Llama Stack provider.

Args:
base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint).
Defaults to ``LLAMA_STACK_BASE_URL`` env var, then ``http://localhost:8321/v1``.
Must be ``None`` when ``library_client`` is provided.
api_key: The API key for authentication. Defaults to ``LLAMA_STACK_API_KEY`` env
var, then ``'not-needed'`` since local Llama Stack servers typically don't
require one. Must be ``None`` when ``library_client`` is provided.
library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode.
When provided, requests are dispatched in-process (no server needed).
Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``.
http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests.
Must be ``None`` when ``library_client`` is provided.

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
"""
if library_client is not None:

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert (
base_url is None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment on lines +76 to +77
), "Cannot provide both `library_client` and `base_url`"
assert api_key is None, "Cannot provide both `library_client` and `api_key`"
assert (
http_client is None
Comment on lines +80 to +81
), "Cannot provide both `library_client` and `http_client`"

Check warning

Code scanning / Bandit

Call to httpx without timeout Warning

Call to httpx without timeout
self._library_client = library_client
transport = LlamaStackLibraryTransport(library_client)
lib_http_client = httpx.AsyncClient(
transport=transport, base_url="http://llama-stack-library"
Comment on lines +86 to +87
)
self._client = AsyncOpenAI(
http_client=lib_http_client,
base_url="http://llama-stack-library/v1",
api_key="not-needed",
)
else:
base_url = (
base_url or os.environ.get("LLAMA_STACK_BASE_URL") or DEFAULT_BASE_URL
)
api_key = api_key or os.environ.get("LLAMA_STACK_API_KEY") or "not-needed"

if http_client is not None:
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=http_client
)
else:
oai_http_client = create_async_http_client()
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=oai_http_client
)

def __repr__(self) -> str:
"""Return a string representation of the provider."""
return f"LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})"

def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
self._client._client = http_client # pyright: ignore[reportPrivateUsage] # pylint: disable=protected-access
157 changes: 157 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""httpx transport that routes OpenAI-compatible requests through a Llama Stack library client."""

from __future__ import annotations as _annotations

import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

import httpx
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
convert_pydantic_to_json_value,
)
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.server.routes import find_matching_route
from llama_stack.core.utils.context import preserve_contexts_async_generator


class _AsyncByteStream(httpx.AsyncByteStream):
"""Wraps an async byte generator as an httpx AsyncByteStream."""

def __init__(self, gen: AsyncGenerator[bytes, None]) -> None:
self._gen = gen

async def __aiter__(self) -> AsyncIterator[bytes]:
async for chunk in self._gen:
yield chunk


class LlamaStackLibraryTransport(httpx.AsyncBaseTransport):
"""Custom httpx transport that dispatches requests through a Llama Stack library client.

Instead of making real HTTP calls, this transport routes requests directly
to the Llama Stack's in-process route handlers via the library client's
route matching and body conversion logic.
"""

def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None:
"""Initialize the transport with a Llama Stack library client.

Args:
client: An initialized ``AsyncLlamaStackAsLibraryClient`` whose route
handlers will receive dispatched requests.
"""
self._client = client

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
"""Dispatch an httpx request to the in-process Llama Stack route handlers.

Args:
request: The outgoing httpx request to route.

Returns:
An httpx response built from the matched route handler result.

Raises:
RuntimeError: If the library client has not been initialized.
"""
if self._client.route_impls is None:
raise RuntimeError(
"Llama Stack library client not initialized. Call initialize() first."
)

method = request.method
path = request.url.raw_path.decode("utf-8")

body = json.loads(request.content) if request.content else {}

headers: dict[str, str] = {

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
k.decode("utf-8") if isinstance(k, bytes) else k: (
v.decode("utf-8") if isinstance(v, bytes) else v
)
for k, v in request.headers.raw
}

if self._client.provider_data:
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(
self._client.provider_data
)

with request_provider_data_context(headers):
is_stream = body.get("stream", False)

if is_stream:
return await self._handle_streaming(request, method, path, body)
return await self._handle_non_streaming(request, method, path, body)

async def _handle_non_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
assert self._client.route_impls is not None

matched_func, path_params, _, _ = find_matching_route(
method, path, self._client.route_impls
)
body |= path_params
body = self._client._convert_body( # pylint: disable=protected-access
matched_func, body
)

result = await matched_func(**body)

json_content = json.dumps(convert_pydantic_to_json_value(result))
status_code = httpx.codes.OK

if method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
json_content = ""

return httpx.Response(
status_code=status_code,
content=json_content.encode("utf-8"),
headers={"Content-Type": "application/json"},
request=request,
)

async def _handle_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
assert self._client.route_impls is not None
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

func, path_params, _, _ = find_matching_route(
method, path, self._client.route_impls
)
body |= path_params
body = self._client._convert_body( # pylint: disable=protected-access
func, body
)

result = await func(**body)

async def gen() -> AsyncGenerator[bytes, None]:
async for chunk in result:
data = json.dumps(convert_pydantic_to_json_value(chunk))
yield f"data: {data}\n\n".encode("utf-8")

wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])

return httpx.Response(
status_code=httpx.codes.OK,
stream=_AsyncByteStream(wrapped_gen),
headers={"Content-Type": "text/event-stream"},
request=request,
)
Loading
Loading