Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ def parse_result(result: Any) -> str:
def to_json_schema_spec(self) -> dict[str, Any]:
"""Convert a FunctionTool to the JSON Schema function specification format.

The parameter schema is sanitized to remove JSON Schema features
(e.g. ``$ref``, ``$defs``, ``$schema``) that LLM APIs may not accept.

Returns:
A dictionary containing the function specification in JSON Schema format.
"""
Expand All @@ -655,7 +658,7 @@ def to_json_schema_spec(self) -> dict[str, Any]:
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters(),
"parameters": sanitize_schema_for_api(self.parameters()),
},
}

Expand All @@ -668,6 +671,106 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True)
return as_dict


# Keys that are valid JSON Schema metadata but not accepted by most LLM APIs
# when used in function tool parameter schemas.
_UNSUPPORTED_SCHEMA_ROOT_KEYS: Final[frozenset[str]] = frozenset({
"$schema",
"$id",
"title",
})


def _resolve_refs(schema: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]:
"""Recursively resolve ``$ref`` references by inlining definitions.

Args:
schema: A JSON Schema node (possibly containing ``$ref``).
defs: The top-level ``$defs`` / ``definitions`` mapping to resolve against.

Returns:
A new dict with ``$ref`` pointers replaced by their resolved definitions.
"""
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
Comment on lines +688 to +696
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Circular $ref bug: if defs['A'] itself contains a $ref back to 'A' (directly or transitively), this _resolve_refs(resolved, defs) call recurses infinitely, raising RecursionError. Add a seen: set[str] parameter (defaulting to an empty set) and skip resolution when def_name in seen.

Suggested change
defs: The top-level ``$defs`` / ``definitions`` mapping to resolve against.
Returns:
A new dict with ``$ref`` pointers replaced by their resolved definitions.
"""
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if def_name in defs:
if def_name in _seen:
# Circular reference — return without the $ref to break the cycle
return {k: v for k, v in schema.items() if k != "$ref"}
resolved = dict(defs[def_name])
# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs, _seen | {def_name})

if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fallback drops $ref but does not recurse into the remaining sibling values. If any sibling key contains a nested dict with a resolvable $ref, it will not be resolved.

Suggested change
if def_name in defs:
return _resolve_refs({k: v for k, v in schema.items() if k != "$ref"}, defs)

resolved = dict(defs[def_name])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Circular $ref chains (A → B → A) will cause infinite recursion here. Consider adding a seen: set[str] parameter to detect cycles and break out gracefully. There is no test covering this edge case.

Suggested change
resolved = dict(defs[def_name])
def _resolve_refs(schema: dict[str, Any], defs: dict[str, Any], _seen: set[str] | None = None) -> dict[str, Any]:
if _seen is None:
_seen = set()
if "$ref" in schema:
ref_path: str = schema["$ref"]
if ref_path in _seen:
return {k: v for k, v in schema.items() if k != "$ref"}
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs:
resolved = dict(defs[def_name])
# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs, _seen | {ref_path})
# Unresolvable $ref — drop it and keep sibling keys as a best-effort fallback
return {k: v for k, v in schema.items() if k != "$ref"}

Comment on lines +690 to +700
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This recursive call will loop forever on circular $ref (e.g., a tree node whose children items reference the node definition itself). Pass a seen: set[str] through the recursion and short-circuit when a definition name is revisited — for example, emit {} or {"type": "object"} as a fallback for the cyclic reference.

Suggested change
Returns:
A new dict with ``$ref`` pointers replaced by their resolved definitions.
"""
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs:
resolved = dict(defs[def_name])
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs:
if def_name in _seen:
# Circular reference — emit best-effort fallback
return {k: v for k, v in schema.items() if k != "$ref"}
resolved = dict(defs[def_name])
# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs, _seen | {def_name})
# Unresolvable $ref — drop it and keep sibling keys as a best-effort fallback
return {k: v for k, v in schema.items() if k != "$ref"}

# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs)
Comment on lines +693 to +705
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Circular $ref chains (common in JSON Schema for recursive types like trees) will recurse until RecursionError. Track visited ref names and bail out when a cycle is detected.

Suggested change
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs:
resolved = dict(defs[def_name])
# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs)
if "$ref" in schema:
ref_path: str = schema["$ref"]
# Only handle local fragment references: #/$defs/Name or #/definitions/Name
for prefix in ("#/$defs/", "#/definitions/"):
if ref_path.startswith(prefix):
def_name = ref_path[len(prefix) :]
if def_name in defs and def_name not in _seen:
resolved = dict(defs[def_name])
# Merge any sibling keys (e.g. description) from the referring node
for k, v in schema.items():
if k != "$ref" and k not in resolved:
resolved[k] = v
return _resolve_refs(resolved, defs, _seen | {def_name})

# Unresolvable $ref — drop it and keep sibling keys as a best-effort fallback
return {k: v for k, v in schema.items() if k != "$ref"}

result: dict[str, Any] = {}
for key, value in schema.items():
if isinstance(value, dict):
result[key] = _resolve_refs(cast(dict[str, Any], value), defs)
elif isinstance(value, list):
result[key] = [
_resolve_refs(cast(dict[str, Any], item), defs) if isinstance(item, dict) else item
for item in value # type: ignore[union-attr]
]
else:
result[key] = value
return result


def sanitize_schema_for_api(schema: dict[str, Any]) -> dict[str, Any]:
"""Sanitize a JSON Schema for use as LLM function-tool parameters.

MCP servers may return ``inputSchema`` dicts that contain standard JSON
Schema features (``$schema``, ``$defs``, ``$ref``, ``title``, etc.) which
many LLM API backends do not accept. This function produces a clean copy
suitable for the ``parameters`` field of a function-tool definition.

The original *schema* dict is never mutated.

Args:
schema: The raw JSON Schema dict (e.g. from ``tool.inputSchema``).

Returns:
A sanitized deep copy with unsupported fields removed, ``$ref``
pointers resolved inline, and ``type`` defaulting to ``"object"``
when ``properties`` is present.
"""
if not schema:
return {"type": "object", "properties": {}}

# Collect $defs / definitions before traversing the tree.
# Combine both if present so that refs using either prefix can be resolved.
defs: dict[str, Any] = {}
raw_defs = schema.get("$defs")
if isinstance(raw_defs, Mapping):
defs.update(raw_defs) # type: ignore[reportUnknownArgumentType]
raw_definitions = schema.get("definitions")
if isinstance(raw_definitions, Mapping):
for def_name, def_value in raw_definitions.items(): # type: ignore[reportUnknownVariableType]
if def_name not in defs:
defs[def_name] = def_value

# Resolve $ref pointers inline (also deep-copies while traversing)
sanitized = _resolve_refs(schema, defs)

# Strip unsupported root-level keys
for key in _UNSUPPORTED_SCHEMA_ROOT_KEYS:
sanitized.pop(key, None)

# Remove $defs / definitions (no longer needed after resolution)
sanitized.pop("$defs", None)
sanitized.pop("definitions", None)

# Ensure top-level type is "object" when properties are present
if "properties" in sanitized and "type" not in sanitized:
sanitized["type"] = "object"

return sanitized


ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
FunctionTool,
ToolTypes,
normalize_tools,
sanitize_schema_for_api,
tool,
)
from .._types import (
Expand Down Expand Up @@ -468,7 +469,7 @@ def _prepare_tools_for_openai(
)
continue
if isinstance(tool_item, FunctionTool):
params = tool_item.parameters()
params = sanitize_schema_for_api(tool_item.parameters())
params["additionalProperties"] = False
response_tools.append(
FunctionToolParam(
Expand Down
192 changes: 192 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from agent_framework._tools import (
_parse_annotation,
_parse_inputs,
_resolve_refs,
sanitize_schema_for_api,
)
from agent_framework.observability import OtelAttr

Expand Down Expand Up @@ -1001,3 +1003,193 @@ def test_parse_annotation_with_annotated_and_literal():


# endregion

# region sanitize_schema_for_api tests


def test_sanitize_schema_empty_returns_default() -> None:
"""An empty schema should produce a minimal valid object schema."""
assert sanitize_schema_for_api({}) == {"type": "object", "properties": {}}


def test_sanitize_schema_simple_unchanged() -> None:
"""A simple schema with only supported fields should pass through."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
}
assert sanitize_schema_for_api(schema) == schema


def test_sanitize_schema_does_not_mutate_original() -> None:
"""The original schema dict must never be modified."""
schema: dict[str, Any] = {
"type": "object",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"properties": {"x": {"type": "integer"}},
}
original_keys = set(schema.keys())
sanitize_schema_for_api(schema)
assert set(schema.keys()) == original_keys


def test_sanitize_schema_strips_unsupported_root_keys() -> None:
"""$schema, $id, and title should all be stripped from the root."""
schema: dict[str, Any] = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "urn:example",
"title": "Args",
"type": "object",
"properties": {"x": {"type": "number"}},
}
result = sanitize_schema_for_api(schema)
assert "$schema" not in result
assert "$id" not in result
assert "title" not in result
assert result["type"] == "object"
assert result["properties"] == {"x": {"type": "number"}}


def test_sanitize_schema_adds_type_object_when_missing() -> None:
"""type should default to 'object' when properties are present but type is missing."""
result = sanitize_schema_for_api({"properties": {"name": {"type": "string"}}})
assert result["type"] == "object"


def test_sanitize_schema_no_type_added_without_properties() -> None:
"""type should not be injected when there are no properties."""
result = sanitize_schema_for_api({"description": "A schema without properties"})
assert "type" not in result


def test_sanitize_schema_resolves_simple_ref() -> None:
"""A simple $ref pointing to $defs should be inlined."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}},
"$defs": {
"CustomerIdParam": {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}
},
}
result = sanitize_schema_for_api(schema)
assert "$defs" not in result
assert result["properties"]["params"] == {
"type": "object",
"properties": {"customer_id": {"type": "integer"}},
"required": ["customer_id"],
}


def test_sanitize_schema_resolves_nested_refs() -> None:
"""Chained $ref references should be resolved recursively."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"order": {"$ref": "#/$defs/Order"}},
"$defs": {
"Order": {
"type": "object",
"properties": {"customer": {"$ref": "#/$defs/Customer"}},
},
"Customer": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
},
}
result = sanitize_schema_for_api(schema)
assert "$defs" not in result
assert result["properties"]["order"]["properties"]["customer"] == {
"type": "object",
"properties": {"name": {"type": "string"}},
}


def test_sanitize_schema_resolves_ref_in_array_items() -> None:
"""$ref inside array items should be resolved."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"items": {"type": "array", "items": {"$ref": "#/$defs/Item"}}},
"$defs": {"Item": {"type": "object", "properties": {"sku": {"type": "string"}}}},
}
result = sanitize_schema_for_api(schema)
assert result["properties"]["items"]["items"] == {
"type": "object",
"properties": {"sku": {"type": "string"}},
}


def test_sanitize_schema_unresolvable_ref_dropped() -> None:
"""An unresolvable $ref should be dropped gracefully."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"data": {"$ref": "#/$defs/NonExistent"}},
"$defs": {},
}
result = sanitize_schema_for_api(schema)
assert "$ref" not in result["properties"]["data"]


def test_sanitize_schema_go_jsonschema_output() -> None:
"""Schema generated by google/jsonschema-go (as used by matlab-mcp-core-server)."""
schema: dict[str, Any] = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "urn:matlab:evaluate_matlab_code",
"title": "Args",
"type": "object",
"properties": {
"code": {"type": "string", "description": "The MATLAB code to evaluate."},
},
"required": ["code"],
"additionalProperties": False,
}
result = sanitize_schema_for_api(schema)
assert "$schema" not in result
assert "$id" not in result
assert "title" not in result
assert result == {
"type": "object",
"properties": {"code": {"type": "string", "description": "The MATLAB code to evaluate."}},
"required": ["code"],
"additionalProperties": False,
}


def test_resolve_refs_deep_copies() -> None:
"""_resolve_refs should return a deep copy, not a reference to the input."""
schema: dict[str, Any] = {
"type": "object",
"properties": {"nested": {"type": "object", "properties": {"deep": {"type": "boolean"}}}},
}
result = _resolve_refs(schema, {})
result["properties"]["nested"]["type"] = "array"
assert schema["properties"]["nested"]["type"] == "object"


def test_sanitize_schema_both_defs_and_definitions() -> None:
"""Schemas with both $defs and definitions should resolve refs from either."""
schema: dict[str, Any] = {
"type": "object",
"properties": {
"a": {"$ref": "#/$defs/TypeA"},
"b": {"$ref": "#/definitions/TypeB"},
},
"$defs": {
"TypeA": {"type": "string"},
},
"definitions": {
"TypeB": {"type": "integer"},
},
}
result = sanitize_schema_for_api(schema)
assert "$defs" not in result
assert "definitions" not in result
assert result["properties"]["a"] == {"type": "string"}
assert result["properties"]["b"] == {"type": "integer"}


# endregion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test that FunctionTool.to_json_schema_spec() produces a sanitized schema. This is a separate integration point from _prepare_tools_for_openai and should have its own test to prevent regressions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for _resolve_refs sibling-key merging: when a $ref node has additional keys (e.g. description), those should appear in the resolved output. Currently no test exercises lines 697-700 of _tools.py.

Loading