diff --git a/python/PACKAGE_STATUS.md b/python/PACKAGE_STATUS.md index ec996de657..b2cd91e35c 100644 --- a/python/PACKAGE_STATUS.md +++ b/python/PACKAGE_STATUS.md @@ -30,6 +30,7 @@ Status is grouped into these buckets: | `agent-framework-declarative` | `python/packages/declarative` | `beta` | | `agent-framework-devui` | `python/packages/devui` | `beta` | | `agent-framework-durabletask` | `python/packages/durabletask` | `beta` | +| `agent-framework-fake` | `python/packages/core` | `alpha` | | `agent-framework-foundry` | `python/packages/foundry` | `released` | | `agent-framework-foundry-local` | `python/packages/foundry_local` | `beta` | | `agent-framework-gemini` | `python/packages/gemini` | `alpha` | diff --git a/python/packages/core/agent_framework/fake/__init__.py b/python/packages/core/agent_framework/fake/__init__.py new file mode 100644 index 0000000000..bd9ae6d944 --- /dev/null +++ b/python/packages/core/agent_framework/fake/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft. All rights reserved. + +from ._chat_client import FakeChatClient, FakeChatOptions + +__all__ = [ + "FakeChatClient", + "FakeChatOptions", +] diff --git a/python/packages/core/agent_framework/fake/_chat_client.py b/python/packages/core/agent_framework/fake/_chat_client.py new file mode 100644 index 0000000000..711c5d3029 --- /dev/null +++ b/python/packages/core/agent_framework/fake/_chat_client.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import copy +import sys +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence +from typing import Any, ClassVar, Generic + +from .._clients import BaseChatClient +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer +from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer +from .._types import ( + ChatOptions, + ChatResponse, + ChatResponseUpdate, + Message, + ResponseStream, +) +from ..exceptions import ChatClientInvalidRequestException +from ..observability import ChatTelemetryLayer +from pydantic import BaseModel + +if sys.version_info >= (3, 13): + from typing import TypeVar # type: ignore # pragma: no cover +else: + from typing_extensions import TypeVar # type: ignore # pragma: no cover + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore # pragma: no cover + + +__all__ = ["FakeChatClient", "FakeChatOptions"] + +ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None) + +FakeResponseItem = str | Message | ChatResponse + + +class FakeChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], total=False): + """Fake-model options used by FakeChatClient. + + Keys: + model: Optional model name override for this request. + response: Optional one-off response that overrides queued responses. + cycle: Optional per-request override for cycling behavior. + """ + + response: FakeResponseItem + cycle: bool + + +class FakeChatClient( + FunctionInvocationLayer[FakeChatOptions], + ChatMiddlewareLayer[FakeChatOptions], + ChatTelemetryLayer[FakeChatOptions], + BaseChatClient[FakeChatOptions], +): + """Deterministic fake chat client useful for tests and local demos.""" + + OTEL_PROVIDER_NAME: ClassVar[str] = "fake" + + def __init__( + self, + *, + responses: Sequence[FakeResponseItem], + model: str = "fake-model", + cycle: bool = False, + additional_properties: dict[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + function_invocation_configuration: FunctionInvocationConfiguration | None = None, + ) -> None: + """Initialize a fake chat client. + + Keyword Args: + responses: Ordered fake responses returned on successive calls. + model: Default model name used in generated responses. + cycle: Whether responses should wrap to the beginning when exhausted. + When False, an error is always raised once the list is exhausted. + additional_properties: Additional properties stored on the client instance. + middleware: Optional middleware to apply to the client. + function_invocation_configuration: Optional function invocation configuration override. + """ + super().__init__( + additional_properties=additional_properties, + middleware=middleware, + function_invocation_configuration=function_invocation_configuration, + ) + self.model = model + self._responses = list(responses) + self._response_index = 0 + self._cycle = cycle + + @override + def _inner_get_response( + self, + *, + messages: Sequence[Message], + options: Mapping[str, Any], + stream: bool = False, + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + del kwargs + + response = self._select_response(messages=messages, options=options) + if stream: + return self._to_stream(response) + + async def _get_response() -> ChatResponse: + return response + + return _get_response() + + def _to_stream(self, response: ChatResponse) -> ResponseStream[ChatResponseUpdate, ChatResponse]: + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + assistant_messages = [message for message in response.messages if message.role == "assistant"] + if not assistant_messages: + return + + for index, message in enumerate(assistant_messages): + yield ChatResponseUpdate( + contents=message.contents, + role="assistant", + model=response.model, + created_at=response.created_at, + finish_reason=response.finish_reason if index == len(assistant_messages) - 1 else None, + ) + + def _finalize(_updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + return response + + return ResponseStream(_stream(), finalizer=_finalize) + + def _select_response(self, *, messages: Sequence[Message], options: Mapping[str, Any]) -> ChatResponse: + + if not messages: + raise ChatClientInvalidRequestException("Messages are required for chat completions") + + if (single_response := options.get("response")) is not None: + return self._materialize_response(single_response, options) + + if self._response_index >= len(self._responses): + should_cycle = bool(options.get("cycle", self._cycle)) + if should_cycle: + self._response_index = 0 + else: + raise ChatClientInvalidRequestException( + "FakeChatClient response list is exhausted. Provide more responses or enable cycle=True." + ) + + item = self._responses[self._response_index] + self._response_index += 1 + return self._materialize_response(item, options) + + def _materialize_response(self, value: FakeResponseItem, options: Mapping[str, Any]) -> ChatResponse: + model = str(options.get("model") or self.model) + + if isinstance(value, ChatResponse): + # Shallow-copy to avoid mutating the original queued item (e.g. under cycle=True). + cloned = copy.copy(value) + cloned.model = model + return cloned + if isinstance(value, Message): + return ChatResponse(messages=[value], model=model) + return ChatResponse(messages=[Message(role="assistant", contents=[value])], model=model) diff --git a/python/packages/core/tests/core/test_fake_chat_client.py b/python/packages/core/tests/core/test_fake_chat_client.py new file mode 100644 index 0000000000..96e94ba254 --- /dev/null +++ b/python/packages/core/tests/core/test_fake_chat_client.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft. All rights reserved. + +import pytest +from agent_framework import ( + ChatResponse, + Message, + chat_middleware, +) +from agent_framework.exceptions import ChatClientInvalidRequestException +from agent_framework.fake import FakeChatClient + + +def test_init() -> None: + fake_chat_client = FakeChatClient(model="fake-model", responses=["Hello!", "This framework is amazing!"]) + assert fake_chat_client.model == "fake-model" + assert fake_chat_client._responses == ["Hello!", "This framework is amazing!"] + assert not fake_chat_client._cycle + + +def test_serialize() -> None: + settings = { + "responses": ["Hello!", "This framework is amazing!"], + "model": "fake-model-serialize", + "cycle": False, + } + + fake_chat_client = FakeChatClient.from_dict(settings) + serialized = fake_chat_client.to_dict() + + assert isinstance(serialized, dict) + # only public attributes are serialized + assert serialized["model"] == "fake-model-serialize" + + +def test_chat_middleware() -> None: + @chat_middleware + async def sample_middleware(context, call_next): + await call_next() + + fake_chat_client = FakeChatClient(responses=["Hello!"], middleware=[sample_middleware]) + assert len(fake_chat_client.chat_middleware) == 1 + assert fake_chat_client.chat_middleware[0] == sample_middleware + + +async def test_empty_messages() -> None: + fake_chat_client = FakeChatClient(responses=["Test :)"]) + with pytest.raises(ChatClientInvalidRequestException): + await fake_chat_client.get_response(messages=[]) + + +async def test_get_response() -> None: + fake_chat_client = FakeChatClient( + responses=[ + "the most beautiful number is 1729", + "It is the smallest number that can be written as the " + "sum of cubes in two different ways: 1729 = 1** + 12**3 = 9**3 + 10**3", + ], + cycle=False, + ) + + result_first = await fake_chat_client.get_response( + messages=[Message(contents=["what is the most beautiful number?"], role="user")] + ) + assert result_first.text == "the most beautiful number is 1729" + result_second = await fake_chat_client.get_response(messages=[Message(contents=["and why is it?"], role="user")]) + assert ( + result_second.text == "It is the smallest number that can be written as the " + "sum of cubes in two different ways: 1729 = 1** + 12**3 = 9**3 + 10**3" + ) + with pytest.raises( + ChatClientInvalidRequestException, + match="FakeChatClient response list is exhausted. Provide more responses or enable cycle=True.", + ): + await fake_chat_client.get_response(messages=[Message(contents=["Do you have more?"], role="user")]) + + +async def test_get_response_cycle() -> None: + client = FakeChatClient(responses=["a", "b"], cycle=True) + messages = [Message(role="user", contents=["hi"])] + + r1 = await client.get_response(messages=messages) + r2 = await client.get_response(messages=messages) + r3 = await client.get_response(messages=messages) + r4 = await client.get_response(messages=messages) + + assert r1.text == "a" + assert r2.text == "b" + assert r3.text == "a" + assert r4.text == "b" + + +async def test_get_response_stream() -> None: + client = FakeChatClient(responses=["streaming response"]) + messages = [Message(role="user", contents=["hi"])] + + stream = client.get_response(messages=messages, stream=True) + updates = [update async for update in stream] + final = await stream.get_final_response() + + assert len(updates) == 1 + assert updates[0].text == "streaming response" + assert final.text == "streaming response" + + +async def test_chat_response_model_override_from_queue() -> None: + queued = ChatResponse(messages=[Message(role="assistant", contents=["hi"])], model="original-model") + client = FakeChatClient(responses=[queued], model="default-model") + messages = [Message(role="user", contents=["hello"])] + + result = await client.get_response(messages=messages, options={"model": "override-model"}) + + assert result.model == "override-model" + + +async def test_chat_response_model_override_from_options_response() -> None: + one_off = ChatResponse(messages=[Message(role="assistant", contents=["hi"])], model="original-model") + client = FakeChatClient(responses=[], model="default-model") + messages = [Message(role="user", contents=["hello"])] + + result = await client.get_response(messages=messages, options={"response": one_off, "model": "override-model"}) + + assert result.model == "override-model" + + +async def test_middleware_wraps_response() -> None: + @chat_middleware + async def wrapping_middleware(context, call_next): + await call_next() + context.result = ChatResponse( + messages=[Message(role="assistant", contents=[f"[wrapped] {context.result.text}"])], + model=context.result.model, + ) + + client = FakeChatClient(responses=["hello"], middleware=[wrapping_middleware]) + result = await client.get_response(messages=[Message(role="user", contents=["hello"])]) + + assert result.text == "[wrapped] hello"