From 3b71fee3399b41706ba3746eeba338a0cc43dc1a Mon Sep 17 00:00:00 2001 From: Wasiu Bakare Date: Fri, 15 May 2026 18:57:55 +0100 Subject: [PATCH] fix(gateway): harden topic helpers --- ori_sdk/__init__.py | 2 ++ ori_sdk/gateway.py | 21 +++++++++++++++------ ori_sdk/health.py | 9 +++++++++ ori_sdk/validation.py | 8 ++++++++ tests/test_gateway.py | 27 +++++++++++++++++++++++++++ tests/test_health.py | 22 ++++++++++++++++------ tests/test_helpers.py | 4 ++-- 7 files changed, 79 insertions(+), 14 deletions(-) diff --git a/ori_sdk/__init__.py b/ori_sdk/__init__.py index 4336157..d33e574 100644 --- a/ori_sdk/__init__.py +++ b/ori_sdk/__init__.py @@ -11,6 +11,7 @@ ) from ori_sdk.gateway import ( GATEWAY_HEALTH_TOPIC, + GATEWAY_REASONING_REQUEST_TOPIC_FILTER, GatewayRetryPolicy, build_gateway_reasoning_request, gateway_request_topic, @@ -51,6 +52,7 @@ "SkillMetadataValidationError", # gateway "GATEWAY_HEALTH_TOPIC", + "GATEWAY_REASONING_REQUEST_TOPIC_FILTER", "GatewayRetryPolicy", "build_gateway_reasoning_request", "gateway_request_topic", diff --git a/ori_sdk/gateway.py b/ori_sdk/gateway.py index d4c1b3a..88c1830 100644 --- a/ori_sdk/gateway.py +++ b/ori_sdk/gateway.py @@ -12,19 +12,28 @@ from ori_sdk.models import GatewayReasoningRequest, GatewayReasoningResponse GATEWAY_HEALTH_TOPIC = "ori/gateway/health" +GATEWAY_REASONING_REQUEST_TOPIC_FILTER = "ori/+/reasoning/request" -def gateway_request_topic(device_id: str) -> str: - device = device_id.strip() - if not device: +def _validate_mqtt_device_id(device_id: str) -> str: + if not device_id: raise ValueError("device_id must not be empty") + if device_id.strip() != device_id: + raise ValueError("device_id must not contain leading or trailing whitespace") + if any(char in device_id for char in "/+#"): + raise ValueError( + "device_id must not contain MQTT topic separators or wildcards" + ) + return device_id + + +def gateway_request_topic(device_id: str) -> str: + device = _validate_mqtt_device_id(device_id) return f"ori/{device}/reasoning/request" def gateway_response_topic(device_id: str) -> str: - device = device_id.strip() - if not device: - raise ValueError("device_id must not be empty") + device = _validate_mqtt_device_id(device_id) return f"ori/{device}/reasoning/response" diff --git a/ori_sdk/health.py b/ori_sdk/health.py index 3e7cdfe..1afd832 100644 --- a/ori_sdk/health.py +++ b/ori_sdk/health.py @@ -165,3 +165,12 @@ async def _arequest(self, message: str) -> bytes: "empty health response payload", code=ORI_SDK_EMPTY_RESPONSE ) return bytes(buf) + + +__all__ = [ + "DEFAULT_HEALTH_SOCKET_PATH", + "DEFAULT_TIMEOUT_S", + "MAX_RESPONSE_BYTES", + "HealthClientError", + "RuntimeHealthClient", +] diff --git a/ori_sdk/validation.py b/ori_sdk/validation.py index 0a31278..55b7245 100644 --- a/ori_sdk/validation.py +++ b/ori_sdk/validation.py @@ -187,3 +187,11 @@ def validate_skill_metadata(skill: object) -> Mapping[str, object]: ) return root + + +__all__ = [ + "MAX_HISTORY_PLACEHOLDERS", + "SkillMetadataValidationError", + "validate_skill_metadata", + "validate_skill_metadata_file", +] diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 56c5b48..61a0db6 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -11,6 +11,7 @@ from ori_sdk.errors import GatewayContractError from ori_sdk.gateway import ( GATEWAY_HEALTH_TOPIC, + GATEWAY_REASONING_REQUEST_TOPIC_FILTER, GatewayRetryPolicy, build_gateway_reasoning_request, gateway_request_topic, @@ -28,6 +29,32 @@ def test_gateway_topics_include_device_id() -> None: assert gateway_request_topic("site-a") == "ori/site-a/reasoning/request" assert gateway_response_topic("site-a") == "ori/site-a/reasoning/response" assert GATEWAY_HEALTH_TOPIC == "ori/gateway/health" + assert GATEWAY_REASONING_REQUEST_TOPIC_FILTER == "ori/+/reasoning/request" + + +@pytest.mark.parametrize( + "device_id", + ["", " site-a", "site-a ", "site/a", "site+a", "site#a"], +) +def test_gateway_topics_reject_invalid_device_ids(device_id: str) -> None: + with pytest.raises(ValueError): + gateway_request_topic(device_id) + with pytest.raises(ValueError): + gateway_response_topic(device_id) + + +def test_gateway_topics_do_not_use_legacy_gateway_namespace() -> None: + request_topic = gateway_request_topic("site-a") + response_topic = gateway_response_topic("site-a") + legacy_fragments = [ + "ori/gateway/site-a/reason/request", + "ori/gateway/site-a/reason/response", + "/reason/request", + "/reason/response", + ] + for fragment in legacy_fragments: + assert fragment not in request_topic + assert fragment not in response_topic def test_gateway_request_builder_preserves_request_id() -> None: diff --git a/tests/test_health.py b/tests/test_health.py index 87229cd..3cb56a8 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -7,6 +7,7 @@ import json import socket from pathlib import Path +from types import TracebackType import pytest @@ -39,7 +40,12 @@ def recv(self, _n: int) -> bytes: def __enter__(self) -> _FakeSocket: return self - def __exit__(self, exc_type, exc, tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: return None @@ -49,7 +55,7 @@ def test_get_health_parses_success_payload(monkeypatch: pytest.MonkeyPatch) -> N monkeypatch.setattr( socket, "socket", - lambda *args, **kwargs: _FakeSocket([raw]), # type: ignore[return-value] + lambda *args, **kwargs: _FakeSocket([raw]), ) response = RuntimeHealthClient("/tmp/ori-health.sock").get_health() @@ -66,7 +72,7 @@ def test_get_health_parses_error_payload(monkeypatch: pytest.MonkeyPatch) -> Non monkeypatch.setattr( socket, "socket", - lambda *args, **kwargs: _FakeSocket([raw]), # type: ignore[return-value] + lambda *args, **kwargs: _FakeSocket([raw]), ) response = RuntimeHealthClient("/tmp/ori-health.sock").get_health() @@ -89,7 +95,7 @@ def connect(self, _path: str) -> None: monkeypatch.setattr( socket, "socket", - lambda *args, **kwargs: _RefusingSocket(), # type: ignore[return-value] + lambda *args, **kwargs: _RefusingSocket(), ) client = RuntimeHealthClient("/tmp/ori-health.sock") @@ -131,7 +137,9 @@ def test_aget_health_parses_success_payload(monkeypatch: pytest.MonkeyPatch) -> payload = json.loads((FIXTURES / "runtime_health_success.json").read_text()) raw = json.dumps(payload).encode("utf-8") - async def _open_unix_connection(_path: str): + async def _open_unix_connection( + _path: str, + ) -> tuple[_FakeAsyncReader, _FakeAsyncWriter]: return _FakeAsyncReader([raw]), _FakeAsyncWriter() monkeypatch.setattr(asyncio, "open_unix_connection", _open_unix_connection) @@ -145,7 +153,9 @@ async def _open_unix_connection(_path: str): def test_aget_health_handles_connection_refused( monkeypatch: pytest.MonkeyPatch, ) -> None: - async def _open_unix_connection(_path: str): + async def _open_unix_connection( + _path: str, + ) -> tuple[_FakeAsyncReader, _FakeAsyncWriter]: raise ConnectionRefusedError("refused") monkeypatch.setattr(asyncio, "open_unix_connection", _open_unix_connection) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 9a500d9..8979fde 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -12,12 +12,12 @@ posture_interpretation, staleness_summary, ) -from ori_sdk.models import HealthResponse +from ori_sdk.models import HealthResponse, HealthStatus FIXTURES = Path(__file__).parent / "fixtures" -def _load_health_status(): # type: ignore[return] +def _load_health_status() -> HealthStatus: payload = json.loads((FIXTURES / "runtime_health_success.json").read_text()) resp = HealthResponse.from_dict(payload) assert resp.health is not None