Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ori_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from ori_sdk.gateway import (
GATEWAY_HEALTH_TOPIC,
GATEWAY_REASONING_REQUEST_TOPIC_FILTER,
GatewayRetryPolicy,
build_gateway_reasoning_request,
gateway_request_topic,
Expand Down Expand Up @@ -51,6 +52,7 @@
"SkillMetadataValidationError",
# gateway
"GATEWAY_HEALTH_TOPIC",
"GATEWAY_REASONING_REQUEST_TOPIC_FILTER",
"GatewayRetryPolicy",
"build_gateway_reasoning_request",
"gateway_request_topic",
Expand Down
21 changes: 15 additions & 6 deletions ori_sdk/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@
from ori_sdk.models import GatewayReasoningRequest, GatewayReasoningResponse

GATEWAY_HEALTH_TOPIC = "ori/gateway/health"
GATEWAY_REASONING_REQUEST_TOPIC_FILTER = "ori/+/reasoning/request"


def gateway_request_topic(device_id: str) -> str:
device = device_id.strip()
if not device:
def _validate_mqtt_device_id(device_id: str) -> str:
if not device_id:
raise ValueError("device_id must not be empty")
if device_id.strip() != device_id:
raise ValueError("device_id must not contain leading or trailing whitespace")
if any(char in device_id for char in "/+#"):
raise ValueError(
"device_id must not contain MQTT topic separators or wildcards"
)
return device_id


def gateway_request_topic(device_id: str) -> str:
device = _validate_mqtt_device_id(device_id)
return f"ori/{device}/reasoning/request"


def gateway_response_topic(device_id: str) -> str:
device = device_id.strip()
if not device:
raise ValueError("device_id must not be empty")
device = _validate_mqtt_device_id(device_id)
return f"ori/{device}/reasoning/response"


Expand Down
9 changes: 9 additions & 0 deletions ori_sdk/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,12 @@ async def _arequest(self, message: str) -> bytes:
"empty health response payload", code=ORI_SDK_EMPTY_RESPONSE
)
return bytes(buf)


__all__ = [
"DEFAULT_HEALTH_SOCKET_PATH",
"DEFAULT_TIMEOUT_S",
"MAX_RESPONSE_BYTES",
"HealthClientError",
"RuntimeHealthClient",
]
8 changes: 8 additions & 0 deletions ori_sdk/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,11 @@ def validate_skill_metadata(skill: object) -> Mapping[str, object]:
)

return root


__all__ = [
"MAX_HISTORY_PLACEHOLDERS",
"SkillMetadataValidationError",
"validate_skill_metadata",
"validate_skill_metadata_file",
]
27 changes: 27 additions & 0 deletions tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ori_sdk.errors import GatewayContractError
from ori_sdk.gateway import (
GATEWAY_HEALTH_TOPIC,
GATEWAY_REASONING_REQUEST_TOPIC_FILTER,
GatewayRetryPolicy,
build_gateway_reasoning_request,
gateway_request_topic,
Expand All @@ -28,6 +29,32 @@ def test_gateway_topics_include_device_id() -> None:
assert gateway_request_topic("site-a") == "ori/site-a/reasoning/request"
assert gateway_response_topic("site-a") == "ori/site-a/reasoning/response"
assert GATEWAY_HEALTH_TOPIC == "ori/gateway/health"
assert GATEWAY_REASONING_REQUEST_TOPIC_FILTER == "ori/+/reasoning/request"


@pytest.mark.parametrize(
"device_id",
["", " site-a", "site-a ", "site/a", "site+a", "site#a"],
)
def test_gateway_topics_reject_invalid_device_ids(device_id: str) -> None:
with pytest.raises(ValueError):
gateway_request_topic(device_id)
with pytest.raises(ValueError):
gateway_response_topic(device_id)


def test_gateway_topics_do_not_use_legacy_gateway_namespace() -> None:
request_topic = gateway_request_topic("site-a")
response_topic = gateway_response_topic("site-a")
legacy_fragments = [
"ori/gateway/site-a/reason/request",
"ori/gateway/site-a/reason/response",
"/reason/request",
"/reason/response",
]
for fragment in legacy_fragments:
assert fragment not in request_topic
assert fragment not in response_topic


def test_gateway_request_builder_preserves_request_id() -> None:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_health.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import socket
from pathlib import Path
from types import TracebackType

import pytest

Expand Down Expand Up @@ -39,7 +40,12 @@ def recv(self, _n: int) -> bytes:
def __enter__(self) -> _FakeSocket:
return self

def __exit__(self, exc_type, exc, tb) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
return None


Expand All @@ -49,7 +55,7 @@ def test_get_health_parses_success_payload(monkeypatch: pytest.MonkeyPatch) -> N
monkeypatch.setattr(
socket,
"socket",
lambda *args, **kwargs: _FakeSocket([raw]), # type: ignore[return-value]
lambda *args, **kwargs: _FakeSocket([raw]),
)

response = RuntimeHealthClient("/tmp/ori-health.sock").get_health()
Expand All @@ -66,7 +72,7 @@ def test_get_health_parses_error_payload(monkeypatch: pytest.MonkeyPatch) -> Non
monkeypatch.setattr(
socket,
"socket",
lambda *args, **kwargs: _FakeSocket([raw]), # type: ignore[return-value]
lambda *args, **kwargs: _FakeSocket([raw]),
)

response = RuntimeHealthClient("/tmp/ori-health.sock").get_health()
Expand All @@ -89,7 +95,7 @@ def connect(self, _path: str) -> None:
monkeypatch.setattr(
socket,
"socket",
lambda *args, **kwargs: _RefusingSocket(), # type: ignore[return-value]
lambda *args, **kwargs: _RefusingSocket(),
)

client = RuntimeHealthClient("/tmp/ori-health.sock")
Expand Down Expand Up @@ -131,7 +137,9 @@ def test_aget_health_parses_success_payload(monkeypatch: pytest.MonkeyPatch) ->
payload = json.loads((FIXTURES / "runtime_health_success.json").read_text())
raw = json.dumps(payload).encode("utf-8")

async def _open_unix_connection(_path: str):
async def _open_unix_connection(
_path: str,
) -> tuple[_FakeAsyncReader, _FakeAsyncWriter]:
return _FakeAsyncReader([raw]), _FakeAsyncWriter()

monkeypatch.setattr(asyncio, "open_unix_connection", _open_unix_connection)
Expand All @@ -145,7 +153,9 @@ async def _open_unix_connection(_path: str):
def test_aget_health_handles_connection_refused(
monkeypatch: pytest.MonkeyPatch,
) -> None:
async def _open_unix_connection(_path: str):
async def _open_unix_connection(
_path: str,
) -> tuple[_FakeAsyncReader, _FakeAsyncWriter]:
raise ConnectionRefusedError("refused")

monkeypatch.setattr(asyncio, "open_unix_connection", _open_unix_connection)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
posture_interpretation,
staleness_summary,
)
from ori_sdk.models import HealthResponse
from ori_sdk.models import HealthResponse, HealthStatus

FIXTURES = Path(__file__).parent / "fixtures"


def _load_health_status(): # type: ignore[return]
def _load_health_status() -> HealthStatus:
payload = json.loads((FIXTURES / "runtime_health_success.json").read_text())
resp = HealthResponse.from_dict(payload)
assert resp.health is not None
Expand Down
Loading