Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ def _normalize_retrieve_hit(raw: dict[str, Any]) -> dict[str, Any]:
return result


def _normalize_state_metadata(
value: Mapping[str, Any] | None,
) -> dict[str, Any] | None:
if value is None:
return None
if not isinstance(value, Mapping):
raise TypeError("state_metadata must be a mapping")
return {
"step": value.get("step"),
"active_plan_id": value.get("active_plan_id"),
"tokens_used": value.get("tokens_used"),
"custom": value.get("custom"),
}


_AWS_KWARG_MAP: dict[str, str] = {
"aws_access_key_id": "aws_access_key_id",
"aws_secret_access_key": "aws_secret_access_key",
Expand Down Expand Up @@ -424,6 +439,7 @@ def add(
bot_id: str | None = None,
session_id: str | None = None,
external_id: str | None = None,
state_metadata: Mapping[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
relationships: list[dict[str, Any]] | None = None,
expires_at: datetime | str | None = None,
Expand All @@ -450,6 +466,7 @@ def add(
bot_id,
session_id,
external_id,
_normalize_state_metadata(state_metadata),
_json_dumps(metadata, "metadata"),
_coerce_timestamp(expires_at, field_name="expires_at"),
retention_policy,
Expand Down Expand Up @@ -576,9 +593,9 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None:

Each record accepts the same fields as :meth:`add`: ``role``,
``content``, optional ``content_type``/``data_type``, ``embedding``,
``bot_id``, ``session_id``, ``external_id``, ``metadata``,
``relationships``, and lifecycle fields such as ``expires_at`` and
``lifecycle_status``.
``bot_id``, ``session_id``, ``external_id``, ``state_metadata``,
``metadata``, ``relationships``, and lifecycle fields such as
``expires_at`` and ``lifecycle_status``.
"""
normalized: list[dict[str, Any]] = []
for index, record in enumerate(records):
Expand Down Expand Up @@ -608,6 +625,9 @@ def add_many(self, records: Iterable[Mapping[str, Any]]) -> None:
"bot_id": record.get("bot_id"),
"session_id": record.get("session_id"),
"external_id": record.get("external_id"),
"state_metadata": _normalize_state_metadata(
record.get("state_metadata")
),
"metadata_json": _json_dumps(record.get("metadata"), "metadata"),
"relationships_json": _json_dumps(
record.get("relationships"), "relationships"
Expand Down Expand Up @@ -915,6 +935,7 @@ async def add(
bot_id: str | None = None,
session_id: str | None = None,
external_id: str | None = None,
state_metadata: Mapping[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
relationships: list[dict[str, Any]] | None = None,
expires_at: datetime | str | None = None,
Expand All @@ -937,6 +958,7 @@ async def add(
bot_id=bot_id,
session_id=session_id,
external_id=external_id,
state_metadata=state_metadata,
metadata=metadata,
relationships=relationships,
expires_at=expires_at,
Expand Down
45 changes: 42 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use lance_context_core::serde::CONTENT_TYPE_TEXT;
use lance_context_core::{
CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord,
ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType, LifecycleQueryOptions,
RecordFilters, RecordPatch, Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE,
RecordFilters, RecordPatch, Relationship, RetrieveResult, SearchResult, StateMetadata,
LIFECYCLE_ACTIVE,
};

const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream";
Expand All @@ -35,6 +36,7 @@ struct RecordInput {
bot_id: Option<String>,
session_id: Option<String>,
external_id: Option<String>,
state_metadata: Option<StateMetadata>,
metadata_json: Option<String>,
relationships: Vec<Relationship>,
lifecycle: LifecycleFields,
Expand Down Expand Up @@ -239,7 +241,7 @@ impl Context {
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, metadata_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None, supersedes_id = None, superseded_by_id = None, relationships_json = None))]
#[pyo3(signature = (role, content, data_type = None, embedding = None, bot_id = None, session_id = None, external_id = None, state_metadata = None, metadata_json = None, expires_at = None, retention_policy = None, lifecycle_status = None, retired_at = None, retired_reason = None, supersedes_id = None, superseded_by_id = None, relationships_json = None))]
fn add(
&mut self,
py: Python<'_>,
Expand All @@ -250,6 +252,7 @@ impl Context {
bot_id: Option<String>,
session_id: Option<String>,
external_id: Option<String>,
state_metadata: Option<&Bound<'_, PyDict>>,
metadata_json: Option<String>,
expires_at: Option<String>,
retention_policy: Option<String>,
Expand Down Expand Up @@ -278,6 +281,7 @@ impl Context {
bot_id,
session_id,
external_id,
state_metadata: state_metadata_from_dict(state_metadata)?,
metadata_json,
relationships: relationships_from_json(relationships_json)?,
lifecycle,
Expand Down Expand Up @@ -348,6 +352,7 @@ impl Context {
bot_id,
session_id,
external_id,
state_metadata: None,
metadata_json,
relationships: relationships_from_json(relationships_json)?,
lifecycle,
Expand Down Expand Up @@ -742,6 +747,17 @@ impl Context {
let session_id = optional_item(dict, "session_id")?.map(|value| value.extract::<String>());
let external_id =
optional_item(dict, "external_id")?.map(|value| value.extract::<String>());
let state_metadata = match optional_item(dict, "state_metadata")? {
Some(value) => {
let metadata = value.downcast::<PyDict>().map_err(|_| {
PyTypeError::new_err(format!(
"records[{index}].state_metadata must be a dict"
))
})?;
state_metadata_from_dict(Some(metadata))?
}
None => None,
};
let metadata_json =
optional_item(dict, "metadata_json")?.map(|value| value.extract::<String>());
let relationships_json =
Expand Down Expand Up @@ -778,6 +794,7 @@ impl Context {
bot_id: bot_id.transpose()?,
session_id: session_id.transpose()?,
external_id: external_id.transpose()?,
state_metadata,
metadata_json: metadata_json.transpose()?,
relationships: relationships_from_json(relationships_json.transpose()?)?,
lifecycle,
Expand All @@ -800,6 +817,7 @@ impl Context {
bot_id,
session_id,
external_id,
state_metadata,
metadata_json,
relationships,
lifecycle,
Expand Down Expand Up @@ -839,7 +857,7 @@ impl Context {
session_id,
created_at: Utc::now(),
role: role.clone(),
state_metadata: None,
state_metadata,
metadata,
relationships,
expires_at: lifecycle.expires_at,
Expand Down Expand Up @@ -877,6 +895,27 @@ fn optional_item<'py>(dict: &Bound<'py, PyDict>, key: &str) -> PyResult<Option<B
Ok(dict.get_item(key)?.filter(|value| !value.is_none()))
}

fn state_metadata_from_dict(dict: Option<&Bound<'_, PyDict>>) -> PyResult<Option<StateMetadata>> {
let Some(dict) = dict else {
return Ok(None);
};

Ok(Some(StateMetadata {
step: optional_item(dict, "step")?
.map(|value| value.extract::<i32>())
.transpose()?,
active_plan_id: optional_item(dict, "active_plan_id")?
.map(|value| value.extract::<String>())
.transpose()?,
tokens_used: optional_item(dict, "tokens_used")?
.map(|value| value.extract::<i32>())
.transpose()?,
custom: optional_item(dict, "custom")?
.map(|value| value.extract::<String>())
.transpose()?,
}))
}

fn relationships_patch_from_json(value: Option<String>) -> PyResult<Option<Vec<Relationship>>> {
value
.map(|value| relationships_from_json(Some(value)))
Expand Down
31 changes: 31 additions & 0 deletions python/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self) -> None:
self.update_calls: list[dict[str, Any]] = []
self.lifecycle_add_calls: list[dict[str, Any]] = []
self.relationship_add_calls: list[str | None] = []
self.state_metadata_add_calls: list[dict[str, Any] | None] = []
self.add_calls: list[
tuple[
str,
Expand All @@ -61,6 +62,7 @@ def add(
bot_id: str | None,
session_id: str | None,
external_id: str | None,
state_metadata: dict[str, Any] | None,
metadata_json: str | None,
expires_at: str | None = None,
retention_policy: str | None = None,
Expand Down Expand Up @@ -95,6 +97,7 @@ def add(
}
)
self.relationship_add_calls.append(relationships_json)
self.state_metadata_add_calls.append(state_metadata)

def upsert(
self,
Expand Down Expand Up @@ -1011,6 +1014,32 @@ def test_context_add_forwards_relationships():
]


def test_context_add_forwards_state_metadata():
ctx = Context.__new__(Context)
dummy = DummyInner()
ctx._inner = dummy # type: ignore[attr-defined]

ctx.add(
"assistant",
"step complete",
state_metadata={
"step": 3,
"active_plan_id": "plan-1",
"tokens_used": 128,
"custom": "retrieval",
},
)

assert dummy.state_metadata_add_calls == [
{
"step": 3,
"active_plan_id": "plan-1",
"tokens_used": 128,
"custom": "retrieval",
}
]


def test_context_add_rejects_non_json_metadata():
ctx = Context.__new__(Context)
dummy = DummyInner()
Expand Down Expand Up @@ -1218,6 +1247,7 @@ def test_context_add_many_normalizes_records():
"bot_id": None,
"session_id": None,
"external_id": None,
"state_metadata": None,
"metadata_json": None,
"relationships_json": None,
"expires_at": None,
Expand All @@ -1236,6 +1266,7 @@ def test_context_add_many_normalizes_records():
"bot_id": "bot",
"session_id": "sess",
"external_id": "doc-1#chunk-2",
"state_metadata": None,
"metadata_json": None,
"relationships_json": None,
"expires_at": None,
Expand Down
59 changes: 59 additions & 0 deletions python/tests/test_state_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for writing structured state metadata from Python."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path

from lance_context.api import Context


def test_add_roundtrips_state_metadata(tmp_path: Path) -> None:
uri = str(tmp_path / "context.lance")
ctx = Context.create(uri)

ctx.add(
"assistant",
"plan step complete",
state_metadata={
"step": 3,
"active_plan_id": "plan-1",
"tokens_used": 128,
"custom": "retrieval",
},
)

records = ctx.list()
assert records[0]["state_metadata"] == {
"step": 3,
"active_plan_id": "plan-1",
"tokens_used": 128,
"custom": "retrieval",
}


def test_add_many_roundtrips_partial_state_metadata(tmp_path: Path) -> None:
uri = str(tmp_path / "context.lance")
ctx = Context.create(uri)

ctx.add_many(
[
{
"role": "user",
"content": "first",
"state_metadata": {"step": 1, "tokens_used": 10},
},
{"role": "assistant", "content": "second"},
]
)

records = ctx.list()
assert records[0]["state_metadata"] == {
"step": 1,
"active_plan_id": None,
"tokens_used": 10,
"custom": None,
}
assert records[1]["state_metadata"] is None
Loading