Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/packages/ag-ui/agent_framework_ag_ui/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import sys
import uuid
from binascii import Error as BinasciiError
from collections.abc import AsyncIterable, Awaitable, Mapping, MutableSequence, Sequence
from functools import wraps
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
Expand Down Expand Up @@ -294,16 +295,17 @@ def _extract_state_from_messages(self, messages: Sequence[Message]) -> tuple[lis
if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json":
try:
uri = content.uri
if uri.startswith("data:application/json;base64,"): # type: ignore[union-attr]
prefix, _, encoded_data = uri.partition(",") # type: ignore[union-attr]
media_type, *parameters = prefix[5:].split(";")
if prefix.startswith("data:") and media_type == "application/json" and "base64" in parameters:
import base64

encoded_data = uri.split(",", 1)[1] # type: ignore[union-attr]
decoded_bytes = base64.b64decode(encoded_data)
decoded_bytes = base64.b64decode(encoded_data, validate=True)
state = json.loads(decoded_bytes.decode("utf-8"))
Comment on lines +300 to 304

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 829a628 by validating base64 with base64.b64decode(..., validate=True) and catching binascii.Error in the existing warning/fallthrough path. I also added a malformed-base64 regression test that preserves the original message list and returns state is None.

Validation run locally:

  • uv run pytest packages/ag-ui/tests/ag_ui/test_ag_ui_client.py -q
  • uv run ruff check packages/ag-ui/agent_framework_ag_ui/_client.py packages/ag-ui/tests/ag_ui/test_ag_ui_client.py
  • git diff --check -- python/packages/ag-ui/agent_framework_ag_ui/_client.py python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py


messages_without_state = list(messages[:-1]) if len(messages) > 1 else []
return messages_without_state, state
except (json.JSONDecodeError, ValueError, KeyError) as e:
except (BinasciiError, json.JSONDecodeError, ValueError, KeyError) as e:
logger.warning(f"Failed to extract state from message: {e}")

return list(messages), None
Expand Down
40 changes: 40 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,30 @@ async def test_extract_state_from_messages_with_state(self) -> None:
assert result_messages[0].text == "Hello"
assert state == state_data

async def test_extract_state_from_messages_with_parameterized_data_uri(self) -> None:
"""Test state extraction from JSON data URIs with media type parameters."""
import base64

client = StubAGUIChatClient(endpoint="http://localhost:8888/")

state_data = {"key": "value", "count": 42}
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")

messages = [
Message(role="user", contents=["Hello"]),
Message(
role="user",
contents=[Content.from_uri(uri=f"data:application/json;charset=utf-8;base64,{state_b64}")],
),
]

result_messages, state = client.extract_state_from_messages(messages)

assert len(result_messages) == 1
assert result_messages[0].text == "Hello"
assert state == state_data

async def test_extract_state_invalid_json(self) -> None:
"""Test state extraction with invalid JSON."""
import base64
Expand All @@ -125,6 +149,22 @@ async def test_extract_state_invalid_json(self) -> None:
assert result_messages == messages
assert state is None

async def test_extract_state_invalid_base64(self) -> None:
"""Test state extraction with invalid base64."""
client = StubAGUIChatClient(endpoint="http://localhost:8888/")

messages = [
Message(
role="user",
contents=[Content.from_uri(uri="data:application/json;base64,not-valid-base64!")],
),
]

result_messages, state = client.extract_state_from_messages(messages)

assert result_messages == messages
assert state is None

async def test_convert_messages_to_agui_format(self) -> None:
"""Test message conversion to AG-UI format."""
client = StubAGUIChatClient(endpoint="http://localhost:8888/")
Expand Down