diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e486ab1..74696bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [main] + branches: [main, dev] pull_request: branches: [main] @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v3 - run: uv sync --dev - - run: uv run ruff check src/metorial/ examples/ tests/ + - run: uv run ruff check src/metorial/ tests/ format: runs-on: ubuntu-latest @@ -21,7 +21,7 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v3 - run: uv sync --dev - - run: uv run cblack --check src/metorial/ examples/ --exclude='src/metorial/_generated' + - run: uv run ruff format --check src/metorial/ type-check: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 8566c20..841d546 100644 --- a/.gitignore +++ b/.gitignore @@ -185,9 +185,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ @@ -211,5 +211,4 @@ __marimo__/ # Project-specific playground/ -tests/ -pytest.ini \ No newline at end of file +pytest.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27cf479..54046e6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,15 +5,7 @@ repos: - id: ruff args: [--fix, --exit-non-zero-on-fix] exclude: ^src/metorial/_generated/ - - - repo: local - hooks: - - id: cblack - name: cblack - entry: cblack - language: system - types: [python] - args: [--check] + - id: ruff-format exclude: ^src/metorial/_generated/ - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/pyproject.toml b/pyproject.toml index 30797a8..e05a5bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", "ruff>=0.4.0", - "cblack>=22.6.0", "mypy>=1.0.0", "types-requests>=2.32.0", "pre-commit>=3.5.0", @@ -64,27 +63,7 @@ packages = ["src/metorial"] [tool.hatch.build.targets.sdist] include = ["/src", "/README.md", "/LICENSE"] -# cblack configuration (Black with 2-space indentation) -[tool.cblack] -line-length = 88 -target-version = ['py310'] -include = '\.pyi?$' -extend-exclude = ''' -/( - # directories - \.eggs - | \.git - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | build - | dist - | src/metorial/_generated -)/ -''' - -# MyPy configuration (WorkOS pattern - strict mode) +# MyPy configuration [tool.mypy] python_version = "3.10" strict = true diff --git a/src/metorial/_base.py b/src/metorial/_base.py index 083ce21..72a302f 100644 --- a/src/metorial/_base.py +++ b/src/metorial/_base.py @@ -337,12 +337,10 @@ def mcp(self) -> dict[str, Any]: } @overload - def create_mcp_session(self, init: MetorialMcpSessionInit) -> MetorialSession: - ... + def create_mcp_session(self, init: MetorialMcpSessionInit) -> MetorialSession: ... @overload - def create_mcp_session(self, init: dict[str, Any]) -> MetorialSession: - ... + def create_mcp_session(self, init: dict[str, Any]) -> MetorialSession: ... def create_mcp_session( self, init: MetorialMcpSessionInit | dict[str, Any] diff --git a/src/metorial/_client_core.py b/src/metorial/_client_core.py index 84502b6..c329d6c 100644 --- a/src/metorial/_client_core.py +++ b/src/metorial/_client_core.py @@ -50,7 +50,7 @@ def _normalize_server_deployments( raise ValueError(f"Invalid deployment object format: {deployment}") else: raise ValueError( - f"Invalid deployment type: {type(deployment)} " "- must be string or dict" + f"Invalid deployment type: {type(deployment)} - must be string or dict" ) return normalized diff --git a/src/metorial/_protocols.py b/src/metorial/_protocols.py index 5025b9a..3e801e9 100644 --- a/src/metorial/_protocols.py +++ b/src/metorial/_protocols.py @@ -13,15 +13,12 @@ class ToolLike(Protocol): """Protocol for tool-like objects that have name, description, and parameters.""" @property - def name(self) -> str: - ... + def name(self) -> str: ... @property - def description(self) -> str | None: - ... + def description(self) -> str | None: ... - def get_parameters_as(self, format: str) -> dict[str, Any]: - ... + def get_parameters_as(self, format: str) -> dict[str, Any]: ... @runtime_checkable diff --git a/src/metorial/_sdk_builder.py b/src/metorial/_sdk_builder.py index 45ee93c..a5613d8 100644 --- a/src/metorial/_sdk_builder.py +++ b/src/metorial/_sdk_builder.py @@ -44,7 +44,7 @@ def build( raise ValueError("api_host must be set") def builder( - get_endpoints: Callable[[MetorialEndpointManager], dict[str, Any]] + get_endpoints: Callable[[MetorialEndpointManager], dict[str, Any]], ) -> Callable[[dict[str, Any]], dict[str, Any]]: def sdk(config: dict[str, Any]) -> dict[str, Any]: full_config = get_config(config) diff --git a/src/metorial/adapters/openai_compatible.py b/src/metorial/adapters/openai_compatible.py index 4cbcd74..08b6153 100644 --- a/src/metorial/adapters/openai_compatible.py +++ b/src/metorial/adapters/openai_compatible.py @@ -20,12 +20,12 @@ call_openai_compatible_tools, ) except ImportError: - call_openai_compatible_tools: Callable[ - [Any, list[Any]], list[dict[str, Any]] - ] | None = None - build_openai_compatible_tools: Callable[ - [Any, bool], list[dict[str, Any]] - ] | None = None + call_openai_compatible_tools: ( + Callable[[Any, list[Any]], list[dict[str, Any]]] | None + ) = None + build_openai_compatible_tools: ( + Callable[[Any, bool], list[dict[str, Any]]] | None + ) = None class OpenAICompatibleAdapter(ProviderAdapter): diff --git a/src/metorial/integrations/pydantic_ai.py b/src/metorial/integrations/pydantic_ai.py index 65186c5..bc39735 100644 --- a/src/metorial/integrations/pydantic_ai.py +++ b/src/metorial/integrations/pydantic_ai.py @@ -158,7 +158,8 @@ async def tool_fn(**kwargs: Any) -> str: # Set annotations for PydanticAI to discover parameters tool_fn.__annotations__ = { - k: v[0] for k, v in fields.items() # Get the type from (type, Field) tuple + k: v[0] + for k, v in fields.items() # Get the type from (type, Field) tuple } tool_fn.__annotations__["return"] = str diff --git a/src/metorial/mcp/mcp_session.py b/src/metorial/mcp/mcp_session.py index fa813b5..4c8d373 100644 --- a/src/metorial/mcp/mcp_session.py +++ b/src/metorial/mcp/mcp_session.py @@ -80,8 +80,7 @@ class _ServersAPI(Protocol): """Protocol for servers API - only requires capabilities sub-API.""" @property - def capabilities(self) -> _CapabilitiesAPI: - ... + def capabilities(self) -> _CapabilitiesAPI: ... class MetorialCoreSDK(Protocol): @@ -93,16 +92,13 @@ class MetorialCoreSDK(Protocol): """ @property - def _config(self) -> _SDKConfig: - ... + def _config(self) -> _SDKConfig: ... @property - def sessions(self) -> _SessionsAPI | None: - ... + def sessions(self) -> _SessionsAPI | None: ... @property - def servers(self) -> _ServersAPI | None: - ... + def servers(self) -> _ServersAPI | None: ... class _SessionResponse(TypedDict): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..4cf8cc3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,131 @@ +""" +Shared test fixtures and configuration for metorial tests. + +Includes sync/async client parametrization pattern. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal +from unittest.mock import AsyncMock, MagicMock + +import pytest + +if TYPE_CHECKING: + from metorial import Metorial, MetorialSync + + +# -------------------------------------------------------------------------- +# Pytest configuration +# -------------------------------------------------------------------------- + + +def pytest_configure(config: pytest.Config) -> None: + """Register custom markers.""" + config.addinivalue_line( + "markers", + "sync_and_async: mark test to run with both sync and async clients", + ) + + +# -------------------------------------------------------------------------- +# Sync/Async client parametrization +# -------------------------------------------------------------------------- + + +@pytest.fixture(params=["sync", "async"]) +def client_type(request: pytest.FixtureRequest) -> Literal["sync", "async"]: + """Parametrize tests to run with both sync and async clients.""" + return request.param + + +@pytest.fixture +def metorial_client( + client_type: Literal["sync", "async"], + mock_metorial_config: dict[str, str], +) -> Metorial | MetorialSync: + """Create a Metorial client based on client_type fixture. + + This allows tests marked with @pytest.mark.sync_and_async to run + automatically with both sync and async clients. + """ + if client_type == "sync": + from metorial import MetorialSync + + return MetorialSync(api_key=mock_metorial_config["apiKey"]) + else: + from metorial import Metorial + + return Metorial(api_key=mock_metorial_config["apiKey"]) + + +@pytest.fixture +def async_metorial_client(mock_metorial_config: dict[str, str]) -> Metorial: + """Create an async-only Metorial client for async-specific tests.""" + from metorial import Metorial + + return Metorial(api_key=mock_metorial_config["apiKey"]) + + +@pytest.fixture +def sync_metorial_client(mock_metorial_config: dict[str, str]) -> MetorialSync: + """Create a sync-only Metorial client for sync-specific tests.""" + from metorial import MetorialSync + + return MetorialSync(api_key=mock_metorial_config["apiKey"]) + + +# -------------------------------------------------------------------------- +# Mock fixtures +# -------------------------------------------------------------------------- + + +@pytest.fixture +def mock_tool_manager() -> MagicMock: + """Mock tool manager for testing.""" + manager = MagicMock() + manager.get_tools.return_value = [] + manager.call_tool = AsyncMock(return_value={"content": "result"}) + manager.get_tool.return_value = None + return manager + + +@pytest.fixture +def mock_metorial_config() -> dict[str, str]: + """Mock configuration for testing.""" + return { + "apiKey": "test-api-key", + "apiHost": "https://api.metorial.com", + "mcpHost": "https://mcp.metorial.com", + } + + +@pytest.fixture +def mock_mcp_tool() -> MagicMock: + """Mock MCP tool object.""" + tool = MagicMock() + tool.name = "test_tool" + tool.description = "A test tool" + tool.parameters = {"type": "object", "properties": {"param1": {"type": "string"}}} + return tool + + +@pytest.fixture +def mock_mcp_session() -> MagicMock: + """Mock MCP session for testing.""" + session = MagicMock() + session.get_tool_manager = AsyncMock() + session.close = AsyncMock() + return session + + +@pytest.fixture +def mock_http_response() -> MagicMock: + """Mock HTTP response for testing RawResponse.""" + response = MagicMock() + response.status_code = 200 + response.headers = { + "X-Request-ID": "req-test-123", + "Content-Type": "application/json", + } + return response diff --git a/tests/test_client_fixtures.py b/tests/test_client_fixtures.py new file mode 100644 index 0000000..ff2a494 --- /dev/null +++ b/tests/test_client_fixtures.py @@ -0,0 +1,262 @@ +""" +Tests for client fixture functionality and sync/async parametrization. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from metorial import Metorial, MetorialSync + + +# ============================================================================= +# Client Type Parametrization Tests +# ============================================================================= + + +class TestClientTypeFixture: + """Tests for client_type parametrized fixture.""" + + def test_client_type_is_sync_or_async(self, client_type: str) -> None: + """client_type should be either 'sync' or 'async'.""" + assert client_type in ("sync", "async") + + +# ============================================================================= +# Metorial Client Fixture Tests +# ============================================================================= + + +class TestMetorialClientFixture: + """Tests for metorial_client parametrized fixture.""" + + def test_creates_sync_client(self, mock_metorial_config: dict[str, str]) -> None: + """Should create sync client when client_type is 'sync'.""" + from metorial import MetorialSync + + client = MetorialSync(api_key=mock_metorial_config["apiKey"]) + assert isinstance(client, MetorialSync) + + def test_creates_async_client(self, mock_metorial_config: dict[str, str]) -> None: + """Should create async client when client_type is 'async'.""" + from metorial import Metorial + + client = Metorial(api_key=mock_metorial_config["apiKey"]) + assert isinstance(client, Metorial) + + def test_client_has_api_key(self, metorial_client: Metorial | MetorialSync) -> None: + """Client should have API key configured.""" + assert metorial_client._config_data["apiKey"] == "test-api-key" + + +# ============================================================================= +# Async Metorial Client Fixture Tests +# ============================================================================= + + +class TestAsyncMetorialClientFixture: + """Tests for async_metorial_client fixture.""" + + def test_creates_async_client(self, async_metorial_client: Metorial) -> None: + """Should create an async Metorial client.""" + from metorial import Metorial + + assert isinstance(async_metorial_client, Metorial) + + def test_async_client_has_config(self, async_metorial_client: Metorial) -> None: + """Async client should have proper configuration.""" + assert async_metorial_client._config_data["apiKey"] == "test-api-key" + assert "apiHost" in async_metorial_client._config_data + + def test_async_client_has_provider_session( + self, async_metorial_client: Metorial + ) -> None: + """Async client should have provider_session method.""" + assert hasattr(async_metorial_client, "provider_session") + assert callable(async_metorial_client.provider_session) + + +# ============================================================================= +# Sync Metorial Client Fixture Tests +# ============================================================================= + + +class TestSyncMetorialClientFixture: + """Tests for sync_metorial_client fixture.""" + + def test_creates_sync_client(self, sync_metorial_client: MetorialSync) -> None: + """Should create a sync MetorialSync client.""" + from metorial import MetorialSync + + assert isinstance(sync_metorial_client, MetorialSync) + + def test_sync_client_has_config(self, sync_metorial_client: MetorialSync) -> None: + """Sync client should have proper configuration.""" + assert sync_metorial_client._config_data["apiKey"] == "test-api-key" + assert "apiHost" in sync_metorial_client._config_data + + def test_sync_client_has_session_method( + self, sync_metorial_client: MetorialSync + ) -> None: + """Sync client should have session method.""" + assert hasattr(sync_metorial_client, "session") + assert callable(sync_metorial_client.session) + + +# ============================================================================= +# Mock Configuration Fixture Tests +# ============================================================================= + + +class TestMockMetorialConfigFixture: + """Tests for mock_metorial_config fixture.""" + + def test_has_required_keys(self, mock_metorial_config: dict[str, str]) -> None: + """Config should have all required keys.""" + assert "apiKey" in mock_metorial_config + assert "apiHost" in mock_metorial_config + assert "mcpHost" in mock_metorial_config + + def test_has_valid_values(self, mock_metorial_config: dict[str, str]) -> None: + """Config should have valid values.""" + assert mock_metorial_config["apiKey"] == "test-api-key" + assert mock_metorial_config["apiHost"].startswith("https://") + assert mock_metorial_config["mcpHost"].startswith("https://") + + +# ============================================================================= +# Mock Tool Manager Fixture Tests +# ============================================================================= + + +class TestMockToolManagerFixture: + """Tests for mock_tool_manager fixture.""" + + def test_has_get_tools_method(self, mock_tool_manager) -> None: + """Should have get_tools method.""" + assert hasattr(mock_tool_manager, "get_tools") + assert mock_tool_manager.get_tools() == [] + + def test_has_call_tool_method(self, mock_tool_manager) -> None: + """Should have call_tool method.""" + assert hasattr(mock_tool_manager, "call_tool") + + def test_has_get_tool_method(self, mock_tool_manager) -> None: + """Should have get_tool method.""" + assert hasattr(mock_tool_manager, "get_tool") + assert mock_tool_manager.get_tool() is None + + +# ============================================================================= +# Mock MCP Tool Fixture Tests +# ============================================================================= + + +class TestMockMcpToolFixture: + """Tests for mock_mcp_tool fixture.""" + + def test_has_name(self, mock_mcp_tool) -> None: + """Should have name attribute.""" + assert mock_mcp_tool.name == "test_tool" + + def test_has_description(self, mock_mcp_tool) -> None: + """Should have description attribute.""" + assert mock_mcp_tool.description == "A test tool" + + def test_has_parameters(self, mock_mcp_tool) -> None: + """Should have parameters attribute.""" + assert "type" in mock_mcp_tool.parameters + assert "properties" in mock_mcp_tool.parameters + + +# ============================================================================= +# Mock MCP Session Fixture Tests +# ============================================================================= + + +class TestMockMcpSessionFixture: + """Tests for mock_mcp_session fixture.""" + + def test_has_get_tool_manager_method(self, mock_mcp_session) -> None: + """Should have get_tool_manager method.""" + assert hasattr(mock_mcp_session, "get_tool_manager") + + def test_has_close_method(self, mock_mcp_session) -> None: + """Should have close method.""" + assert hasattr(mock_mcp_session, "close") + + +# ============================================================================= +# Mock HTTP Response Fixture Tests +# ============================================================================= + + +class TestMockHttpResponseFixture: + """Tests for mock_http_response fixture.""" + + def test_has_status_code(self, mock_http_response) -> None: + """Should have status_code attribute.""" + assert mock_http_response.status_code == 200 + + def test_has_headers(self, mock_http_response) -> None: + """Should have headers dict.""" + assert "X-Request-ID" in mock_http_response.headers + assert "Content-Type" in mock_http_response.headers + + def test_has_request_id_header(self, mock_http_response) -> None: + """Should have X-Request-ID header.""" + assert mock_http_response.headers["X-Request-ID"] == "req-test-123" + + +# ============================================================================= +# Client Context Manager Tests +# ============================================================================= + + +class TestClientContextManagers: + """Tests for client context manager behavior.""" + + def test_sync_client_context_manager( + self, sync_metorial_client: MetorialSync + ) -> None: + """Sync client should support context manager.""" + with sync_metorial_client as client: + assert client is sync_metorial_client + + @pytest.mark.asyncio + async def test_async_client_context_manager( + self, async_metorial_client: Metorial + ) -> None: + """Async client should support async context manager.""" + async with async_metorial_client as client: + assert client is async_metorial_client + + +# ============================================================================= +# Client Configuration Tests +# ============================================================================= + + +class TestClientConfiguration: + """Tests for client configuration handling.""" + + def test_client_default_timeout( + self, metorial_client: Metorial | MetorialSync + ) -> None: + """Client should have default timeout configured.""" + assert metorial_client._config_data["timeout"] == 30.0 + + def test_client_default_max_retries( + self, metorial_client: Metorial | MetorialSync + ) -> None: + """Client should have default max retries configured.""" + assert metorial_client._config_data["maxRetries"] == 3 + + def test_client_has_http_client( + self, metorial_client: Metorial | MetorialSync + ) -> None: + """Client should have HTTP client initialized.""" + assert hasattr(metorial_client, "_http_client") diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..fb1a2f6 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,107 @@ +""" +Tests for configuration handling. +""" + +import pytest + +from metorial._base import MetorialBase + + +class TestConfigFromEnv: + """Tests for configuration from environment""" + + def test_load_config_from_env(self): + """Configuration should load API key from parameter.""" + base = MetorialBase(api_key="test-key") + + assert base._config_data["apiKey"] == "test-key" + assert base._config_data["apiHost"] == "https://api.metorial.com" + assert base._config_data["mcpHost"] == "https://mcp.metorial.com" + + def test_config_with_dict(self): + """Configuration should accept dict format.""" + config = { + "apiKey": "dict-api-key", + "apiHost": "https://custom-api.example.com", + "mcpHost": "https://custom-mcp.example.com", + } + base = MetorialBase(api_key=config) + + assert base._config_data["apiKey"] == "dict-api-key" + assert base._config_data["apiHost"] == "https://custom-api.example.com" + assert base._config_data["mcpHost"] == "https://custom-mcp.example.com" + + +class TestConfigValidation: + """Tests for configuration validation""" + + def test_validate_config_missing_key(self): + """Missing API key should raise ValueError.""" + with pytest.raises(ValueError, match="api_key is required"): + MetorialBase(api_key=None) + + def test_validate_config_empty_key(self): + """Empty API key should raise ValueError.""" + with pytest.raises(ValueError, match="api_key is required"): + MetorialBase(api_key="") + + def test_config_with_updates(self): + """Additional kwargs should be stored in config.""" + base = MetorialBase(api_key="test-key", custom_param="custom_value") + + assert base._config_data["custom_param"] == "custom_value" + + +class TestConfigHostDerivation: + """Tests for automatic host derivation""" + + def test_derive_mcp_from_api_host(self): + """MCP host should be derived from custom API host.""" + base = MetorialBase(api_key="test-key", api_host="https://api.custom.example.com") + + assert base._config_data["mcpHost"] == "https://mcp.custom.example.com" + + def test_derive_api_from_mcp_host(self): + """API host should be derived from custom MCP host.""" + base = MetorialBase(api_key="test-key", mcp_host="https://mcp.custom.example.com") + + assert base._config_data["apiHost"] == "https://api.custom.example.com" + + def test_explicit_hosts_not_overwritten(self): + """Explicitly provided hosts should not be overwritten.""" + base = MetorialBase( + api_key="test-key", + api_host="https://api.explicit.com", + mcp_host="https://mcp.explicit.com", + ) + + assert base._config_data["apiHost"] == "https://api.explicit.com" + assert base._config_data["mcpHost"] == "https://mcp.explicit.com" + + +class TestConfigTimeout: + """Tests for timeout configuration""" + + def test_default_timeout(self): + """Default timeout should be 30 seconds.""" + base = MetorialBase(api_key="test-key") + + assert base._config_data["timeout"] == 30.0 + + def test_custom_timeout(self): + """Custom timeout should be respected.""" + base = MetorialBase(api_key="test-key", timeout=60.0) + + assert base._config_data["timeout"] == 60.0 + + def test_max_retries_default(self): + """Default max retries should be 3.""" + base = MetorialBase(api_key="test-key") + + assert base._config_data["maxRetries"] == 3 + + def test_custom_max_retries(self): + """Custom max retries should be respected.""" + base = MetorialBase(api_key="test-key", max_retries=5) + + assert base._config_data["maxRetries"] == 5 diff --git a/tests/test_endpoint_manager.py b/tests/test_endpoint_manager.py new file mode 100644 index 0000000..7c83eca --- /dev/null +++ b/tests/test_endpoint_manager.py @@ -0,0 +1,604 @@ +""" +Tests for endpoint manager HTTP client functionality. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from metorial._endpoint.endpoint_manager import MetorialEndpointManager +from metorial._endpoint.request import MetorialRequest +from metorial.exceptions import ( + AuthenticationError, + BadRequestError, + InternalServerError, + MetorialSDKError, + NotFoundError, + RateLimitError, +) + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def endpoint_manager() -> MetorialEndpointManager: + """Create an endpoint manager for testing.""" + config = {"apiKey": "test-key"} + return MetorialEndpointManager( + config=config, + api_host="https://api.metorial.com", + get_headers=lambda c: {"Authorization": f"Bearer {c['apiKey']}"}, + enable_debug_logging=False, + ) + + +@pytest.fixture +def mock_response() -> MagicMock: + """Create a mock HTTP response.""" + response = MagicMock() + response.status_code = 200 + response.ok = True + response.text = '{"data": "test"}' + response.headers = {"Content-Type": "application/json"} + response.json.return_value = {"data": "test"} + return response + + +# ============================================================================= +# Request ID Capture Tests +# ============================================================================= + + +class TestRequestIdCapture: + """Tests for X-Request-ID header capture in errors.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_request_id_on_400( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Request ID should be captured on 400 errors.""" + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"message": "Bad request"}', + headers={"X-Request-ID": "req-bad-400", "Content-Type": "application/json"}, + json=MagicMock(return_value={"message": "Bad request"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.request_id == "req-bad-400" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_request_id_on_401( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Request ID should be captured on 401 errors.""" + mock_request.return_value = MagicMock( + status_code=401, + ok=False, + text='{"message": "Unauthorized"}', + headers={"X-Request-ID": "req-auth-401"}, + json=MagicMock(return_value={"message": "Unauthorized"}), + reason="Unauthorized", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(AuthenticationError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.request_id == "req-auth-401" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_request_id_on_404( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Request ID should be captured on 404 errors.""" + mock_request.return_value = MagicMock( + status_code=404, + ok=False, + text='{"message": "Not found"}', + headers={"X-Request-ID": "req-notfound-404"}, + json=MagicMock(return_value={"message": "Not found"}), + reason="Not Found", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(NotFoundError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.request_id == "req-notfound-404" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_request_id_on_429( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Request ID should be captured on 429 errors (after retries exhausted).""" + mock_request.return_value = MagicMock( + status_code=429, + ok=False, + text='{"message": "Rate limited"}', + headers={"X-Request-ID": "req-rate-429"}, + json=MagicMock(return_value={"message": "Rate limited"}), + reason="Too Many Requests", + ) + + request = MetorialRequest(path="/test") + # After 3 retries, should raise RateLimitError + with pytest.raises(RateLimitError) as exc_info: + endpoint_manager._request("GET", request, try_count=3) + + assert exc_info.value.request_id == "req-rate-429" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_request_id_on_500( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Request ID should be captured on 500 errors.""" + mock_request.return_value = MagicMock( + status_code=500, + ok=False, + text='{"message": "Internal server error"}', + headers={"X-Request-ID": "req-server-500"}, + json=MagicMock(return_value={"message": "Internal server error"}), + reason="Internal Server Error", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(InternalServerError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.request_id == "req-server-500" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_handles_missing_request_id( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should handle responses without X-Request-ID header.""" + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"message": "Bad request"}', + headers={}, # No X-Request-ID + json=MagicMock(return_value={"message": "Bad request"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.request_id is None + + +# ============================================================================= +# Error Message Extraction Tests +# ============================================================================= + + +class TestErrorMessageExtraction: + """Tests for error message extraction from responses.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_extracts_message_from_message_field( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should extract message from 'message' field.""" + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"message": "Custom error message"}', + headers={"X-Request-ID": "req-123"}, + json=MagicMock(return_value={"message": "Custom error message"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.message == "Custom error message" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_extracts_message_from_error_field( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should extract message from 'error' field when 'message' is missing.""" + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"error": "Error description"}', + headers={"X-Request-ID": "req-123"}, + json=MagicMock(return_value={"error": "Error description"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.message == "Error description" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_falls_back_to_reason_phrase( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should fall back to HTTP reason phrase when no message in body.""" + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"code": "ERROR_CODE"}', # No message or error field + headers={"X-Request-ID": "req-123"}, + json=MagicMock(return_value={"code": "ERROR_CODE"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.message == "Bad Request" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_handles_string_response_body( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should handle string response body.""" + mock_request.return_value = MagicMock( + status_code=500, + ok=False, + text="Plain text error", + headers={"X-Request-ID": "req-123"}, + json=MagicMock(side_effect=ValueError("Not JSON")), + reason="Internal Server Error", + ) + + request = MetorialRequest(path="/test") + # Should raise MetorialSDKError due to malformed response + with pytest.raises(MetorialSDKError): + endpoint_manager._request("GET", request) + + +# ============================================================================= +# Response Body Capture Tests +# ============================================================================= + + +class TestResponseBodyCapture: + """Tests for response body capture in errors.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_dict_body( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should capture dict response body.""" + body = {"message": "Error", "details": {"field": "invalid"}} + mock_request.return_value = MagicMock( + status_code=422, + ok=False, + text='{"message": "Error", "details": {"field": "invalid"}}', + headers={"X-Request-ID": "req-123"}, + json=MagicMock(return_value=body), + reason="Unprocessable Entity", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(Exception) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.body == body + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_captures_validation_errors( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should capture validation error details.""" + body = { + "message": "Validation failed", + "errors": [ + {"field": "email", "message": "Invalid format"}, + {"field": "password", "message": "Too short"}, + ], + } + mock_request.return_value = MagicMock( + status_code=422, + ok=False, + text="...", + headers={"X-Request-ID": "req-validation"}, + json=MagicMock(return_value=body), + reason="Unprocessable Entity", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(Exception) as exc_info: + endpoint_manager._request("GET", request) + + assert exc_info.value.body["errors"][0]["field"] == "email" + + +# ============================================================================= +# Successful Response Tests +# ============================================================================= + + +class TestSuccessfulResponses: + """Tests for successful response handling.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_returns_json_for_200( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should return parsed JSON for 200 response.""" + mock_request.return_value = MagicMock( + status_code=200, + ok=True, + text='{"data": "test"}', + headers={}, + json=MagicMock(return_value={"data": "test"}), + ) + + request = MetorialRequest(path="/test") + result = endpoint_manager._request("GET", request) + + assert result == {"data": "test"} + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_returns_empty_dict_for_204( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should return empty dict for 204 No Content.""" + mock_request.return_value = MagicMock( + status_code=204, + ok=True, + text="", + headers={}, + ) + + request = MetorialRequest(path="/test") + result = endpoint_manager._request("DELETE", request) + + assert result == {} + + +# ============================================================================= +# Network Error Tests +# ============================================================================= + + +class TestNetworkErrors: + """Tests for network error handling.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_raises_sdk_error_on_connection_error( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should raise MetorialSDKError on connection error.""" + mock_request.side_effect = ConnectionError("Connection refused") + + request = MetorialRequest(path="/test") + with pytest.raises(MetorialSDKError) as exc_info: + endpoint_manager._request("GET", request) + + assert "Unable to connect" in str(exc_info.value) + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_raises_sdk_error_on_timeout( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should raise MetorialSDKError on timeout.""" + import requests + + mock_request.side_effect = requests.Timeout("Request timed out") + + request = MetorialRequest(path="/test") + with pytest.raises(MetorialSDKError) as exc_info: + endpoint_manager._request("GET", request) + + assert "Unable to connect" in str(exc_info.value) + + +# ============================================================================= +# Status Code Mapping Tests +# ============================================================================= + + +class TestStatusCodeMapping: + """Tests for HTTP status code to exception mapping.""" + + @pytest.mark.parametrize( + "status,expected_exception", + [ + (400, BadRequestError), + (401, AuthenticationError), + (404, NotFoundError), + (429, RateLimitError), + (500, InternalServerError), + (502, InternalServerError), + (503, InternalServerError), + ], + ) + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_maps_status_to_correct_exception( + self, + mock_request: MagicMock, + endpoint_manager: MetorialEndpointManager, + status: int, + expected_exception: type, + ) -> None: + """Should map HTTP status codes to correct exception types.""" + mock_request.return_value = MagicMock( + status_code=status, + ok=False, + text='{"message": "Error"}', + headers={"X-Request-ID": "req-test"}, + json=MagicMock(return_value={"message": "Error"}), + reason="Error", + ) + + request = MetorialRequest(path="/test") + + # Skip retry for 429 by setting try_count to 3 + try_count = 3 if status == 429 else 0 + + with pytest.raises(expected_exception): + endpoint_manager._request("GET", request, try_count=try_count) + + +# ============================================================================= +# Debug Logging Tests +# ============================================================================= + + +class TestDebugLogging: + """Tests for debug logging functionality.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + @patch("metorial._endpoint.endpoint_manager.logger") + def test_logs_request_when_debug_enabled( + self, + mock_logger: MagicMock, + mock_request: MagicMock, + ) -> None: + """Should log request when debug logging is enabled.""" + config = {"apiKey": "test-key"} + manager = MetorialEndpointManager( + config=config, + api_host="https://api.metorial.com", + get_headers=lambda c: {}, + enable_debug_logging=True, + ) + + mock_request.return_value = MagicMock( + status_code=200, + ok=True, + text='{"data": "test"}', + headers={}, + json=MagicMock(return_value={"data": "test"}), + ) + + request = MetorialRequest(path="/test") + manager._request("GET", request) + + mock_logger.debug.assert_called() + + @patch("metorial._endpoint.endpoint_manager.requests.request") + @patch("metorial._endpoint.endpoint_manager.logger") + def test_logs_error_with_request_id_when_debug_enabled( + self, + mock_logger: MagicMock, + mock_request: MagicMock, + ) -> None: + """Should log error with request ID when debug logging is enabled.""" + config = {"apiKey": "test-key"} + manager = MetorialEndpointManager( + config=config, + api_host="https://api.metorial.com", + get_headers=lambda c: {}, + enable_debug_logging=True, + ) + + mock_request.return_value = MagicMock( + status_code=400, + ok=False, + text='{"message": "Error"}', + headers={"X-Request-ID": "req-debug-test"}, + json=MagicMock(return_value={"message": "Error"}), + reason="Bad Request", + ) + + request = MetorialRequest(path="/test") + with pytest.raises(BadRequestError): + manager._request("GET", request) + + # Verify error was logged with request_id + mock_logger.error.assert_called() + call_args = str(mock_logger.error.call_args) + assert "req-debug-test" in call_args + + +# ============================================================================= +# URL Construction Tests +# ============================================================================= + + +class TestUrlConstruction: + """Tests for URL construction.""" + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_constructs_url_from_string_path( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should construct URL from string path.""" + mock_request.return_value = MagicMock( + status_code=200, + ok=True, + text="{}", + headers={}, + json=MagicMock(return_value={}), + ) + + request = MetorialRequest(path="users/123") + endpoint_manager._request("GET", request) + + # requests.request is called with positional args: method, url + call_args = mock_request.call_args + called_url = ( + call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url") + ) + assert "users/123" in called_url + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_constructs_url_from_list_path( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should construct URL from list path.""" + mock_request.return_value = MagicMock( + status_code=200, + ok=True, + text="{}", + headers={}, + json=MagicMock(return_value={}), + ) + + request = MetorialRequest(path=["users", "123", "sessions"]) + endpoint_manager._request("GET", request) + + # requests.request is called with positional args: method, url + call_args = mock_request.call_args + called_url = ( + call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url") + ) + assert "users/123/sessions" in called_url + + @patch("metorial._endpoint.endpoint_manager.requests.request") + def test_uses_custom_host( + self, mock_request: MagicMock, endpoint_manager: MetorialEndpointManager + ) -> None: + """Should use custom host when specified.""" + mock_request.return_value = MagicMock( + status_code=200, + ok=True, + text="{}", + headers={}, + json=MagicMock(return_value={}), + ) + + request = MetorialRequest( + path="test", + host="https://custom.metorial.com", + ) + endpoint_manager._request("GET", request) + + # requests.request is called with positional args: method, url + call_args = mock_request.call_args + called_url = ( + call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("url") + ) + assert "custom.metorial.com" in called_url diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..4fc46b1 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,553 @@ +""" +Tests for exception classes and make_status_error factory. +""" + +import pytest + +from metorial.exceptions import ( + AuthenticationError, + BadRequestError, + ConflictError, + InternalServerError, + MetorialAPIError, + MetorialConfigError, + MetorialConnectionError, + MetorialDuplicateToolError, + MetorialError, + MetorialSDKError, + MetorialSessionError, + MetorialTimeoutError, + MetorialToolError, + NotFoundError, + PermissionDeniedError, + RateLimitError, + UnprocessableEntityError, + is_metorial_sdk_error, + make_status_error, +) + +# ============================================================================= +# make_status_error Factory Tests +# ============================================================================= + + +class TestMakeStatusError: + """Tests for make_status_error factory function.""" + + def test_returns_bad_request_error_for_400(self) -> None: + err = make_status_error(400, "Bad request") + assert isinstance(err, BadRequestError) + assert err.status_code == 400 + assert err.message == "Bad request" + + def test_returns_authentication_error_for_401(self) -> None: + err = make_status_error(401, "Invalid API key") + assert isinstance(err, AuthenticationError) + assert err.status_code == 401 + + def test_returns_permission_denied_error_for_403(self) -> None: + err = make_status_error(403, "Forbidden") + assert isinstance(err, PermissionDeniedError) + assert err.status_code == 403 + + def test_returns_not_found_error_for_404(self) -> None: + err = make_status_error(404, "Not found") + assert isinstance(err, NotFoundError) + assert err.status_code == 404 + + def test_returns_conflict_error_for_409(self) -> None: + err = make_status_error(409, "Conflict") + assert isinstance(err, ConflictError) + assert err.status_code == 409 + + def test_returns_unprocessable_entity_error_for_422(self) -> None: + err = make_status_error(422, "Validation failed") + assert isinstance(err, UnprocessableEntityError) + assert err.status_code == 422 + + def test_returns_rate_limit_error_for_429(self) -> None: + err = make_status_error(429, "Rate limited", request_id="req-123") + assert isinstance(err, RateLimitError) + assert err.status_code == 429 + assert err.request_id == "req-123" + + def test_returns_internal_server_error_for_500(self) -> None: + err = make_status_error(500, "Internal server error") + assert isinstance(err, InternalServerError) + assert err.status_code == 500 + + def test_returns_internal_server_error_for_502(self) -> None: + err = make_status_error(502, "Bad gateway") + assert isinstance(err, InternalServerError) + assert err.status_code == 502 + + def test_returns_internal_server_error_for_503(self) -> None: + err = make_status_error(503, "Service unavailable") + assert isinstance(err, InternalServerError) + assert err.status_code == 503 + + def test_returns_internal_server_error_for_504(self) -> None: + err = make_status_error(504, "Gateway timeout") + assert isinstance(err, InternalServerError) + assert err.status_code == 504 + + def test_returns_generic_api_error_for_unknown_status(self) -> None: + err = make_status_error(418, "I'm a teapot") + assert isinstance(err, MetorialAPIError) + assert not isinstance(err, BadRequestError) + assert err.status_code == 418 + + def test_returns_generic_api_error_for_uncommon_4xx(self) -> None: + for status in [402, 405, 406, 407, 408, 410, 411, 412, 413, 414, 415]: + err = make_status_error(status, f"Error {status}") + assert isinstance(err, MetorialAPIError) + assert err.status_code == status + + def test_includes_request_id(self) -> None: + err = make_status_error(404, "Not found", request_id="req-abc-123") + assert err.request_id == "req-abc-123" + assert "request_id=req-abc-123" in str(err) + + def test_includes_body_dict(self) -> None: + body = {"error": "details", "code": "RESOURCE_NOT_FOUND"} + err = make_status_error(404, "Not found", body=body) + assert err.body == body + + def test_includes_body_string(self) -> None: + body = "Raw error text" + err = make_status_error(500, "Server error", body=body) + assert err.body == body + + def test_all_parameters_together(self) -> None: + body = {"detail": "User not found"} + err = make_status_error( + status=404, + message="User not found", + request_id="req-xyz-789", + body=body, + ) + assert err.status_code == 404 + assert err.message == "User not found" + assert err.request_id == "req-xyz-789" + assert err.body == body + + +# ============================================================================= +# Exception String Representation Tests +# ============================================================================= + + +class TestExceptionStringRepresentation: + """Tests for exception __str__ method.""" + + def test_str_includes_message_and_status(self) -> None: + err = make_status_error(404, "Resource not found") + s = str(err) + assert "Resource not found" in s + assert "status=404" in s + + def test_str_includes_request_id_when_present(self) -> None: + err = make_status_error(500, "Server error", request_id="req-xyz") + s = str(err) + assert "request_id=req-xyz" in s + + def test_str_without_request_id(self) -> None: + err = make_status_error(400, "Bad request") + s = str(err) + assert "request_id" not in s + + def test_str_with_empty_request_id(self) -> None: + err = make_status_error(400, "Bad request", request_id="") + s = str(err) + # Empty string is falsy, so should not appear + assert "request_id=" not in s + + def test_str_format_consistency(self) -> None: + """Verify string format is consistent across exception types.""" + error_classes = [ + (400, BadRequestError), + (401, AuthenticationError), + (403, PermissionDeniedError), + (404, NotFoundError), + (429, RateLimitError), + (500, InternalServerError), + ] + for status, _cls in error_classes: + err = make_status_error(status, "Test message", request_id="req-test") + s = str(err) + assert "Test message" in s + assert f"status={status}" in s + assert "request_id=req-test" in s + + +# ============================================================================= +# Exception Inheritance Tests +# ============================================================================= + + +class TestExceptionInheritance: + """Tests for exception class hierarchy.""" + + def test_all_status_errors_inherit_from_metorial_api_error(self) -> None: + error_classes = [ + BadRequestError, + AuthenticationError, + PermissionDeniedError, + NotFoundError, + ConflictError, + UnprocessableEntityError, + RateLimitError, + InternalServerError, + ] + for cls in error_classes: + err = cls("test") + assert isinstance(err, MetorialAPIError) + assert isinstance(err, MetorialSDKError) + assert isinstance(err, MetorialError) + assert isinstance(err, Exception) + + def test_can_catch_specific_error(self) -> None: + err = make_status_error(429, "Rate limited") + with pytest.raises(RateLimitError): + raise err + + def test_can_catch_generic_api_error(self) -> None: + err = make_status_error(429, "Rate limited") + with pytest.raises(MetorialAPIError): + raise err + + def test_can_catch_sdk_error(self) -> None: + err = make_status_error(500, "Server error") + with pytest.raises(MetorialSDKError): + raise err + + def test_can_catch_base_metorial_error(self) -> None: + err = make_status_error(404, "Not found") + with pytest.raises(MetorialError): + raise err + + +# ============================================================================= +# Individual Exception Class Tests +# ============================================================================= + + +class TestBadRequestError: + """Tests for BadRequestError (400).""" + + def test_creation(self) -> None: + err = BadRequestError("Invalid input") + assert err.message == "Invalid input" + assert isinstance(err, MetorialAPIError) + + def test_with_body(self) -> None: + err = BadRequestError( + "Validation failed", + status_code=400, + body={"errors": [{"field": "email", "message": "Invalid format"}]}, + ) + assert err.body["errors"][0]["field"] == "email" + + +class TestAuthenticationError: + """Tests for AuthenticationError (401).""" + + def test_creation(self) -> None: + err = AuthenticationError("Invalid API key") + assert err.message == "Invalid API key" + + def test_with_request_id(self) -> None: + err = AuthenticationError( + "Token expired", + status_code=401, + request_id="req-auth-fail", + ) + assert err.request_id == "req-auth-fail" + + +class TestRateLimitError: + """Tests for RateLimitError (429).""" + + def test_creation(self) -> None: + err = RateLimitError("Rate limit exceeded") + assert err.message == "Rate limit exceeded" + + def test_with_retry_info_in_body(self) -> None: + err = RateLimitError( + "Too many requests", + status_code=429, + body={"retry_after": 60, "limit": 100, "remaining": 0}, + ) + assert err.body["retry_after"] == 60 + assert err.body["limit"] == 100 + + +class TestInternalServerError: + """Tests for InternalServerError (5xx).""" + + def test_creation(self) -> None: + err = InternalServerError("Database connection failed") + assert err.message == "Database connection failed" + + def test_with_trace_id(self) -> None: + err = InternalServerError( + "Unexpected error", + status_code=500, + request_id="trace-abc-123", + body={"trace_id": "trace-abc-123"}, + ) + assert err.request_id == "trace-abc-123" + + +# ============================================================================= +# Domain-Specific Exception Tests +# ============================================================================= + + +class TestMetorialToolError: + """Tests for MetorialToolError.""" + + def test_creation(self) -> None: + err = MetorialToolError("Tool execution failed", tool_name="my_tool") + assert err.message == "Tool execution failed" + assert err.tool_name == "my_tool" + + def test_with_args(self) -> None: + err = MetorialToolError( + "Invalid arguments", + tool_name="search", + tool_args={"query": "test"}, + ) + assert err.tool_args == {"query": "test"} + + def test_str_includes_tool_name(self) -> None: + err = MetorialToolError("Failed", tool_name="test_tool") + assert "test_tool" in str(err) + + +class TestMetorialTimeoutError: + """Tests for MetorialTimeoutError.""" + + def test_creation(self) -> None: + err = MetorialTimeoutError("Request timed out", timeout_duration=30.0) + assert err.timeout_duration == 30.0 + + def test_with_operation(self) -> None: + err = MetorialTimeoutError( + "Timeout", + timeout_duration=10.0, + operation="tool_execution", + ) + assert err.operation == "tool_execution" + + def test_str_includes_timeout_info(self) -> None: + err = MetorialTimeoutError( + "Timed out", + timeout_duration=5.0, + operation="api_call", + ) + s = str(err) + assert "5.0" in s + assert "api_call" in s + + +class TestMetorialSessionError: + """Tests for MetorialSessionError.""" + + def test_creation(self) -> None: + err = MetorialSessionError("Session closed", session_id="sess-123") + assert err.session_id == "sess-123" + + def test_with_deployment_id(self) -> None: + err = MetorialSessionError( + "Connection failed", + deployment_id="deploy-abc", + ) + assert err.deployment_id == "deploy-abc" + + +class TestMetorialConfigError: + """Tests for MetorialConfigError.""" + + def test_creation(self) -> None: + err = MetorialConfigError("Invalid config", config_key="api_key") + assert err.config_key == "api_key" + + def test_with_value(self) -> None: + err = MetorialConfigError( + "Invalid timeout", + config_key="timeout", + config_value=-1, + ) + assert err.config_value == -1 + + +class TestMetorialConnectionError: + """Tests for MetorialConnectionError.""" + + def test_creation(self) -> None: + err = MetorialConnectionError("Connection refused", host="api.metorial.com") + assert err.host == "api.metorial.com" + + def test_with_retry_count(self) -> None: + err = MetorialConnectionError( + "Failed after retries", + host="api.example.com", + retry_count=3, + ) + assert err.retry_count == 3 + + +class TestMetorialDuplicateToolError: + """Tests for MetorialDuplicateToolError.""" + + def test_creation(self) -> None: + err = MetorialDuplicateToolError( + "Duplicate tool name", + tool_name="search", + ) + assert err.tool_name == "search" + + +# ============================================================================= +# Utility Function Tests +# ============================================================================= + + +class TestIsMetorialSdkError: + """Tests for is_metorial_sdk_error utility function.""" + + def test_returns_true_for_sdk_error(self) -> None: + err = MetorialSDKError({"message": "test", "status": 500, "code": "error"}) + assert is_metorial_sdk_error(err) is True + + def test_returns_true_for_api_error(self) -> None: + err = make_status_error(404, "Not found") + assert is_metorial_sdk_error(err) is True + + def test_returns_false_for_base_error(self) -> None: + err = MetorialError("test") + assert is_metorial_sdk_error(err) is False + + def test_returns_false_for_standard_exception(self) -> None: + err = ValueError("test") + assert is_metorial_sdk_error(err) is False + + def test_returns_false_for_none_attribute(self) -> None: + class FakeError(Exception): + pass + + err = FakeError("test") + assert is_metorial_sdk_error(err) is False + + +# ============================================================================= +# MetorialError.is_metorial_error Tests +# ============================================================================= + + +class TestIsMetorialError: + """Tests for MetorialError.is_metorial_error static method.""" + + def test_returns_true_for_metorial_error(self) -> None: + err = MetorialError("test") + assert MetorialError.is_metorial_error(err) is True + + def test_returns_true_for_api_error(self) -> None: + err = make_status_error(400, "Bad request") + assert MetorialError.is_metorial_error(err) is True + + def test_returns_true_for_tool_error(self) -> None: + err = MetorialToolError("Failed", tool_name="test") + assert MetorialError.is_metorial_error(err) is True + + def test_returns_false_for_standard_exception(self) -> None: + err = ValueError("test") + assert MetorialError.is_metorial_error(err) is False + + +# ============================================================================= +# Exception Backwards Compatibility Tests +# ============================================================================= + + +class TestBackwardsCompatibility: + """Tests for backwards compatibility with existing code.""" + + def test_metorial_api_error_status_code_attribute(self) -> None: + """Ensure status_code attribute is available for existing code.""" + err = MetorialAPIError("Test", status_code=404) + assert err.status_code == 404 + + def test_metorial_api_error_response_data_attribute(self) -> None: + """Ensure response_data attribute is available for existing code.""" + err = MetorialAPIError("Test", response_data={"key": "value"}) + assert err.response_data == {"key": "value"} + + def test_metorial_api_error_status_attribute(self) -> None: + """Ensure status attribute is available from parent class.""" + err = MetorialAPIError("Test", status_code=500) + assert err.status == 500 + + def test_metorial_sdk_error_data_attribute(self) -> None: + """Ensure data attribute is available on SDK errors.""" + err = MetorialSDKError({"message": "test", "status": 400, "code": "error"}) + assert err.data["status"] == 400 + + +# ============================================================================= +# Edge Cases and Error Scenarios +# ============================================================================= + + +class TestEdgeCases: + """Tests for edge cases and unusual scenarios.""" + + def test_empty_message(self) -> None: + err = make_status_error(400, "") + assert err.message == "" + + def test_very_long_message(self) -> None: + long_msg = "x" * 10000 + err = make_status_error(500, long_msg) + assert err.message == long_msg + + def test_unicode_message(self) -> None: + err = make_status_error(400, "Error: 你好世界 🌍") + assert "你好世界" in err.message + assert "🌍" in err.message + + def test_none_body(self) -> None: + err = make_status_error(404, "Not found", body=None) + assert err.body is None + + def test_empty_dict_body(self) -> None: + err = make_status_error(400, "Bad request", body={}) + assert err.body == {} + + def test_nested_body(self) -> None: + body = { + "errors": [ + {"field": "name", "errors": ["required", "too_short"]}, + {"field": "email", "errors": ["invalid_format"]}, + ], + "meta": {"request_id": "123"}, + } + err = make_status_error(422, "Validation failed", body=body) + assert err.body["errors"][0]["field"] == "name" + assert err.body["meta"]["request_id"] == "123" + + def test_special_characters_in_request_id(self) -> None: + req_id = "req_abc-123.xyz/456" + err = make_status_error(500, "Error", request_id=req_id) + assert err.request_id == req_id + + def test_exception_can_be_pickled(self) -> None: + """Ensure exceptions can be pickled for multiprocessing.""" + import pickle + + err = make_status_error(404, "Not found", request_id="req-123") + pickled = pickle.dumps(err) + restored = pickle.loads(pickled) + assert restored.message == "Not found" + assert restored.status_code == 404 diff --git a/tests/test_raw_response.py b/tests/test_raw_response.py new file mode 100644 index 0000000..eb2d4ab --- /dev/null +++ b/tests/test_raw_response.py @@ -0,0 +1,458 @@ +""" +Tests for RawResponse wrapper class. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from metorial._raw_response import RawResponse, RawResponseWrapper + +# ============================================================================= +# RawResponse Basic Tests +# ============================================================================= + + +class TestRawResponseBasics: + """Basic tests for RawResponse class.""" + + def test_parse_returns_parsed_data(self, mock_http_response: MagicMock) -> None: + parsed_data = {"id": "123", "name": "test"} + raw = RawResponse(mock_http_response, parsed_data) + assert raw.parse() == parsed_data + + def test_status_code(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + assert raw.status_code == 200 + + def test_request_id(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + assert raw.request_id == "req-test-123" + + def test_request_id_none_when_missing(self) -> None: + response = MagicMock() + response.status_code = 200 + response.headers = {"Content-Type": "application/json"} + raw = RawResponse(response, {}) + assert raw.request_id is None + + def test_headers(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + headers = raw.headers + assert headers["X-Request-ID"] == "req-test-123" + assert headers["Content-Type"] == "application/json" + + def test_content_type(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + assert raw.content_type == "application/json" + + def test_content_type_none_when_missing(self) -> None: + response = MagicMock() + response.status_code = 200 + response.headers = {} + raw = RawResponse(response, {}) + assert raw.content_type is None + + +# ============================================================================= +# RawResponse is_success Tests +# ============================================================================= + + +class TestRawResponseIsSuccess: + """Tests for is_success property.""" + + @pytest.mark.parametrize("status", [200, 201, 202, 203, 204, 205, 206]) + def test_is_success_true_for_2xx(self, status: int) -> None: + response = MagicMock() + response.status_code = status + response.headers = {} + raw = RawResponse(response, {}) + assert raw.is_success is True + + @pytest.mark.parametrize("status", [100, 101, 102]) + def test_is_success_false_for_1xx(self, status: int) -> None: + response = MagicMock() + response.status_code = status + response.headers = {} + raw = RawResponse(response, {}) + assert raw.is_success is False + + @pytest.mark.parametrize("status", [300, 301, 302, 303, 304, 307, 308]) + def test_is_success_false_for_3xx(self, status: int) -> None: + response = MagicMock() + response.status_code = status + response.headers = {} + raw = RawResponse(response, {}) + assert raw.is_success is False + + @pytest.mark.parametrize("status", [400, 401, 403, 404, 422, 429]) + def test_is_success_false_for_4xx(self, status: int) -> None: + response = MagicMock() + response.status_code = status + response.headers = {} + raw = RawResponse(response, {}) + assert raw.is_success is False + + @pytest.mark.parametrize("status", [500, 501, 502, 503, 504]) + def test_is_success_false_for_5xx(self, status: int) -> None: + response = MagicMock() + response.status_code = status + response.headers = {} + raw = RawResponse(response, {}) + assert raw.is_success is False + + +# ============================================================================= +# RawResponse __repr__ Tests +# ============================================================================= + + +class TestRawResponseRepr: + """Tests for __repr__ method.""" + + def test_repr(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + repr_str = repr(raw) + assert "RawResponse" in repr_str + assert "status_code=200" in repr_str + assert "request_id='req-test-123'" in repr_str + assert "content_type='application/json'" in repr_str + + def test_repr_with_none_values(self) -> None: + response = MagicMock() + response.status_code = 204 + response.headers = {} + raw = RawResponse(response, None) + repr_str = repr(raw) + assert "status_code=204" in repr_str + assert "request_id=None" in repr_str + assert "content_type=None" in repr_str + + +# ============================================================================= +# RawResponse Generic Type Tests +# ============================================================================= + + +class TestRawResponseGenericType: + """Tests for generic type preservation.""" + + def test_dict_type_preserved(self, mock_http_response: MagicMock) -> None: + data: dict[str, Any] = {"items": [1, 2, 3], "count": 3} + raw: RawResponse[dict[str, Any]] = RawResponse(mock_http_response, data) + result = raw.parse() + assert result["items"] == [1, 2, 3] + assert result["count"] == 3 + + def test_list_type_preserved(self, mock_http_response: MagicMock) -> None: + data = [{"id": 1}, {"id": 2}, {"id": 3}] + raw: RawResponse[list[dict[str, int]]] = RawResponse(mock_http_response, data) + result = raw.parse() + assert len(result) == 3 + assert result[0]["id"] == 1 + + def test_string_type_preserved(self, mock_http_response: MagicMock) -> None: + data = "plain text response" + raw: RawResponse[str] = RawResponse(mock_http_response, data) + result = raw.parse() + assert result == "plain text response" + + def test_none_type_preserved(self, mock_http_response: MagicMock) -> None: + raw: RawResponse[None] = RawResponse(mock_http_response, None) + result = raw.parse() + assert result is None + + def test_custom_object_type_preserved(self, mock_http_response: MagicMock) -> None: + class User: + def __init__(self, id: int, name: str): + self.id = id + self.name = name + + user = User(id=123, name="Test") + raw: RawResponse[User] = RawResponse(mock_http_response, user) + result = raw.parse() + assert result.id == 123 + assert result.name == "Test" + + +# ============================================================================= +# RawResponse with Different Response Types Tests +# ============================================================================= + + +class TestRawResponseWithDifferentResponseTypes: + """Tests for RawResponse with different HTTP response implementations.""" + + def test_with_dict_like_headers(self) -> None: + """Test with a response that has dict-like headers.""" + response = MagicMock() + response.status_code = 200 + response.headers = { + "X-Request-ID": "req-dict-123", + "Content-Type": "text/plain", + "X-Custom-Header": "custom-value", + } + raw = RawResponse(response, "data") + assert raw.request_id == "req-dict-123" + assert raw.content_type == "text/plain" + assert raw.headers["X-Custom-Header"] == "custom-value" + + def test_with_case_sensitive_headers(self) -> None: + """Test header access is case-sensitive by default.""" + response = MagicMock() + response.status_code = 200 + response.headers = { + "X-Request-ID": "req-123", + "x-request-id": "req-456", # Different case + } + raw = RawResponse(response, {}) + # Should get the exact case match + assert raw.request_id == "req-123" + + +# ============================================================================= +# RawResponseWrapper Tests +# ============================================================================= + + +class TestRawResponseWrapper: + """Tests for RawResponseWrapper class.""" + + def test_to_raw_creates_raw_response(self, mock_http_response: MagicMock) -> None: + parsed = {"key": "value"} + wrapper = RawResponseWrapper(parsed, mock_http_response) + raw = wrapper.to_raw() + assert isinstance(raw, RawResponse) + assert raw.parse() == parsed + assert raw.status_code == 200 + + def test_wrapper_preserves_parsed_data(self, mock_http_response: MagicMock) -> None: + parsed = [1, 2, 3] + wrapper = RawResponseWrapper(parsed, mock_http_response) + raw = wrapper.to_raw() + assert raw.parse() == [1, 2, 3] + + def test_wrapper_preserves_response(self, mock_http_response: MagicMock) -> None: + parsed = "test" + wrapper = RawResponseWrapper(parsed, mock_http_response) + raw = wrapper.to_raw() + assert raw.request_id == "req-test-123" + + +# ============================================================================= +# RawResponse Edge Cases Tests +# ============================================================================= + + +class TestRawResponseEdgeCases: + """Tests for edge cases and unusual scenarios.""" + + def test_empty_dict_data(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + assert raw.parse() == {} + + def test_empty_list_data(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, []) + assert raw.parse() == [] + + def test_empty_string_data(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, "") + assert raw.parse() == "" + + def test_zero_value_data(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, 0) + assert raw.parse() == 0 + + def test_false_value_data(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, False) + assert raw.parse() is False + + def test_deeply_nested_data(self, mock_http_response: MagicMock) -> None: + data = {"level1": {"level2": {"level3": {"level4": [1, 2, {"level5": "deep"}]}}}} + raw = RawResponse(mock_http_response, data) + result = raw.parse() + assert result["level1"]["level2"]["level3"]["level4"][2]["level5"] == "deep" + + def test_large_data(self, mock_http_response: MagicMock) -> None: + data = {"items": list(range(10000))} + raw = RawResponse(mock_http_response, data) + result = raw.parse() + assert len(result["items"]) == 10000 + + def test_unicode_data(self, mock_http_response: MagicMock) -> None: + data = {"message": "Hello 世界 🌍 مرحبا"} + raw = RawResponse(mock_http_response, data) + result = raw.parse() + assert "世界" in result["message"] + assert "🌍" in result["message"] + + def test_binary_like_data(self, mock_http_response: MagicMock) -> None: + data = b"binary content" + raw = RawResponse(mock_http_response, data) + assert raw.parse() == b"binary content" + + +# ============================================================================= +# RawResponse Headers Tests +# ============================================================================= + + +class TestRawResponseHeaders: + """Tests for headers access.""" + + def test_headers_returns_dict(self, mock_http_response: MagicMock) -> None: + raw = RawResponse(mock_http_response, {}) + headers = raw.headers + assert isinstance(headers, dict) + + def test_headers_is_copy(self, mock_http_response: MagicMock) -> None: + """Modifying returned headers should not affect the response.""" + raw = RawResponse(mock_http_response, {}) + headers = raw.headers + headers["New-Header"] = "new-value" + # Original response headers should not be modified + assert "New-Header" not in mock_http_response.headers + + def test_multiple_headers_access(self, mock_http_response: MagicMock) -> None: + """Multiple calls to headers should return consistent data.""" + raw = RawResponse(mock_http_response, {}) + headers1 = raw.headers + headers2 = raw.headers + assert headers1 == headers2 + + def test_special_header_values(self) -> None: + """Test headers with special characters in values.""" + response = MagicMock() + response.status_code = 200 + response.headers = { + "Content-Type": "application/json; charset=utf-8", + "Cache-Control": "no-cache, no-store, must-revalidate", + "Link": '; rel="next"', + } + raw = RawResponse(response, {}) + assert raw.content_type == "application/json; charset=utf-8" + assert raw.headers["Cache-Control"] == "no-cache, no-store, must-revalidate" + + +# ============================================================================= +# RawResponse Immutability Tests +# ============================================================================= + + +class TestRawResponseImmutability: + """Tests to ensure RawResponse behaves correctly with data.""" + + def test_parse_returns_same_object(self, mock_http_response: MagicMock) -> None: + """Multiple calls to parse() should return the same object.""" + data = {"key": "value"} + raw = RawResponse(mock_http_response, data) + result1 = raw.parse() + result2 = raw.parse() + assert result1 is result2 + + def test_modifying_parsed_data_affects_future_calls( + self, mock_http_response: MagicMock + ) -> None: + """Modifying parsed data affects future parse() calls (same reference).""" + data: dict[str, Any] = {"key": "value"} + raw = RawResponse(mock_http_response, data) + result = raw.parse() + result["new_key"] = "new_value" + # Should see the modification + assert raw.parse()["new_key"] == "new_value" + + +# ============================================================================= +# RawResponse with Real-world-like Scenarios +# ============================================================================= + + +class TestRawResponseRealWorldScenarios: + """Tests simulating real-world usage scenarios.""" + + def test_api_list_response(self) -> None: + """Test handling of paginated list response.""" + response = MagicMock() + response.status_code = 200 + response.headers = { + "X-Request-ID": "req-list-123", + "Content-Type": "application/json", + "X-Total-Count": "100", + "X-Page": "1", + "X-Per-Page": "10", + } + data = { + "items": [{"id": i} for i in range(10)], + "total": 100, + "page": 1, + "per_page": 10, + } + raw = RawResponse(response, data) + + assert raw.is_success + assert raw.request_id == "req-list-123" + assert len(raw.parse()["items"]) == 10 + assert raw.headers["X-Total-Count"] == "100" + + def test_api_create_response(self) -> None: + """Test handling of resource creation response.""" + response = MagicMock() + response.status_code = 201 + response.headers = { + "X-Request-ID": "req-create-456", + "Content-Type": "application/json", + "Location": "https://api.example.com/resources/new-123", + } + data = { + "id": "new-123", + "name": "New Resource", + "created_at": "2024-01-15T10:30:00Z", + } + raw = RawResponse(response, data) + + assert raw.is_success + assert raw.status_code == 201 + assert raw.parse()["id"] == "new-123" + assert raw.headers["Location"] == "https://api.example.com/resources/new-123" + + def test_api_delete_response(self) -> None: + """Test handling of deletion response (204 No Content).""" + response = MagicMock() + response.status_code = 204 + response.headers = { + "X-Request-ID": "req-delete-789", + } + raw = RawResponse(response, None) + + assert raw.is_success + assert raw.status_code == 204 + assert raw.parse() is None + + def test_error_response_parsing(self) -> None: + """Test handling of error response for debugging.""" + response = MagicMock() + response.status_code = 422 + response.headers = { + "X-Request-ID": "req-error-abc", + "Content-Type": "application/json", + } + data = { + "error": "validation_error", + "message": "Validation failed", + "details": [ + {"field": "email", "error": "Invalid email format"}, + {"field": "age", "error": "Must be positive"}, + ], + } + raw = RawResponse(response, data) + + assert not raw.is_success + assert raw.status_code == 422 + assert raw.request_id == "req-error-abc" + errors = raw.parse() + assert errors["error"] == "validation_error" + assert len(errors["details"]) == 2 diff --git a/tests/test_safe_cleanup.py b/tests/test_safe_cleanup.py new file mode 100644 index 0000000..c7dee76 --- /dev/null +++ b/tests/test_safe_cleanup.py @@ -0,0 +1,231 @@ +""" +Tests for safe AsyncIO cleanup functionality. +Fast, deterministic tests to ensure proper resource cleanup without warnings. +""" + +import asyncio +import logging +import warnings + +import pytest + +from metorial._safe_cleanup import ( + attach_noise_filters, + drain_pending_tasks, + install_warning_filters, + quiet_asyncio_shutdown, +) + + +class TestWarningFilters: + """Test warning filter installation""" + + def test_install_warning_filters(self): + """Test that warning filters are installed correctly""" + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") # Capture all warnings + + install_warning_filters() + + # Try to trigger the specific warnings that should be filtered + # These would normally generate RuntimeWarnings + warnings.warn("generator didn't stop after athrow", RuntimeWarning, stacklevel=2) + warnings.warn("Attempted to exit cancel scope", RuntimeWarning, stacklevel=2) + warnings.warn( + "an error occurred during closing of asynchronous generator", + RuntimeWarning, + stacklevel=2, + ) + + # Check that SSE-related warnings were filtered + sse_warnings = [ + w + for w in caught_warnings + if any( + phrase in str(w.message) + for phrase in [ + "generator didn't stop", + "cancel scope", + "closing of asynchronous generator", + ] + ) + ] + + assert len(sse_warnings) == 0, f"SSE warnings should be filtered: {sse_warnings}" + + +class TestQuietAsyncioShutdown: + """Test the scoped exception handler context manager""" + + @pytest.mark.asyncio + async def test_quiet_shutdown_context(self): + """Test that quiet_asyncio_shutdown provides scoped suppression""" + + # Track handler changes + loop = asyncio.get_running_loop() + original_handler = loop.get_exception_handler() + + handler_during_context = None + + with quiet_asyncio_shutdown(): + handler_during_context = loop.get_exception_handler() + + handler_after_context = loop.get_exception_handler() + + # Handler should change during context + assert handler_during_context != original_handler + + # Handler should be restored after context + assert handler_after_context == original_handler + + @pytest.mark.asyncio + async def test_suppresses_known_noise(self): + """Test that known SSE cleanup noise is suppressed""" + + exception_caught = False + + def test_handler(loop, context): + nonlocal exception_caught + exception_caught = True + + loop = asyncio.get_running_loop() + loop.set_exception_handler(test_handler) + + try: + with quiet_asyncio_shutdown(): + # Simulate the types of exceptions that should be suppressed + context = { + "exception": RuntimeError("generator didn't stop after athrow"), + "message": "Test SSE cleanup error", + } + loop.call_exception_handler(context) + + # Exception should have been suppressed + assert not exception_caught, "SSE cleanup exception should be suppressed" + + finally: + loop.set_exception_handler(None) + + @pytest.mark.asyncio + async def test_preserves_real_exceptions(self): + """Test that real exceptions still surface properly""" + + caught_contexts = [] + + def test_handler(loop, context): + caught_contexts.append(context) + + loop = asyncio.get_running_loop() + original_handler = loop.get_exception_handler() + loop.set_exception_handler(test_handler) + + try: + with quiet_asyncio_shutdown(): + # Simulate a real exception that should NOT be suppressed + context = { + "exception": ValueError("Real application error"), + "message": "Important user error", + } + loop.call_exception_handler(context) + + # Real exception should still be processed (by default handler within our handler) + # The context should have been passed to our test handler + assert len(caught_contexts) > 0, "Real exceptions should still be handled" + assert caught_contexts[0]["message"] == "Important user error" + + finally: + loop.set_exception_handler(original_handler) + + +class TestDrainPendingTasks: + """Test the task draining utility""" + + @pytest.mark.asyncio + async def test_drain_empty_tasks(self): + """Test draining when no tasks are pending""" + # Should complete immediately without error + await drain_pending_tasks(timeout=0.1) + + @pytest.mark.asyncio + async def test_drain_completed_tasks(self): + """Test draining when tasks complete normally""" + + async def quick_task(): + await asyncio.sleep(0.01) + return "completed" + + # Start task + task = asyncio.create_task(quick_task()) + + # Wait for it to complete + await task + + # Draining should work even with completed tasks + await drain_pending_tasks(timeout=0.1) + + @pytest.mark.asyncio + async def test_drain_cancels_hanging_tasks(self): + """Test that hanging tasks are cancelled on timeout""" + + async def hanging_task(): + try: + await asyncio.sleep(10) # Long-running task + return "should_not_reach" + except asyncio.CancelledError: + return "cancelled" + + # Start hanging task + task = asyncio.create_task(hanging_task()) + + # Drain with short timeout should cancel the task + await drain_pending_tasks(timeout=0.1) + + # Task should be cancelled + assert task.cancelled() or task.done() + + +@pytest.mark.asyncio +async def test_no_shutdown_warnings(recwarn, caplog): + """Integration test: ensure no shutdown warnings are captured""" + + caplog.set_level(logging.DEBUG) + + # Install filters + install_warning_filters() + attach_noise_filters() + + # Simulate a complete shutdown sequence with SSE cleanup + with quiet_asyncio_shutdown(): + try: + # Simulate some async work + await asyncio.sleep(0.01) + finally: + await drain_pending_tasks(timeout=0.1) + + # Check that no SSE-related warnings were captured + sse_warnings = [ + w + for w in recwarn + if any( + phrase in str(w.message) + for phrase in [ + "generator didn't stop", + "cancel scope", + "closing of asynchronous generator", + ] + ) + ] + + assert len(sse_warnings) == 0, f"Should have no SSE warnings: {sse_warnings}" + + # Check that no SSE-related log messages were captured + sse_logs = [ + r + for r in caplog.records + if any( + phrase in r.message + for phrase in ["closing of asynchronous generator", "sse_client", "aconnect_sse"] + ) + ] + + assert len(sse_logs) == 0, f"Should have no SSE log noise: {sse_logs}" diff --git a/tests/test_tool_adapters.py b/tests/test_tool_adapters.py new file mode 100644 index 0000000..1d279bd --- /dev/null +++ b/tests/test_tool_adapters.py @@ -0,0 +1,176 @@ +""" +Tests for tool adapter functionality. +""" + +from unittest.mock import MagicMock + +from metorial._tool_adapters import ( + ToolFormatAdapter, + ToolSanitizer, +) + + +class TestSanitizeFunctionName: + """Tests for ToolFormatAdapter.sanitize_function_name""" + + def test_sanitize_function_name_valid(self): + """Valid names should pass through unchanged.""" + assert ToolFormatAdapter.sanitize_function_name("my_tool") == "my_tool" + assert ToolFormatAdapter.sanitize_function_name("myTool") == "myTool" + assert ToolFormatAdapter.sanitize_function_name("tool123") == "tool123" + assert ToolFormatAdapter.sanitize_function_name("my-tool") == "my_tool" + + def test_sanitize_function_name_special_chars(self): + """Special characters should be replaced or removed.""" + assert ToolFormatAdapter.sanitize_function_name("tool.name") == "tool_name" + assert ToolFormatAdapter.sanitize_function_name("tool&name") == "tool_and_name" + assert ToolFormatAdapter.sanitize_function_name("tool+name") == "tool_plus_name" + assert ToolFormatAdapter.sanitize_function_name("tool#name") == "tool_hash_name" + assert ToolFormatAdapter.sanitize_function_name("tool@name") == "tool_at_name" + + def test_sanitize_function_name_spaces(self): + """Spaces should become underscores.""" + assert ToolFormatAdapter.sanitize_function_name("my tool") == "my_tool" + assert ToolFormatAdapter.sanitize_function_name("my tool") == "my_tool" + assert ToolFormatAdapter.sanitize_function_name(" my tool ") == "my_tool" + + def test_sanitize_function_name_empty(self): + """Empty names should return default.""" + assert ToolFormatAdapter.sanitize_function_name("") == "unknown_tool" + + def test_sanitize_function_name_numeric_prefix(self): + """Names starting with numbers should get prefix.""" + assert ToolFormatAdapter.sanitize_function_name("123tool") == "tool_123tool" + + +class TestOpenAIFunctionPattern: + """Tests for the OpenAI function name pattern regex.""" + + def test_openai_function_pattern_valid(self): + """Valid function names should match the pattern.""" + pattern = ToolFormatAdapter.OPENAI_FUNCTION_PATTERN + assert pattern.match("my_tool") + assert pattern.match("myTool") + assert pattern.match("tool123") + assert pattern.match("my-tool") + assert pattern.match("TOOL") + assert pattern.match("a") + + def test_openai_function_pattern_invalid(self): + """Invalid function names should not match the pattern.""" + pattern = ToolFormatAdapter.OPENAI_FUNCTION_PATTERN + assert not pattern.match("my tool") # space + assert not pattern.match("tool.name") # dot + assert not pattern.match("tool@name") # at symbol + assert not pattern.match("tool!name") # exclamation + assert not pattern.match("") # empty + + +class TestToOpenAIFormat: + """Tests for ToolFormatAdapter.to_openai_format""" + + def test_to_openai_format_valid_tool(self, mock_mcp_tool): + """Valid tools should convert correctly.""" + result = ToolFormatAdapter.to_openai_format(mock_mcp_tool) + + assert result is not None + assert result["type"] == "function" + assert result["function"]["name"] == "test_tool" + assert result["function"]["description"] == "A test tool" + assert "properties" in result["function"]["parameters"] + + def test_to_openai_format_sanitizes_name(self): + """Tool names should be sanitized.""" + tool = MagicMock() + tool.name = "my tool name" + tool.description = "Test" + tool.parameters = {} + + result = ToolFormatAdapter.to_openai_format(tool) + + assert result is not None + assert result["function"]["name"] == "my_tool_name" + + def test_to_openai_format_missing_name(self): + """Tools without names should return None.""" + tool = MagicMock() + tool.name = None + tool.description = "Test" + tool.parameters = {} + + result = ToolFormatAdapter.to_openai_format(tool) + + assert result is None + + +class TestToolValidation: + """Tests for ToolFormatAdapter.validate_tool""" + + def test_validate_tool_valid(self, mock_mcp_tool): + """Valid tools should pass validation.""" + result = ToolFormatAdapter.validate_tool(mock_mcp_tool) + + assert result.is_valid + assert len(result.errors) == 0 + assert result.sanitized_name == "test_tool" + + def test_validate_tool_missing_name(self): + """Tools without names should fail validation.""" + tool = MagicMock() + tool.name = None + tool.description = "Test" + tool.parameters = {} + + result = ToolFormatAdapter.validate_tool(tool) + + assert not result.is_valid + assert any("name" in err.lower() for err in result.errors) + + def test_validate_tool_name_warning(self): + """Tools with sanitized names should have warnings.""" + tool = MagicMock() + tool.name = "my tool" + tool.description = "Test" + tool.parameters = {} + + result = ToolFormatAdapter.validate_tool(tool) + + assert result.is_valid + assert result.sanitized_name == "my_tool" + assert any("sanitized" in warn.lower() for warn in result.warnings) + + +class TestToolSanitizer: + """Tests for ToolSanitizer class""" + + def test_sanitize_tools_filters_invalid(self): + """Invalid tools should be filtered out.""" + valid_tool = MagicMock() + valid_tool.name = "valid_tool" + valid_tool.description = "Valid" + valid_tool.parameters = {} + + invalid_tool = MagicMock() + invalid_tool.name = None + invalid_tool.description = "Invalid" + invalid_tool.parameters = {} + + result = ToolSanitizer.sanitize_tools( + [valid_tool, invalid_tool], log_warnings=False + ) + + assert len(result) == 1 + assert result[0]["function"]["name"] == "valid_tool" + + def test_get_tool_statistics(self, mock_mcp_tool): + """Tool statistics should be calculated correctly.""" + invalid_tool = MagicMock() + invalid_tool.name = None + invalid_tool.description = "Invalid" + invalid_tool.parameters = {} + + stats = ToolSanitizer.get_tool_statistics([mock_mcp_tool, invalid_tool]) + + assert stats["total_tools"] == 2 + assert stats["valid_tools"] == 1 + assert stats["invalid_tools"] == 1 diff --git a/tests/test_tool_manager.py b/tests/test_tool_manager.py new file mode 100644 index 0000000..c8af900 --- /dev/null +++ b/tests/test_tool_manager.py @@ -0,0 +1,182 @@ +""" +Tests for ToolManager functionality. +""" + +from unittest.mock import AsyncMock + +import pytest + +from metorial._tool_manager import ToolManager + + +class TestToolManagerGetTools: + """Tests for ToolManager.get_tools""" + + def test_get_tools_returns_list(self, mock_tool_manager): + """get_tools should return a list.""" + mock_tool_manager.get_tools.return_value = [] + manager = ToolManager(mock_tool_manager) + + result = manager.get_tools() + + assert isinstance(result, list) + mock_tool_manager.get_tools.assert_called_once() + + def test_get_tools_delegates_to_mcp_manager(self, mock_tool_manager, mock_mcp_tool): + """get_tools should delegate to MCP manager.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + result = manager.get_tools() + + assert len(result) == 1 + assert result[0].name == "test_tool" + + +class TestToolManagerGetToolsForOpenAI: + """Tests for ToolManager.get_tools_for_openai""" + + def test_get_tools_for_openai_returns_list(self, mock_tool_manager, mock_mcp_tool): + """get_tools_for_openai should return OpenAI-formatted tools.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + result = manager.get_tools_for_openai() + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "test_tool" + + def test_get_tools_for_openai_caching(self, mock_tool_manager, mock_mcp_tool): + """get_tools_for_openai should cache results.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + # First call + result1 = manager.get_tools_for_openai() + # Second call should use cache + result2 = manager.get_tools_for_openai() + + assert result1 == result2 + # Should only call get_tools once due to caching + assert mock_tool_manager.get_tools.call_count == 1 + + def test_get_tools_for_openai_force_refresh(self, mock_tool_manager, mock_mcp_tool): + """force_refresh should bypass cache.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + # First call + manager.get_tools_for_openai() + # Force refresh + manager.get_tools_for_openai(force_refresh=True) + + # Should call get_tools twice + assert mock_tool_manager.get_tools.call_count == 2 + + +class TestToolManagerCacheInvalidation: + """Tests for cache invalidation""" + + def test_refresh_cache(self, mock_tool_manager, mock_mcp_tool): + """refresh_cache should clear the cache.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + # Populate cache + manager.get_tools_for_openai() + assert mock_tool_manager.get_tools.call_count == 1 + + # Refresh cache + manager.refresh_cache() + + # Next call should fetch again + manager.get_tools_for_openai() + assert mock_tool_manager.get_tools.call_count == 2 + + def test_get_cache_info(self, mock_tool_manager, mock_mcp_tool): + """get_cache_info should return cache state.""" + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + # Before caching + info = manager.get_cache_info() + assert info["cached"] is False + + # After caching + manager.get_tools_for_openai() + info = manager.get_cache_info() + assert info["cached"] is True + assert info["cache_age_seconds"] is not None + + +class TestToolManagerExecuteTool: + """Tests for ToolManager.execute_tool""" + + @pytest.mark.asyncio + async def test_execute_tool_success(self, mock_tool_manager): + """execute_tool should execute tool and return result.""" + mock_tool_manager.call_tool = AsyncMock(return_value={"content": "test result"}) + manager = ToolManager(mock_tool_manager) + + result = await manager.execute_tool("test_tool", {"param1": "value1"}) + + assert result["content"] == "test result" + mock_tool_manager.call_tool.assert_called_once_with( + "test_tool", {"param1": "value1"} + ) + + @pytest.mark.asyncio + async def test_execute_tool_json_arguments(self, mock_tool_manager): + """execute_tool should parse JSON string arguments.""" + mock_tool_manager.call_tool = AsyncMock(return_value={"content": "test result"}) + manager = ToolManager(mock_tool_manager) + + result = await manager.execute_tool("test_tool", '{"param1": "value1"}') + + assert result["content"] == "test result" + mock_tool_manager.call_tool.assert_called_once_with( + "test_tool", {"param1": "value1"} + ) + + @pytest.mark.asyncio + async def test_execute_tool_invalid_json(self, mock_tool_manager): + """execute_tool should raise ValueError for invalid JSON.""" + manager = ToolManager(mock_tool_manager) + + with pytest.raises(ValueError, match="Invalid JSON"): + await manager.execute_tool("test_tool", "not valid json") + + @pytest.mark.asyncio + async def test_execute_tool_not_found(self, mock_tool_manager, mock_mcp_tool): + """execute_tool should raise ValueError when tool not found.""" + mock_tool_manager.call_tool = AsyncMock(side_effect=Exception("Tool not found")) + mock_tool_manager.get_tools.return_value = [mock_mcp_tool] + manager = ToolManager(mock_tool_manager) + + with pytest.raises(ValueError, match="not found"): + await manager.execute_tool("nonexistent_tool", {}) + + +class TestToolManagerGetTool: + """Tests for ToolManager.get_tool""" + + def test_get_tool_delegates(self, mock_tool_manager, mock_mcp_tool): + """get_tool should delegate to MCP manager.""" + mock_tool_manager.get_tool.return_value = mock_mcp_tool + manager = ToolManager(mock_tool_manager) + + result = manager.get_tool("test_tool") + + assert result == mock_mcp_tool + mock_tool_manager.get_tool.assert_called_once_with("test_tool") + + def test_get_tool_not_found(self, mock_tool_manager): + """get_tool should return None for unknown tools.""" + mock_tool_manager.get_tool.return_value = None + manager = ToolManager(mock_tool_manager) + + result = manager.get_tool("unknown_tool") + + assert result is None