diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index a8907b7ec0..ade7d8a2c3 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -76,7 +76,7 @@ agent_framework/ - **`SkillScriptRunner`** - Protocol for file-based script execution. Any callable matching `(skill, script, args) -> Any` satisfies it. Code-defined scripts do not use a runner. - **`SkillScriptArgumentParser`** - Public type alias for an optional callable `(raw args: dict | list[str] | str | None) -> dict | None` that converts the raw `args` value before an `InlineSkillScript` runs (applied before the inline list-args guard). It is an opt-in customization hook (port of .NET PR #6498) that lets callers support backends sending tool-call arguments in a non-conforming shape (e.g. vLLM JSON strings). The output is constrained to a `dict` (named keyword arguments) or `None`, because inline scripts bind arguments by keyword name. Supply it via the `argument_parser=` constructor arg on `InlineSkillScript`, `InlineSkill` (default for scripts added via `@skill.script`), or `ClassSkill` (default for scripts discovered via `@ClassSkill.script`). When `None` (the default), the raw value is used unchanged. File-based scripts are unaffected (their runner owns arg handling). - **`SkillsProvider`** - Context provider (extends `ContextProvider`) that discovers file-based skills from `SKILL.md` files and/or accepts code-defined `Skill` instances. Follows progressive disclosure: advertise → load → read resources / run scripts. All three tools it exposes (`load_skill`, `read_skill_resource`, `run_skill_script`) are registered with `approval_mode="always_require"`, so every skill operation needs approval. To run unattended, pass one of the static auto-approval rules to `ToolApprovalMiddleware` (via `auto_approval_rules`): `SkillsProvider.read_only_tools_auto_approval_rule` approves only the read-only tools (`load_skill`, `read_skill_resource`) while still prompting for `run_skill_script`, and `SkillsProvider.all_tools_auto_approval_rule` approves every skill tool including script execution. Both rules reject any call carrying a `server_label` so they stay scoped to this provider's local tools and never auto-approve a same-named hosted tool. The tool names are also exposed as class constants (`LOAD_SKILL_TOOL_NAME`, `READ_SKILL_RESOURCE_TOOL_NAME`, `RUN_SKILL_SCRIPT_TOOL_NAME`). -- **`SkillsSource` decorators** - Skill sources are composable: `SkillsSource` is the abstract base, with concrete sources (`InMemorySkillsSource`, `FileSkillsSource`, `MCPSkillsSource`) and decorators that wrap an inner source — `AggregatingSkillsSource` (concatenate several sources), `FilteringSkillsSource` (predicate filter), `DeduplicatingSkillsSource` (first-wins by name), and `CachingSkillsSource` (cache the inner source's skills list). `DelegatingSkillsSource` is the abstract base for decorators. **Caching lives in the source pipeline, not the provider**: `SkillsProvider` wraps its resolved source in a `CachingSkillsSource` by default (so expensive filesystem/network discovery runs once and is reused), and rebuilds instructions/tools from the cached skills each run. Pass `disable_caching=True` to `SkillsProvider(...)` / `SkillsProvider.from_paths(...)` to skip the wrapping and re-query the source on every run. `CachingSkillsSource` shares a single in-flight fetch across concurrent callers and resets its cache on failure so the next call retries. +- **`SkillsSource` decorators** - Skill sources are composable: `SkillsSource` is the abstract base, with concrete sources (`InMemorySkillsSource`, `FileSkillsSource`, `MCPSkillsSource`) and decorators that wrap an inner source — `AggregatingSkillsSource` (concatenate several sources), `FilteringSkillsSource` (predicate filter), `DeduplicatingSkillsSource` (first-wins by name), and `CachingSkillsSource` (cache the inner source's skills list). `DelegatingSkillsSource` is the abstract base for decorators. **`get_skills` takes a `SkillsSourceContext`**: every source/decorator implements `async def get_skills(self, context: SkillsSourceContext) -> list[Skill]` and forwards `context` to inner sources. `SkillsSourceContext` (frozen, experimental) carries the invoking `agent` (`SupportsAgentRun`) and optional `session` (`AgentSession | None`); `SkillsProvider` builds it from `before_run`'s `agent`/`session` and passes it into the pipeline. `FilteringSkillsSource`'s predicate is context-aware: `Callable[[Skill, SkillsSourceContext], bool]` (port of .NET #6797). **Caching lives in the source pipeline, not the provider**: `SkillsProvider` wraps its resolved source in a `CachingSkillsSource` by default (so expensive filesystem/network discovery runs once and is reused), and rebuilds instructions/tools from the cached skills each run. Pass `disable_caching=True` to `SkillsProvider(...)` / `SkillsProvider.from_paths(...)` to skip the wrapping and re-query the source on every run. `CachingSkillsSource` shares a single in-flight fetch across concurrent callers (per cache key) and resets its cache on failure so the next call retries. By default all callers share one cache bucket; pass `cache_isolation_key_selector=Callable[[SkillsSourceContext], str | None]` to cache separately per key (e.g. per agent name) for context-aware inner sources — the key should be low-cardinality and stable, and returning `None` (or leaving the selector `None`) uses the shared bucket. ### Model Context Protocol (`_mcp.py`) diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 7f58dc2394..34b317934d 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -204,6 +204,7 @@ SkillScriptRunner, SkillsProvider, SkillsSource, + SkillsSourceContext, ) from ._telemetry import ( AGENT_FRAMEWORK_USER_AGENT, @@ -520,6 +521,7 @@ "SkillScriptRunner", "SkillsProvider", "SkillsSource", + "SkillsSourceContext", "SlidingWindowStrategy", "StepWrapper", "SubWorkflowRequestMessage", diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index f33dfb2aaf..6b345f1960 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -53,6 +53,7 @@ import re from abc import ABC, abstractmethod from collections.abc import Callable, Sequence +from dataclasses import dataclass from html import escape as xml_escape from pathlib import Path, PurePosixPath from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeAlias, TypeVar, cast, runtime_checkable @@ -2216,7 +2217,9 @@ def _create_instructions( resource_instructions=resource_instructions or "", ) - async def _create_context(self) -> tuple[Sequence[Skill], str | None, list[FunctionTool]]: + async def _create_context( + self, source_context: SkillsSourceContext + ) -> tuple[Sequence[Skill], str | None, list[FunctionTool]]: """Build skills, instructions, and tools from the source. Queries the source for skills and constructs the instruction prompt @@ -2225,10 +2228,14 @@ async def _create_context(self) -> tuple[Sequence[Skill], str | None, list[Funct rebuilds instructions and tools from the (possibly cached) skills on every call. + Args: + source_context: Contextual information about the agent and session + requesting skills, forwarded to the source pipeline. + Returns: A tuple of ``(skills, instructions, tools)``. """ - skills = await self._source.get_skills() + skills = await self._source.get_skills(source_context) if not skills: return skills, None, [] @@ -2270,7 +2277,8 @@ async def before_run( context: Session context to extend with instructions and tools. state: Mutable per-run state dictionary (unused by this provider). """ - skills, instructions, tools = await self._create_context() + source_context = SkillsSourceContext(agent=agent, session=session) + skills, instructions, tools = await self._create_context(source_context) if not skills: return @@ -2557,6 +2565,28 @@ def _create_script_element(script: SkillScript) -> str: # region Skill Sources +@experimental(feature_id=ExperimentalFeature.SKILLS) +@dataclass(frozen=True) +class SkillsSourceContext: + """Contextual information passed to a :class:`SkillsSource` when retrieving skills. + + Exposes the invoking *agent* and, when available, the current *session* so + that skill sources and decorators can make context-aware decisions such as + per-agent filtering (see :class:`FilteringSkillsSource`) or per-key cache + isolation (see :class:`CachingSkillsSource`). + + The context is constructed by :class:`SkillsProvider` from the invoking + agent run and flows through every source and decorator in the pipeline. + + Attributes: + agent: The agent requesting skills. + session: The session associated with the agent invocation, if any. + """ + + agent: SupportsAgentRun + session: AgentSession | None = None + + @experimental(feature_id=ExperimentalFeature.SKILLS) class SkillsSource(ABC): """Abstract base class for skill sources. @@ -2570,9 +2600,13 @@ class SkillsSource(ABC): """ @abstractmethod - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Discover and return all skills from this source. + Args: + context: Contextual information about the agent and session + requesting skills. + Returns: A list of :class:`Skill` instances discovered by this source. """ @@ -2601,7 +2635,9 @@ class FileSkillsSource(SkillsSource): .. code-block:: python source = FileSkillsSource(skill_paths="./skills") - skills = await source.get_skills() + # `context` is normally supplied by SkillsProvider at runtime. + context = SkillsSourceContext(agent=agent) + skills = await source.get_skills(context) With a script runner and filter predicates: @@ -2675,13 +2711,17 @@ def __init__( self._script_filter = script_filter self._resource_filter = resource_filter - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Discover and return all file-based skills from configured paths. Scans directories for ``SKILL.md`` files, parses their frontmatter, discovers resource and script files, and returns populated :class:`Skill` instances. + Args: + context: Contextual information about the agent and session + requesting skills. Unused by this source. + Returns: A list of discovered file-based skills. """ @@ -3366,7 +3406,9 @@ class InMemorySkillsSource(SkillsSource): instructions="Instructions here...", ) source = InMemorySkillsSource([skill]) - skills = await source.get_skills() + # `context` is normally supplied by SkillsProvider at runtime. + context = SkillsSourceContext(agent=agent) + skills = await source.get_skills(context) """ def __init__(self, skills: Sequence[Skill]) -> None: @@ -3377,9 +3419,13 @@ def __init__(self, skills: Sequence[Skill]) -> None: """ self._skills = list(skills) - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Return the stored skills. + Args: + context: Contextual information about the agent and session + requesting skills. Unused by this source. + Returns: A list of :class:`Skill` instances. """ @@ -3410,15 +3456,19 @@ def inner_source(self) -> SkillsSource: """The wrapped inner skill source.""" return self._inner_source - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Delegate to the inner source. Subclasses should override this to intercept the results. + Args: + context: Contextual information about the agent and session + requesting skills, forwarded to the inner source. + Returns: Skills from the inner source. """ - return await self._inner_source.get_skills() + return await self._inner_source.get_skills(context) class DeduplicatingSkillsSource(DelegatingSkillsSource): @@ -3435,7 +3485,9 @@ class DeduplicatingSkillsSource(DelegatingSkillsSource): .. code-block:: python deduped = DeduplicatingSkillsSource(inner_source) - skills = await deduped.get_skills() + # `context` is normally supplied by SkillsProvider at runtime. + context = SkillsSourceContext(agent=agent) + skills = await deduped.get_skills(context) """ def __init__(self, inner_source: SkillsSource) -> None: @@ -3446,13 +3498,17 @@ def __init__(self, inner_source: SkillsSource) -> None: """ super().__init__(inner_source) - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Return deduplicated skills (first-one-wins by name). + Args: + context: Contextual information about the agent and session + requesting skills, forwarded to the inner source. + Returns: A list of :class:`Skill` instances with duplicate names removed. """ - skills = await self._inner_source.get_skills() + skills = await self._inner_source.get_skills(context) seen: dict[str, Skill] = {} result: list[Skill] = [] @@ -3475,42 +3531,51 @@ class FilteringSkillsSource(DelegatingSkillsSource): """Decorator that filters skills from an inner source by predicate. Only skills for which *predicate* returns ``True`` are included in the - result. The predicate receives each :class:`Skill` and should return - a boolean. + result. The predicate receives each :class:`Skill` together with the + :class:`SkillsSourceContext` for the current invocation, enabling + context-aware filtering (for example, per-agent skill selection). Examples: .. code-block:: python filtered = FilteringSkillsSource( inner_source=my_source, - predicate=lambda s: s.frontmatter.name != "internal", + predicate=lambda skill, context: skill.frontmatter.name != "internal", ) - skills = await filtered.get_skills() + # `source_context` is normally supplied by SkillsProvider at runtime. + source_context = SkillsSourceContext(agent=agent) + skills = await filtered.get_skills(source_context) """ def __init__( self, inner_source: SkillsSource, - predicate: Callable[[Skill], bool], + predicate: Callable[[Skill, SkillsSourceContext], bool], ) -> None: """Initialize a FilteringSkillsSource. Args: inner_source: The source to filter. - predicate: A callable that receives a :class:`Skill` and returns - ``True`` to keep it or ``False`` to exclude it. + predicate: A callable that receives a :class:`Skill` and the + :class:`SkillsSourceContext` and returns ``True`` to keep the + skill or ``False`` to exclude it. """ super().__init__(inner_source) self._predicate = predicate - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Return only skills that match the predicate. + Args: + context: Contextual information about the agent and session + requesting skills, forwarded to the inner source and passed to + the predicate. + Returns: A filtered list of :class:`Skill` instances. """ - skills = await self._inner_source.get_skills() - return [s for s in skills if self._predicate(s)] + skills = await self._inner_source.get_skills(context) + return [s for s in skills if self._predicate(s, context)] @experimental(feature_id=ExperimentalFeature.SKILLS) @@ -3528,49 +3593,114 @@ class CachingSkillsSource(DelegatingSkillsSource): are typically static discovery metadata, so querying once and reusing the result is a pure performance win. - Concurrency: concurrent callers share a single in-flight fetch, so the - inner source is queried at most once even under concurrent access. If the - fetch fails (or is cancelled), the cache is left empty so the next call - retries. + Concurrency: concurrent callers for the same cache key share a single + in-flight fetch, so the inner source is queried at most once per key even + under concurrent access. If the fetch fails (or is cancelled), the cache + is left empty so the next call retries. + + Cache isolation: by default all callers share a single cache bucket. Pass + a *cache_isolation_key_selector* to derive a cache key from the + :class:`SkillsSourceContext`, so context-aware inner sources (which may + return different skills per agent) are cached separately per key. The key + should be low-cardinality and stable (for example, an agent name or tenant + id); high-cardinality keys such as per-session ids can cause the cache to + grow without bound. Returning ``None`` from the selector (or leaving it + ``None``) uses the shared bucket. Examples: .. code-block:: python cached = CachingSkillsSource(expensive_source) - skills = await cached.get_skills() # queries the inner source - skills = await cached.get_skills() # returns the cached list + # `context` is normally supplied by SkillsProvider at runtime. + context = SkillsSourceContext(agent=agent) + skills = await cached.get_skills(context) # queries the inner source + skills = await cached.get_skills(context) # returns the cached list + + Isolating the cache per agent: + + .. code-block:: python + + cached = CachingSkillsSource( + expensive_source, + cache_isolation_key_selector=lambda context: context.agent.name, + ) """ - def __init__(self, inner_source: SkillsSource) -> None: + _SHARED_CACHE_KEY: Final[str] = "__caching_skills_source_shared__" + + def __init__( + self, + inner_source: SkillsSource, + *, + cache_isolation_key_selector: Callable[[SkillsSourceContext], str | None] | None = None, + ) -> None: """Initialize a CachingSkillsSource. Args: inner_source: The source whose results will be cached. + + Keyword Args: + cache_isolation_key_selector: Optional callable that derives a cache + key from the :class:`SkillsSourceContext`. When ``None`` (the + default), or when the callable returns ``None``, skills are + stored in a single shared bucket. Otherwise skills are cached + under the returned key. Keys should be low-cardinality and + stable to keep the cache bounded. """ super().__init__(inner_source) - self._lock = asyncio.Lock() - self._cached_skills: list[Skill] | None = None + self._cache_isolation_key_selector = cache_isolation_key_selector + self._locks_guard = asyncio.Lock() + self._locks: dict[str, asyncio.Lock] = {} + self._cached_skills: dict[str, list[Skill]] = {} + + def _resolve_cache_key(self, context: SkillsSourceContext) -> str: + """Resolve the cache bucket key for *context*.""" + if self._cache_isolation_key_selector is not None: + selected = self._cache_isolation_key_selector(context) + if selected is not None: + return selected + return self._SHARED_CACHE_KEY + + async def _get_lock(self, key: str) -> asyncio.Lock: + """Return the per-key lock, creating it under a guard if needed.""" + async with self._locks_guard: + lock = self._locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._locks[key] = lock + return lock + + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: + """Return the inner source's skills, caching them per key on first call. - async def get_skills(self) -> list[Skill]: - """Return the inner source's skills, caching them on first call. + Args: + context: Contextual information about the agent and session + requesting skills. Used to resolve the cache key (via + *cache_isolation_key_selector*) and forwarded to the inner + source. Returns: - The cached list of :class:`Skill` instances. On the first call - the inner source is queried; subsequent calls return the cached - list. If the first query fails, the cache is not populated and - the next call retries. + The cached list of :class:`Skill` instances for the resolved cache + key. On the first call for a key the inner source is queried; + subsequent calls return the cached list. If the first query fails, + the cache is not populated and the next call retries. """ - if self._cached_skills is not None: - return self._cached_skills + key = self._resolve_cache_key(context) + + cached = self._cached_skills.get(key) + if cached is not None: + return cached - async with self._lock: + lock = await self._get_lock(key) + async with lock: # Another coroutine may have populated the cache while we awaited # the lock; re-check before querying the inner source. - if self._cached_skills is not None: - return self._cached_skills + cached = self._cached_skills.get(key) + if cached is not None: + return cached - skills = await self._inner_source.get_skills() - self._cached_skills = skills + skills = await self._inner_source.get_skills(context) + self._cached_skills[key] = skills return skills @@ -3580,10 +3710,10 @@ class AggregatingSkillsSource(SkillsSource): def __init__(self, sources: Sequence[SkillsSource]) -> None: self._sources = list(sources) - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: result: list[Skill] = [] for source in self._sources: - skills = await source.get_skills() + skills = await source.get_skills(context) result.extend(skills) return result @@ -3912,7 +4042,9 @@ class MCPSkillsSource(SkillsSource): from mcp.client.session import ClientSession source = MCPSkillsSource(client=session) - skills = await source.get_skills() + # `context` is normally supplied by SkillsProvider at runtime. + context = SkillsSourceContext(agent=agent) + skills = await source.get_skills(context) """ _INDEX_URI: Final[str] = "skill://index.json" @@ -3927,12 +4059,16 @@ def __init__(self, client: ClientSession) -> None: """ self._client = client - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: """Discover and return skills from the MCP server. Reads ``skill://index.json``, parses it, and creates an :class:`MCPSkill` for each valid ``skill-md`` entry. + Args: + context: Contextual information about the agent and session + requesting skills. Unused by this source. + Returns: A list of discovered :class:`MCPSkill` instances. """ diff --git a/python/packages/core/tests/core/test_mcp_skills.py b/python/packages/core/tests/core/test_mcp_skills.py index 74993997d0..0d0925ece9 100644 --- a/python/packages/core/tests/core/test_mcp_skills.py +++ b/python/packages/core/tests/core/test_mcp_skills.py @@ -18,13 +18,19 @@ ) from pydantic import AnyUrl -from agent_framework import MCPSkill, MCPSkillResource, MCPSkillsSource +from agent_framework import MCPSkill, MCPSkillResource, MCPSkillsSource, SkillsSourceContext from agent_framework._skills import _parse_mcp_skill_index +from .conftest import MockAgent + # --------------------------------------------------------------------------- # Fixtures & helpers # --------------------------------------------------------------------------- + +# Shared context for exercising skill sources where the agent/session are irrelevant. +_SOURCE_CTX = SkillsSourceContext(agent=MockAgent()) # type: ignore[abstract] # pyrefly: ignore[bad-instantiation] + SAMPLE_SKILL_MD = """\ --- name: unit-converter @@ -342,7 +348,7 @@ async def test_index_based_discovery_returns_skill(self) -> None: "skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD), }) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "unit-converter" @@ -356,7 +362,7 @@ async def test_index_based_discovery_returns_skill(self) -> None: async def test_no_index_returns_empty(self) -> None: client = _make_client() # No resources at all source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -365,7 +371,7 @@ async def test_does_not_read_skill_md_during_discovery(self) -> None: # Discovery should succeed because it only reads the index. client = _make_client(**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "unit-converter" @@ -385,7 +391,7 @@ async def test_invalid_name_is_skipped(self) -> None: }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -402,7 +408,7 @@ async def test_missing_required_fields_is_skipped(self) -> None: }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -420,7 +426,7 @@ async def test_unsupported_type_is_skipped(self) -> None: }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -437,21 +443,21 @@ async def test_template_type_is_skipped(self) -> None: }) client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio async def test_empty_index_returns_empty(self) -> None: client = _make_client(**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio async def test_malformed_index_json_returns_empty(self) -> None: client = _make_client(**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")}) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -462,7 +468,7 @@ async def test_sibling_text_resource(self) -> None: "skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"), }) source = MCPSkillsSource(client=client) - skill = (await source.get_skills())[0] + skill = (await source.get_skills(_SOURCE_CTX))[0] resource = await skill.get_resource("references/checklist.md") assert resource is not None content = await resource.read() @@ -477,7 +483,7 @@ async def test_sibling_binary_resource(self) -> None: "skill://unit-converter/assets/icon.bin": _make_blob_result(data), }) source = MCPSkillsSource(client=client) - skill = (await source.get_skills())[0] + skill = (await source.get_skills(_SOURCE_CTX))[0] resource = await skill.get_resource("assets/icon.bin") assert resource is not None content = await resource.read() @@ -504,7 +510,7 @@ async def test_index_method_not_found_returns_empty(self) -> None: client = AsyncMock() client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=-32601, message="Method not found"))) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -515,7 +521,7 @@ async def test_index_resource_not_found_returns_empty(self) -> None: side_effect=McpError(error=ErrorData(code=-32002, message="Resource not found")) ) source = MCPSkillsSource(client=client) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert skills == [] @pytest.mark.asyncio @@ -525,7 +531,7 @@ async def test_index_invalid_params_propagates(self) -> None: client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=-32602, message="Invalid params"))) source = MCPSkillsSource(client=client) with pytest.raises(McpError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) @pytest.mark.asyncio async def test_index_internal_error_propagates(self) -> None: @@ -534,7 +540,7 @@ async def test_index_internal_error_propagates(self) -> None: client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=-32603, message="Internal error"))) source = MCPSkillsSource(client=client) with pytest.raises(McpError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) @pytest.mark.asyncio async def test_index_connection_closed_propagates(self) -> None: @@ -545,7 +551,7 @@ async def test_index_connection_closed_propagates(self) -> None: ) source = MCPSkillsSource(client=client) with pytest.raises(McpError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) @pytest.mark.asyncio async def test_index_generic_error_code_propagates(self) -> None: @@ -554,7 +560,7 @@ async def test_index_generic_error_code_propagates(self) -> None: client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=0, message="Some handler error"))) source = MCPSkillsSource(client=client) with pytest.raises(McpError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) @pytest.mark.asyncio async def test_index_non_mcp_error_propagates(self) -> None: @@ -563,7 +569,7 @@ async def test_index_non_mcp_error_propagates(self) -> None: client.read_resource = AsyncMock(side_effect=ConnectionError("connection lost")) source = MCPSkillsSource(client=client) with pytest.raises(ConnectionError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) @pytest.mark.asyncio async def test_get_resource_internal_error_propagates(self) -> None: @@ -634,4 +640,4 @@ async def test_index_timeout_error_propagates(self) -> None: client.read_resource = AsyncMock(side_effect=TimeoutError("read timed out")) source = MCPSkillsSource(client=client) with pytest.raises(TimeoutError): - await source.get_skills() + await source.get_skills(_SOURCE_CTX) diff --git a/python/packages/core/tests/core/test_skills.py b/python/packages/core/tests/core/test_skills.py index 3a170d2a34..2f9a45ebbb 100644 --- a/python/packages/core/tests/core/test_skills.py +++ b/python/packages/core/tests/core/test_skills.py @@ -33,6 +33,7 @@ SkillScriptRunner, SkillsProvider, SkillsSource, + SkillsSourceContext, ) from agent_framework._skills import ( DEFAULT_RESOURCE_EXTENSIONS, @@ -45,12 +46,34 @@ _FileSkillResource, ) +from .conftest import MockAgent, MockAgentSession + pytestmark = pytest.mark.filterwarnings(r"ignore:\[SKILLS\].*:FutureWarning") # Cross-platform absolute path prefix for tests _ABS = "C:\\skills" if os.name == "nt" else "/skills" +class _NamedMockAgent(MockAgent): + """A :class:`MockAgent` with a configurable name for context-aware skill tests.""" + + def __init__(self, name: str = "test-agent") -> None: + self._name = name + + @property + def name(self) -> str | None: # type: ignore[override] # pyrefly: ignore[bad-override] + return self._name + + +def _make_source_context(agent_name: str = "test-agent") -> SkillsSourceContext: + """Build a :class:`SkillsSourceContext` for exercising skill sources in tests.""" + return SkillsSourceContext(agent=_NamedMockAgent(agent_name)) # type: ignore[abstract] # pyrefly: ignore[bad-instantiation] + + +# Shared context for the common case where the agent/session are irrelevant. +_SOURCE_CTX = _make_source_context() + + async def _noop_script_runner(skill: Any, script: Any, args: Any = None) -> None: """No-op script runner for tests that need a SkillScriptRunner.""" return @@ -63,7 +86,7 @@ def __init__(self, skills: Sequence[Skill]) -> None: self._skills = list(skills) self.call_count = 0 - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: self.call_count += 1 return list(self._skills) @@ -76,7 +99,7 @@ async def _init_provider(provider: SkillsProvider) -> SkillsProvider: skills list itself is cached by the source pipeline (see ``CachingSkillsSource``); this helper just captures one built context. """ - provider._test_context = await provider._create_context() # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] + provider._test_context = await provider._create_context(_SOURCE_CTX) # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage, reportAttributeAccessIssue] # ty: ignore[unresolved-attribute] return provider @@ -187,7 +210,7 @@ async def _discover_file_skills_for_test( kwargs["script_runner"] = script_runner source = FileSkillsSource(skill_paths, **kwargs) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) result: dict[str, FileSkill] = {} for s in skills: assert isinstance(s, FileSkill), f"Expected FileSkill, got {type(s).__name__}" @@ -1722,14 +1745,14 @@ async def test_search_depth_controls_resource_discovery(self, tmp_path: Path) -> # depth=1: only root source1 = FileSkillsSource(str(tmp_path), search_depth=1) - skills1 = await source1.get_skills() + skills1 = await source1.get_skills(_SOURCE_CTX) names1 = [r.name for r in skills1[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "root.md" in names1 assert "a/level1.md" not in names1 # depth=2 (default): root + one level source2 = FileSkillsSource(str(tmp_path)) - skills2 = await source2.get_skills() + skills2 = await source2.get_skills(_SOURCE_CTX) names2 = [r.name for r in skills2[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "root.md" in names2 assert "a/level1.md" in names2 @@ -1737,7 +1760,7 @@ async def test_search_depth_controls_resource_discovery(self, tmp_path: Path) -> # depth=3: finds all source3 = FileSkillsSource(str(tmp_path), search_depth=3) - skills3 = await source3.get_skills() + skills3 = await source3.get_skills(_SOURCE_CTX) names3 = [r.name for r in skills3[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "a/b/level2.md" in names3 @@ -1757,7 +1780,7 @@ async def test_resource_filter_excludes_files(self, tmp_path: Path) -> None: str(tmp_path), resource_filter=lambda name, path: "secret" not in path, ) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) resource_names = [r.name for r in skills[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "references/keep.md" in resource_names assert "references/secret.md" not in resource_names @@ -1777,7 +1800,7 @@ async def test_script_filter_excludes_files(self, tmp_path: Path) -> None: str(tmp_path), script_filter=lambda name, path: not path.startswith("test_"), ) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) script_names = [s.name for s in skills[0]._scripts] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "run.py" in script_names assert "test_run.py" not in script_names @@ -1841,7 +1864,7 @@ async def test_nested_skill_directory_absorbed_into_parent(self, tmp_path: Path) (child_dir / "child-script.py").write_text("print('child')", encoding="utf-8") source = FileSkillsSource(str(tmp_path), search_depth=3) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) skills_dict = {s.frontmatter.name: s for s in skills} # Only the parent skill is discovered; the nested SKILL.md is not its own skill. @@ -2635,7 +2658,7 @@ async def test_file_skill_fields_populated_from_discovery(self, tmp_path: Path) encoding="utf-8", ) source = FileSkillsSource(str(tmp_path)) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 skill = skills[0] assert isinstance(skill, FileSkill) @@ -4267,7 +4290,7 @@ async def test_provider_loads_class_skill_content(self) -> None: async def test_in_memory_source_with_class_skill(self) -> None: skill = _MinimalClassSkill() source = InMemorySkillsSource([skill]) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "minimal-skill" @@ -5189,7 +5212,7 @@ async def test_file_skill_takes_precedence_over_code_skill(self, tmp_path: Path) InMemorySkillsSource([code_skill]), ]) ) - result = await source.get_skills() + result = await source.get_skills(_SOURCE_CTX) skills_by_name = {s.frontmatter.name: s for s in result} assert "my-skill" in skills_by_name assert skills_by_name["my-skill"].path is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] # file-based skill has path set @@ -5213,7 +5236,7 @@ async def test_file_skills_source_discovers_skills(self, tmp_path: Path) -> None ) source = FileSkillsSource(str(tmp_path)) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "my-skill" assert skills[0].path is not None # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] @@ -5232,7 +5255,7 @@ async def test_file_skills_source_with_extensions(self, tmp_path: Path) -> None: # Only allow .json resources source = FileSkillsSource(str(tmp_path), resource_extensions=(".json",)) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 resource_names = [r.name for r in skills[0]._resources] # type: ignore[attr-defined] # ty: ignore[unresolved-attribute] assert "references/data.json" in resource_names @@ -5246,7 +5269,7 @@ async def test_in_memory_skills_source_returns_all_skills(self) -> None: s2 = InlineSkill(frontmatter=SkillFrontmatter(name="skill-b", description="B"), instructions="body") source = InMemorySkillsSource([s1, s2]) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 2 assert skills[0].frontmatter.name == "skill-a" assert skills[1].frontmatter.name == "skill-b" @@ -5262,7 +5285,7 @@ async def test_aggregating_source_combines_sources(self) -> None: InMemorySkillsSource([s1]), InMemorySkillsSource([s2]), ]) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) names = [s.frontmatter.name for s in skills] assert names == ["skill-a", "skill-b"] @@ -5275,9 +5298,9 @@ async def test_filtering_source_filters_by_predicate(self) -> None: source = FilteringSkillsSource( InMemorySkillsSource([s1, s2]), - predicate=lambda s: s.frontmatter.name.startswith("keep"), + predicate=lambda s, _ctx: s.frontmatter.name.startswith("keep"), ) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "keep-me" @@ -5290,7 +5313,7 @@ async def test_deduplicating_source_removes_duplicates(self) -> None: s3 = InlineSkill(frontmatter=SkillFrontmatter(name="other", description="other"), instructions="body3") source = DeduplicatingSkillsSource(InMemorySkillsSource([s1, s2, s3])) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 2 names = {s.frontmatter.name for s in skills} assert names == {"my-skill", "other"} @@ -5310,7 +5333,7 @@ class PassthroughSource(DelegatingSkillsSource): source = PassthroughSource(inner) assert source.inner_source is inner - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) assert len(skills) == 1 assert skills[0].frontmatter.name == "test-skill" @@ -5320,8 +5343,8 @@ async def test_caching_source_caches_inner_results(self) -> None: inner = _CountingSkillsSource([skill]) cached = CachingSkillsSource(inner) - first = await cached.get_skills() - second = await cached.get_skills() + first = await cached.get_skills(_SOURCE_CTX) + second = await cached.get_skills(_SOURCE_CTX) assert inner.call_count == 1 assert first is second @@ -5344,7 +5367,7 @@ class FlakySource(SkillsSource): def __init__(self) -> None: self.call_count = 0 - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: self.call_count += 1 if self.call_count == 1: raise RuntimeError("transient failure") @@ -5359,9 +5382,9 @@ async def get_skills(self) -> list[Skill]: cached = CachingSkillsSource(inner) with pytest.raises(RuntimeError, match="transient failure"): - await cached.get_skills() + await cached.get_skills(_SOURCE_CTX) - skills = await cached.get_skills() + skills = await cached.get_skills(_SOURCE_CTX) assert inner.call_count == 2 assert [s.frontmatter.name for s in skills] == ["test-skill"] @@ -5376,7 +5399,7 @@ class SlowSource(SkillsSource): def __init__(self) -> None: self.call_count = 0 - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: self.call_count += 1 started.set() await release.wait() @@ -5390,9 +5413,9 @@ async def get_skills(self) -> list[Skill]: inner = SlowSource() cached = CachingSkillsSource(inner) - first = asyncio.ensure_future(cached.get_skills()) + first = asyncio.ensure_future(cached.get_skills(_SOURCE_CTX)) await started.wait() - second = asyncio.ensure_future(cached.get_skills()) + second = asyncio.ensure_future(cached.get_skills(_SOURCE_CTX)) release.set() results = await asyncio.gather(first, second) @@ -5459,10 +5482,10 @@ async def test_composed_source_pipeline(self, tmp_path: Path) -> None: InMemorySkillsSource([code_skill, internal]), ]) ), - predicate=lambda s: s.frontmatter.name != "internal", + predicate=lambda s, _ctx: s.frontmatter.name != "internal", ) - skills = await source.get_skills() + skills = await source.get_skills(_SOURCE_CTX) names = {s.frontmatter.name for s in skills} assert names == {"file-skill", "code-skill"} assert "internal" not in names @@ -5473,6 +5496,108 @@ async def test_composed_source_pipeline(self, tmp_path: Path) -> None: # --------------------------------------------------------------------------- +class TestSkillsSourceContext: + """Tests for SkillsSourceContext propagation and context-aware sources.""" + + async def test_context_exposes_agent_and_session(self) -> None: + """SkillsSourceContext carries the agent and optional session.""" + agent = _NamedMockAgent() # type: ignore[abstract] # pyrefly: ignore[bad-instantiation] + ctx = SkillsSourceContext(agent=agent) + assert ctx.agent is agent + assert ctx.session is None + + session = MockAgentSession() + ctx_with_session = SkillsSourceContext(agent=agent, session=session) + assert ctx_with_session.session is session + + async def test_context_flows_through_decorator_pipeline(self) -> None: + """The context passed to get_skills reaches the innermost source.""" + received: list[SkillsSourceContext] = [] + + class _RecordingSource(SkillsSource): + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: + received.append(context) + return [ + InlineSkill( + frontmatter=SkillFrontmatter(name="skill-a", description="A"), + instructions="body", + ) + ] + + source = DeduplicatingSkillsSource(CachingSkillsSource(_RecordingSource())) + ctx = _make_source_context("agent-x") + + skills = await source.get_skills(ctx) + assert [s.frontmatter.name for s in skills] == ["skill-a"] + assert received == [ctx] + assert received[0].agent.name == "agent-x" + + async def test_filtering_predicate_receives_context(self) -> None: + """FilteringSkillsSource passes the context to the predicate.""" + from agent_framework import FilteringSkillsSource + + seen: list[SkillsSourceContext] = [] + + def _predicate(skill: Skill, context: SkillsSourceContext) -> bool: + seen.append(context) + # Keep only skills whose name matches the invoking agent's name. + return skill.frontmatter.name == context.agent.name + + s1 = InlineSkill(frontmatter=SkillFrontmatter(name="agent-x", description="A"), instructions="body") + s2 = InlineSkill(frontmatter=SkillFrontmatter(name="agent-y", description="B"), instructions="body") + + source = FilteringSkillsSource(InMemorySkillsSource([s1, s2]), predicate=_predicate) + ctx = _make_source_context("agent-x") + + skills = await source.get_skills(ctx) + assert [s.frontmatter.name for s in skills] == ["agent-x"] + assert all(c is ctx for c in seen) + + async def test_caching_shared_bucket_by_default(self) -> None: + """Without an isolation key selector, all contexts share one cache entry.""" + inner = _CountingSkillsSource([ + InlineSkill(frontmatter=SkillFrontmatter(name="skill-a", description="A"), instructions="body") + ]) + cached = CachingSkillsSource(inner) + + first = await cached.get_skills(_make_source_context("agent-x")) + second = await cached.get_skills(_make_source_context("agent-y")) + + assert inner.call_count == 1 + assert first is second + + async def test_caching_isolation_key_separates_buckets(self) -> None: + """An isolation key selector caches skills separately per key.""" + inner = _CountingSkillsSource([ + InlineSkill(frontmatter=SkillFrontmatter(name="skill-a", description="A"), instructions="body") + ]) + cached = CachingSkillsSource( + inner, + cache_isolation_key_selector=lambda context: context.agent.name, + ) + + first_x = await cached.get_skills(_make_source_context("agent-x")) + first_y = await cached.get_skills(_make_source_context("agent-y")) + second_x = await cached.get_skills(_make_source_context("agent-x")) + + # One fetch per distinct key; repeated keys are served from cache. + assert inner.call_count == 2 + assert first_x is second_x + assert first_x is not first_y + + async def test_caching_isolation_key_none_uses_shared_bucket(self) -> None: + """A selector returning None falls back to the shared cache bucket.""" + inner = _CountingSkillsSource([ + InlineSkill(frontmatter=SkillFrontmatter(name="skill-a", description="A"), instructions="body") + ]) + cached = CachingSkillsSource(inner, cache_isolation_key_selector=lambda context: None) + + await cached.get_skills(_make_source_context("agent-x")) + await cached.get_skills(_make_source_context("agent-y")) + + assert inner.call_count == 1 + + class TestSourceComposition: """Tests for composing sources directly instead of using a builder.""" @@ -5523,7 +5648,7 @@ async def test_filtering_source_excludes_skills(self) -> None: source = DeduplicatingSkillsSource( FilteringSkillsSource( InMemorySkillsSource([s1, s2]), - predicate=lambda s: s.frontmatter.name.startswith("keep"), + predicate=lambda s, _ctx: s.frontmatter.name.startswith("keep"), ) ) provider = SkillsProvider(source) @@ -5725,8 +5850,8 @@ async def test_default_caching_queries_source_once(self) -> None: inner = _CountingSkillsSource([skill]) provider = SkillsProvider(inner) - await provider._create_context() # pyright: ignore[reportPrivateUsage] - await provider._create_context() # pyright: ignore[reportPrivateUsage] + await provider._create_context(_SOURCE_CTX) # pyright: ignore[reportPrivateUsage] + await provider._create_context(_SOURCE_CTX) # pyright: ignore[reportPrivateUsage] assert inner.call_count == 1 async def test_disable_caching_does_not_wrap_source(self) -> None: @@ -5743,8 +5868,8 @@ async def test_disable_caching_rebuilds_on_every_call(self) -> None: inner = _CountingSkillsSource([skill]) provider = SkillsProvider(inner, disable_caching=True) - await provider._create_context() # pyright: ignore[reportPrivateUsage] - await provider._create_context() # pyright: ignore[reportPrivateUsage] + await provider._create_context(_SOURCE_CTX) # pyright: ignore[reportPrivateUsage] + await provider._create_context(_SOURCE_CTX) # pyright: ignore[reportPrivateUsage] assert inner.call_count == 2 async def test_disable_caching_via_constructor(self) -> None: diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_toolbox.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_toolbox.py index d5b11957a7..28b8759341 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_toolbox.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_toolbox.py @@ -8,7 +8,13 @@ from urllib.parse import urlsplit import httpx -from agent_framework import MCPSkillsSource, MCPStreamableHTTPTool, SkillsProvider, SkillsSource +from agent_framework import ( + MCPSkillsSource, + MCPStreamableHTTPTool, + SkillsProvider, + SkillsSource, + SkillsSourceContext, +) from azure.ai.agentserver.core import get_request_context if TYPE_CHECKING: @@ -246,7 +252,7 @@ class _FoundryToolboxSkillsSource(SkillsSource): def __init__(self, toolbox: FoundryToolbox) -> None: self._toolbox = toolbox - async def get_skills(self) -> list[Skill]: + async def get_skills(self, context: SkillsSourceContext) -> list[Skill]: session = self._toolbox.session if session is None: raise RuntimeError( @@ -254,4 +260,4 @@ async def get_skills(self) -> list[Skill]: "Pass the toolbox to the agent (tools=...) or enter it as an async " "context manager before the agent runs." ) - return await MCPSkillsSource(client=session).get_skills() + return await MCPSkillsSource(client=session).get_skills(context) diff --git a/python/packages/foundry_hosting/tests/test_toolbox.py b/python/packages/foundry_hosting/tests/test_toolbox.py index 02f48d793f..c9a1391ad3 100644 --- a/python/packages/foundry_hosting/tests/test_toolbox.py +++ b/python/packages/foundry_hosting/tests/test_toolbox.py @@ -5,11 +5,12 @@ from __future__ import annotations from datetime import datetime, timezone +from typing import cast from unittest.mock import AsyncMock import httpx import pytest -from agent_framework import SkillsProvider +from agent_framework import SkillsProvider, SkillsSourceContext, SupportsAgentRun from azure.ai.agentserver.core import ( FoundryAgentRequestContext, reset_request_context, @@ -25,6 +26,17 @@ ) +class _StubAgent: + """Minimal stand-in for a ``SupportsAgentRun`` used to build a source context.""" + + name = "test-agent" + + +def _source_context() -> SkillsSourceContext: + """Build a :class:`SkillsSourceContext` for exercising skill sources in tests.""" + return SkillsSourceContext(agent=cast(SupportsAgentRun, _StubAgent())) + + class _FakeAccessToken: def __init__(self, token: str) -> None: self.token = token @@ -162,7 +174,7 @@ async def test_skills_source_requires_connection() -> None: assert toolbox.session is None source = _FoundryToolboxSkillsSource(toolbox) with pytest.raises(RuntimeError, match="not connected"): - await source.get_skills() + await source.get_skills(_source_context()) async def test_skills_source_uses_connected_session(monkeypatch: pytest.MonkeyPatch) -> None: @@ -179,12 +191,12 @@ class _StubSkillsSource: def __init__(self, *, client: object) -> None: captured["client"] = client - async def get_skills(self) -> list[str]: + async def get_skills(self, context: SkillsSourceContext) -> list[str]: return ["skill-a"] monkeypatch.setattr("agent_framework_foundry_hosting._toolbox.MCPSkillsSource", _StubSkillsSource) - result = await _FoundryToolboxSkillsSource(toolbox).get_skills() + result = await _FoundryToolboxSkillsSource(toolbox).get_skills(_source_context()) assert result == ["skill-a"] assert captured["client"] is sentinel_session diff --git a/python/samples/02-agents/skills/skill_filtering/skill_filtering.py b/python/samples/02-agents/skills/skill_filtering/skill_filtering.py index 35fcd58495..f25a0fa6b3 100644 --- a/python/samples/02-agents/skills/skill_filtering/skill_filtering.py +++ b/python/samples/02-agents/skills/skill_filtering/skill_filtering.py @@ -76,7 +76,7 @@ async def main() -> None: FilteringSkillsSource( FileSkillsSource(str(skills_dir), script_runner=subprocess_script_runner), # Only keep the volume-converter skill - predicate=lambda s: s.frontmatter.name != "length-converter", + predicate=lambda skill, context: skill.frontmatter.name != "length-converter", ) )