diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 7f0e1c0..81fb205 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -174,6 +174,21 @@ def _normalize_retrieve_hit(raw: dict[str, Any]) -> dict[str, Any]: return result +def _normalize_state_metadata( + value: Mapping[str, Any] | None, +) -> dict[str, Any] | None: + if value is None: + return None + if not isinstance(value, Mapping): + raise TypeError("state_metadata must be a mapping") + return { + "step": value.get("step"), + "active_plan_id": value.get("active_plan_id"), + "tokens_used": value.get("tokens_used"), + "custom": value.get("custom"), + } + + _AWS_KWARG_MAP: dict[str, str] = { "aws_access_key_id": "aws_access_key_id", "aws_secret_access_key": "aws_secret_access_key", @@ -424,6 +439,7 @@ def add( bot_id: str | None = None, session_id: str | None = None, external_id: str | None = None, + state_metadata: Mapping[str, Any] | None = None, metadata: dict[str, Any] | None = None, relationships: list[dict[str, Any]] | None = None, expires_at: datetime | str | None = None, @@ -450,6 +466,7 @@ def add( bot_id, session_id, external_id, + _normalize_state_metadata(state_metadata), _json_dumps(metadata, "metadata"), _coerce_timestamp(expires_at, field_name="expires_at"), retention_policy, @@ -576,9 +593,9 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None: Each record accepts the same fields as :meth:`add`: ``role``, ``content``, optional ``content_type``/``data_type``, ``embedding``, - ``bot_id``, ``session_id``, ``external_id``, ``metadata``, - ``relationships``, and lifecycle fields such as ``expires_at`` and - ``lifecycle_status``. + ``bot_id``, ``session_id``, ``external_id``, ``state_metadata``, + ``metadata``, ``relationships``, and lifecycle fields such as + ``expires_at`` and ``lifecycle_status``. """ normalized: list[dict[str, Any]] = [] for index, record in enumerate(records): @@ -608,6 +625,9 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None: "bot_id": record.get("bot_id"), "session_id": record.get("session_id"), "external_id": record.get("external_id"), + "state_metadata": _normalize_state_metadata( + record.get("state_metadata") + ), "metadata_json": _json_dumps(record.get("metadata"), "metadata"), "relationships_json": _json_dumps( record.get("relationships"), "relationships" @@ -915,6 +935,7 @@ async def add( bot_id: str | None = None, session_id: str | None = None, external_id: str | None = None, + state_metadata: Mapping[str, Any] | None = None, metadata: dict[str, Any] | None = None, relationships: list[dict[str, Any]] | None = None, expires_at: datetime | str | None = None, @@ -937,6 +958,7 @@ async def add( bot_id=bot_id, session_id=session_id, external_id=external_id, + state_metadata=state_metadata, metadata=metadata, relationships=relationships, expires_at=expires_at, diff --git a/python/src/lib.rs b/python/src/lib.rs index d0b4c26..a58e474 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,7 +15,8 @@ use lance_context_core::serde::CONTENT_TYPE_TEXT; use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType, LifecycleQueryOptions, - RecordFilters, RecordPatch, Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, + RecordFilters, RecordPatch, Relationship, RetrieveResult, SearchResult, StateMetadata, + LIFECYCLE_ACTIVE, }; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; @@ -35,6 +36,7 @@ struct RecordInput { bot_id: Option, session_id: Option, external_id: Option, + state_metadata: Option, metadata_json: Option, relationships: Vec, lifecycle: LifecycleFields, @@ -239,7 +241,7 @@ impl Context { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, metadata_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None, supersedes_id = None, superseded_by_id = None, relationships_json = None))] + #[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, state_metadata = None, metadata_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None, supersedes_id = None, superseded_by_id = None, relationships_json = None))] fn add( &mut self, py: Python<'_>, @@ -250,6 +252,7 @@ impl Context { bot_id: Option, session_id: Option, external_id: Option, + state_metadata: Option<&Bound<'_, PyDict>>, metadata_json: Option, expires_at: Option, retention_policy: Option, @@ -278,6 +281,7 @@ impl Context { bot_id, session_id, external_id, + state_metadata: state_metadata_from_dict(state_metadata)?, metadata_json, relationships: relationships_from_json(relationships_json)?, lifecycle, @@ -348,6 +352,7 @@ impl Context { bot_id, session_id, external_id, + state_metadata: None, metadata_json, relationships: relationships_from_json(relationships_json)?, lifecycle, @@ -742,6 +747,17 @@ impl Context { let session_id = optional_item(dict, "session_id")?.map(|value| value.extract::()); let external_id = optional_item(dict, "external_id")?.map(|value| value.extract::()); + let state_metadata = match optional_item(dict, "state_metadata")? { + Some(value) => { + let metadata = value.downcast::().map_err(|_| { + PyTypeError::new_err(format!( + "records[{index}].state_metadata must be a dict" + )) + })?; + state_metadata_from_dict(Some(metadata))? + } + None => None, + }; let metadata_json = optional_item(dict, "metadata_json")?.map(|value| value.extract::()); let relationships_json = @@ -778,6 +794,7 @@ impl Context { bot_id: bot_id.transpose()?, session_id: session_id.transpose()?, external_id: external_id.transpose()?, + state_metadata, metadata_json: metadata_json.transpose()?, relationships: relationships_from_json(relationships_json.transpose()?)?, lifecycle, @@ -800,6 +817,7 @@ impl Context { bot_id, session_id, external_id, + state_metadata, metadata_json, relationships, lifecycle, @@ -839,7 +857,7 @@ impl Context { session_id, created_at: Utc::now(), role: role.clone(), - state_metadata: None, + state_metadata, metadata, relationships, expires_at: lifecycle.expires_at, @@ -877,6 +895,27 @@ fn optional_item<'py>(dict: &Bound<'py, PyDict>, key: &str) -> PyResult>) -> PyResult> { + let Some(dict) = dict else { + return Ok(None); + }; + + Ok(Some(StateMetadata { + step: optional_item(dict, "step")? + .map(|value| value.extract::()) + .transpose()?, + active_plan_id: optional_item(dict, "active_plan_id")? + .map(|value| value.extract::()) + .transpose()?, + tokens_used: optional_item(dict, "tokens_used")? + .map(|value| value.extract::()) + .transpose()?, + custom: optional_item(dict, "custom")? + .map(|value| value.extract::()) + .transpose()?, + })) +} + fn relationships_patch_from_json(value: Option) -> PyResult>> { value .map(|value| relationships_from_json(Some(value))) diff --git a/python/tests/test_search.py b/python/tests/test_search.py index 9ce7c90..b54026d 100644 --- a/python/tests/test_search.py +++ b/python/tests/test_search.py @@ -38,6 +38,7 @@ def __init__(self) -> None: self.update_calls: list[dict[str, Any]] = [] self.lifecycle_add_calls: list[dict[str, Any]] = [] self.relationship_add_calls: list[str | None] = [] + self.state_metadata_add_calls: list[dict[str, Any] | None] = [] self.add_calls: list[ tuple[ str, @@ -61,6 +62,7 @@ def add( bot_id: str | None, session_id: str | None, external_id: str | None, + state_metadata: dict[str, Any] | None, metadata_json: str | None, expires_at: str | None = None, retention_policy: str | None = None, @@ -95,6 +97,7 @@ def add( } ) self.relationship_add_calls.append(relationships_json) + self.state_metadata_add_calls.append(state_metadata) def upsert( self, @@ -1011,6 +1014,32 @@ def test_context_add_forwards_relationships(): ] +def test_context_add_forwards_state_metadata(): + ctx = Context.__new__(Context) + dummy = DummyInner() + ctx._inner = dummy # type: ignore[attr-defined] + + ctx.add( + "assistant", + "step complete", + state_metadata={ + "step": 3, + "active_plan_id": "plan-1", + "tokens_used": 128, + "custom": "retrieval", + }, + ) + + assert dummy.state_metadata_add_calls == [ + { + "step": 3, + "active_plan_id": "plan-1", + "tokens_used": 128, + "custom": "retrieval", + } + ] + + def test_context_add_rejects_non_json_metadata(): ctx = Context.__new__(Context) dummy = DummyInner() @@ -1218,6 +1247,7 @@ def test_context_add_many_normalizes_records(): "bot_id": None, "session_id": None, "external_id": None, + "state_metadata": None, "metadata_json": None, "relationships_json": None, "expires_at": None, @@ -1236,6 +1266,7 @@ def test_context_add_many_normalizes_records(): "bot_id": "bot", "session_id": "sess", "external_id": "doc-1#chunk-2", + "state_metadata": None, "metadata_json": None, "relationships_json": None, "expires_at": None, diff --git a/python/tests/test_state_metadata.py b/python/tests/test_state_metadata.py new file mode 100644 index 0000000..6766c15 --- /dev/null +++ b/python/tests/test_state_metadata.py @@ -0,0 +1,59 @@ +"""Tests for writing structured state metadata from Python.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +from lance_context.api import Context + + +def test_add_roundtrips_state_metadata(tmp_path: Path) -> None: + uri = str(tmp_path / "context.lance") + ctx = Context.create(uri) + + ctx.add( + "assistant", + "plan step complete", + state_metadata={ + "step": 3, + "active_plan_id": "plan-1", + "tokens_used": 128, + "custom": "retrieval", + }, + ) + + records = ctx.list() + assert records[0]["state_metadata"] == { + "step": 3, + "active_plan_id": "plan-1", + "tokens_used": 128, + "custom": "retrieval", + } + + +def test_add_many_roundtrips_partial_state_metadata(tmp_path: Path) -> None: + uri = str(tmp_path / "context.lance") + ctx = Context.create(uri) + + ctx.add_many( + [ + { + "role": "user", + "content": "first", + "state_metadata": {"step": 1, "tokens_used": 10}, + }, + {"role": "assistant", "content": "second"}, + ] + ) + + records = ctx.list() + assert records[0]["state_metadata"] == { + "step": 1, + "active_plan_id": None, + "tokens_used": 10, + "custom": None, + } + assert records[1]["state_metadata"] is None