diff --git a/AGENTS.md b/AGENTS.md index cb07d5b31..65ecf105f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -180,6 +180,8 @@ Intrinsics are specialized LoRA adapters that add task-specific capabilities (RA | `rag` | `find_citations(response, documents, context, backend)` | Document sentences supporting the response | | `rag` | `check_context_relevance(question, document, context, backend)` | Whether a document is relevant (0–1); only supported for granite-4.0, not granite-4.1 | | `rag` | `flag_hallucinated_content(response, documents, context, backend)` | Flag potentially hallucinated sentences | +| `guardian` | `factuality_detection(context, backend, *, documents=None, model_options=None)` | Determine if the last response is factually incorrect ("yes"/"no") | +| `guardian` | `factuality_correction(context, backend, *, documents=None, model_options=None)` | Correct the last response to be factually accurate | ```python from mellea.backends.huggingface import LocalHFBackend diff --git a/docs/examples/intrinsics/factuality_correction.py b/docs/examples/intrinsics/factuality_correction.py index 80310a009..547e7a64c 100644 --- a/docs/examples/intrinsics/factuality_correction.py +++ b/docs/examples/intrinsics/factuality_correction.py @@ -86,11 +86,8 @@ ) # NOTE: This example can also be run with the OpenAIBackend using a GraniteSwitch model. See docs/examples/granite-switch/. -ctx = ( - ctx.add(document) - .add(Message("user", user_text)) - .add(Message("assistant", response_text)) -) +ctx = ctx.add(Message("user", user_text)) +ctx = ctx.add(Message("assistant", response_text)) -result = guardian.factuality_correction(ctx, backend) +result = guardian.factuality_correction(ctx, backend, documents=[document]) print(f"Result of factuality correction: {result}") # corrected response string diff --git a/docs/examples/intrinsics/factuality_detection.py b/docs/examples/intrinsics/factuality_detection.py index a292eea94..71cbef9f5 100644 --- a/docs/examples/intrinsics/factuality_detection.py +++ b/docs/examples/intrinsics/factuality_detection.py @@ -29,11 +29,8 @@ ) # NOTE: This example can also be run with the OpenAIBackend using a GraniteSwitch model. See docs/examples/granite-switch/. -ctx = ( - ctx.add(document) - .add(Message("user", user_text)) - .add(Message("assistant", response_text)) -) +ctx = ctx.add(Message("user", user_text)) +ctx = ctx.add(Message("assistant", response_text)) -result = guardian.factuality_detection(ctx, backend) +result = guardian.factuality_detection(ctx, backend, documents=[document]) print(f"Result of factuality detection: {result}") # string "yes" or "no" diff --git a/mellea/stdlib/components/intrinsic/_util.py b/mellea/stdlib/components/intrinsic/_util.py index 94349612d..6905adf66 100644 --- a/mellea/stdlib/components/intrinsic/_util.py +++ b/mellea/stdlib/components/intrinsic/_util.py @@ -13,23 +13,25 @@ from ....stdlib import functional as mfuncs from ...components import Document from ...context import ChatContext +from ..chat import Message from .intrinsic import Intrinsic def _resolve_question( question: str | None, context: ChatContext, backend: Backend | None = None -) -> tuple[str, ChatContext]: - """Return ``(question_text, context_to_use)``. +) -> tuple[str, ChatContext, list[Document] | None]: + """Return ``(question_text, context_to_use, documents)``. - When *question* is not ``None``, returns it with *context* unchanged. + When *question* is not ``None``, returns it with *context* unchanged and no documents. When ``None``, extracts the text from the last turn's ``model_input`` - and rewinds *context* to before that element. + and rewinds *context* to before that element. Also extracts documents if the + last input is a Message. Supports ``Message`` (via ``.content``), ``CBlock`` (via ``.value``), and generic ``Component`` types (via ``TemplateFormatter.print()``). """ if question is not None: - return question, context + return question, context, None from ....core import CBlock, Component from ..chat import Message @@ -40,8 +42,10 @@ def _resolve_question( ) model_input = turn.model_input + documents: list[Document] | None = None if isinstance(model_input, Message): text = model_input.content + documents = model_input._docs elif isinstance(model_input, CBlock): if model_input.value is None: raise ValueError( @@ -65,29 +69,48 @@ def _resolve_question( rewound = context.previous_node if rewound is None: raise ValueError("Cannot rewind context past the root node") - return text, rewound # type: ignore[return-value] + return text, rewound, documents # type: ignore[return-value] def _resolve_response( response: str | None, context: ChatContext -) -> tuple[str, ChatContext]: - """Return ``(response_text, context_to_use)``. +) -> tuple[str, ChatContext, list[Document] | None]: + """Return ``(response_text, context_to_use, documents)``. - When *response* is not ``None``, returns it with *context* unchanged. - When ``None``, extracts from the last turn's ``output.value`` and rewinds - *context* to before that output. + When *response* is not ``None``, returns it with *context* unchanged and no documents. + When ``None``, extracts from the last turn's ``output.value`` (generated) or + ``model_input.content`` (manually-added Message), then rewinds *context* + to before that turn. Also extracts documents if the last message is a Message. """ if response is not None: - return response, context + return response, context, None turn = context.last_turn() - if turn is None or turn.output is None: + if turn is None: raise ValueError("response is None and context has no last turn with output") - if turn.output.value is None: - raise ValueError("response is None and last turn output has no value") + + documents: list[Document] | None = None + # Try generated output first + if turn.output is not None: + if turn.output.value is None: + raise ValueError("response is None and last turn output has no value") + response_text = turn.output.value + # Fall back to manually-added assistant Message + elif ( + turn.model_input is not None + and isinstance(turn.model_input, Message) + and turn.model_input.role == "assistant" + ): + response_text = turn.model_input.content + documents = turn.model_input._docs + else: + raise ValueError( + "response is None and context has no last turn with output or assistant message" + ) + rewound = context.previous_node if rewound is None: raise ValueError("Cannot rewind context past the root node") - return turn.output.value, rewound # type: ignore[return-value] + return response_text, rewound, documents # type: ignore[return-value] def call_intrinsic( diff --git a/mellea/stdlib/components/intrinsic/core.py b/mellea/stdlib/components/intrinsic/core.py index 8ac248a55..470a82444 100644 --- a/mellea/stdlib/components/intrinsic/core.py +++ b/mellea/stdlib/components/intrinsic/core.py @@ -9,7 +9,9 @@ from ._util import _resolve_response, call_intrinsic -def check_certainty(context: ChatContext, backend: AdapterMixin) -> float: +def check_certainty( + context: ChatContext, backend: AdapterMixin, model_options: dict | None = None +) -> float: """Estimate the model's certainty about its last response. Intrinsic function that evaluates how certain the model is about the @@ -19,11 +21,15 @@ def check_certainty(context: ChatContext, backend: AdapterMixin) -> float: Args: context: Chat context containing user question and assistant answer. backend: Backend instance that supports LoRA/aLoRA adapters. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Certainty score as a float (higher = more certain). """ - result_json = call_intrinsic("uncertainty", context, backend) + result_json = call_intrinsic( + "uncertainty", context, backend, model_options=model_options + ) return result_json["certainty"] @@ -37,7 +43,10 @@ def check_certainty(context: ChatContext, backend: AdapterMixin) -> float: def requirement_check( - context: ChatContext, backend: AdapterMixin, requirement: str + context: ChatContext, + backend: AdapterMixin, + requirement: str, + model_options: dict | None = None, ) -> float: """Detect if text adheres to provided requirements. @@ -49,13 +58,17 @@ def requirement_check( context: Chat context containing user question and assistant answer. backend: Backend instance that supports LoRA/aLoRA adapters. requirement: Set of requirements to satisfy. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Score as a float between 0.0 and 1.0 (higher = more likely satisfied). """ eval_message = f": {requirement}\n{_EVALUATION_PROMPT}" context = context.add(Message("user", eval_message)) - result_json = call_intrinsic("requirement-check", context, backend) + result_json = call_intrinsic( + "requirement-check", context, backend, model_options=model_options + ) return result_json["requirement_check"]["score"] @@ -64,6 +77,7 @@ def find_context_attributions( documents: collections.abc.Iterable[str | Document], context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> list[dict]: """Find sentences in conversation history and documents that most influence an LLM's response. @@ -71,39 +85,35 @@ def find_context_attributions( documents that were most important to the LLM in generating each sentence in the assistant response. - :param response: Assistant response. When ``None``, the response is extracted - from the last assistant output in ``context``. - :param documents: Documents that were used to generate ``response``. Each element - may be a ``Document`` or a plain string. Strings are wrapped in ``Document`` - with an auto-generated ``doc_id`` (``"0"``, ``"1"``, ...); for explicit - control, pass ``Document`` objects with ``doc_id`` set. ``Document`` objects - without ``doc_id`` trigger a warning because the intrinsic uses ``doc_id`` to - identify attribution sources. - :param context: Context of the dialog between user and assistant, ending with a - user query - :param backend: Backend that supports intrinsic adapters - - :return: List of records with the following fields: - * ``response_begin`` - * ``response_end`` - * ``response_text`` - * ``attribution_doc_id`` - * ``attribution_msg_index`` - * ``attribution_begin`` - * ``attribution_end`` - * ``attribution_text`` - Begin and end offsets are character offsets into their respective UTF-8 strings. + Args: + response: Assistant response. When ``None``, the response is extracted + from the last assistant output in ``context``. + documents: Documents that were used to generate ``response``. Each element + may be a ``Document`` or a plain string. Strings are wrapped in ``Document`` + with an auto-generated ``doc_id`` (``"0"``, ``"1"``, ...); for explicit + control, pass ``Document`` objects with ``doc_id`` set. ``Document`` objects + without ``doc_id`` trigger a warning because the intrinsic uses ``doc_id`` to + identify attribution sources. + context: Context of the dialog between user and assistant, ending with a + user query. + backend: Backend that supports intrinsic adapters. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. + + Returns: + List of records with the following fields: ``response_begin``, + ``response_end``, ``response_text``, ``attribution_doc_id``, + ``attribution_msg_index``, ``attribution_begin``, ``attribution_end``, + ``attribution_text``. Begin and end offsets are character offsets into + their respective UTF-8 strings. """ - response, context = _resolve_response(response, context) + response, context, resolved_docs = _resolve_response(response, context) + explicit_docs = _coerce_to_documents(documents, auto_doc_id=False) + docs_to_use = [*(explicit_docs or []), *(resolved_docs or [])] or None result_json = call_intrinsic( "context-attribution", - context.add( - Message( - "assistant", - response, - documents=_coerce_to_documents(documents, auto_doc_id=False), - ) - ), + context.add(Message("assistant", response, documents=docs_to_use)), backend, + model_options=model_options, ) return result_json diff --git a/mellea/stdlib/components/intrinsic/guardian.py b/mellea/stdlib/components/intrinsic/guardian.py index 3dcc843a9..2ca0cc0cb 100644 --- a/mellea/stdlib/components/intrinsic/guardian.py +++ b/mellea/stdlib/components/intrinsic/guardian.py @@ -6,25 +6,37 @@ ``io.yaml``. """ +import collections.abc + from ....backends import model_ids from ....backends.adapters import AdapterMixin +from ...components import Document from ...context import ChatContext from ..chat import Message -from ._util import call_intrinsic +from ..docs.document import _coerce_to_documents +from ._util import _resolve_response, call_intrinsic def policy_guardrails( - context: ChatContext, backend: AdapterMixin, policy_text: str + context: ChatContext, + backend: AdapterMixin, + policy_text: str, + model_options: dict | None = None, ) -> str: """Checks whether text complied with specified policy. Uses the policy_guardrails LoRA adapter to judge whether the scenario described in the last message in ``context`` is compliant with the given ``policy_text``. - :param context: Chat context containing the conversation to evaluate. - :param backend: Backend instance that supports LoRA adapters. - :param policy_text: Policy against with compliance is to be checked - :return: Compliance as a "Yes/No/Ambiguous" label (Yes = compliant). + Args: + context: Chat context containing the conversation to evaluate. + backend: Backend instance that supports LoRA adapters. + policy_text: Policy against which compliance is to be checked. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. + + Returns: + Compliance as a "Yes/No/Ambiguous" label (Yes = compliant). """ judge_criteria = "Policy: " + policy_text system_prompt = "You are a compliance agent trying to help determine whether a scenario is compliant with a given policy." @@ -34,7 +46,9 @@ def policy_guardrails( judge_protocol = f" {system_prompt}\n\n### Criteria: {judge_criteria}\n\n### Scoring Schema: {scoring_schema}" context = context.add(Message("user", judge_protocol)) - result_json = call_intrinsic("policy-guardrails", context, backend) + result_json = call_intrinsic( + "policy-guardrails", context, backend, model_options=model_options + ) if "label" not in result_json.keys() and "score" not in result_json.keys(): raise Exception( @@ -138,6 +152,9 @@ def guardian_check( backend: AdapterMixin, criteria: str, target_role: str = "assistant", + *, + documents: collections.abc.Iterable[str | Document] | None = None, + model_options: dict | None = None, ) -> float: """Check whether text meets specified safety/quality criteria. @@ -152,10 +169,36 @@ def guardian_check( criteria string. target_role: Role whose last message is being evaluated (``"user"`` or ``"assistant"``). + documents: Optional document snippets to attach to the target message. + Primarily used for the ``"groundedness"`` criterion, to provide + reference context for grounding checks. Each element may be a + ``Document`` or a plain string (automatically wrapped in ``Document``). + Keyword-only. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Risk score as a float between 0.0 (no risk) and 1.0 (risk detected). + + Raises: + ValueError: If documents are provided but target_role is not "assistant". """ + if documents is not None and target_role != "assistant": + raise ValueError( + "documents parameter is only supported when target_role='assistant'" + ) + + if documents is not None and target_role == "assistant": + response, context, resolved_docs = _resolve_response(None, context) + explicit_docs = _coerce_to_documents(documents) + docs_to_use = explicit_docs + if resolved_docs is not None: + if docs_to_use is not None: + docs_to_use.extend(resolved_docs) + else: + docs_to_use = resolved_docs + context = context.add(Message("assistant", response, documents=docs_to_use)) + criteria_text = CRITERIA_BANK.get(criteria, criteria) scoring = ( @@ -168,21 +211,37 @@ def guardian_check( f"### Scoring Schema: {scoring}" ) context = context.add(Message("user", judge_protocol)) - result_json = call_intrinsic("guardian-core", context, backend) + result_json = call_intrinsic( + "guardian-core", context, backend, model_options=model_options + ) return result_json["guardian"]["score"] -def factuality_detection(context: ChatContext, backend: AdapterMixin) -> float: - """Determine is the last response is factually incorrect. +def factuality_detection( + context: ChatContext, + backend: AdapterMixin, + *, + documents: collections.abc.Iterable[str | Document] | None = None, + model_options: dict | None = None, +) -> str: + """Determine if a response is factually incorrect. - Intrinsic function that evaluates the factuality of the - assistant's response to a user's question. The context should end with - a user question followed by an assistant answer. + Intrinsic function that evaluates the factuality of an assistant's response + to a user's question. The context should typically end with a user question + followed by an assistant answer. - :param context: Chat context containing user question and assistant answer. - :param backend: Backend instance that supports LoRA/aLoRA adapters. + Args: + context: Chat context containing user question and assistant answer. + backend: Backend instance that supports LoRA/aLoRA adapters. + documents: Document snippets that provide factual context for evaluation. + Each element may be a ``Document`` or a plain string (automatically + wrapped in ``Document``). When provided, replaces any documents + on the last assistant message. Keyword-only. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. - :return: Factuality score as a "yes/no" label (yes = factually incorrect). + Returns: + Factuality score as a "yes/no" label (yes = factually incorrect). """ detector_message = """ As a judge agent, your role is to help assess whether the provided text meets the given judging criteria, utilizing all available information, including conversations, documents, and tools. @@ -192,21 +251,60 @@ def factuality_detection(context: ChatContext, backend: AdapterMixin) -> float: ### Scoring Schema: If the last assistant's text meets the criteria, return 'yes'; otherwise, return 'no'. """ + if documents is not None: + turn = context.last_turn() + if turn is not None: + if turn.output is not None and turn.output.value is not None: + response_text = turn.output.value + elif ( + turn.model_input is not None + and isinstance(turn.model_input, Message) + and turn.model_input.role == "assistant" + ): + response_text = turn.model_input.content + else: + response_text = None + + if response_text is not None: + context = context.add( + Message( + "assistant", + response_text, + documents=_coerce_to_documents(documents), + ) + ) + context = context.add(Message("user", detector_message)) - result_json = call_intrinsic("factuality-detection", context, backend) + result_json = call_intrinsic( + "factuality-detection", context, backend, model_options=model_options + ) return result_json["score"] -def factuality_correction(context: ChatContext, backend: AdapterMixin) -> float: - """Corrects the last response so that it is factually correct. +def factuality_correction( + context: ChatContext, + backend: AdapterMixin, + *, + documents: collections.abc.Iterable[str | Document] | None = None, + model_options: dict | None = None, +) -> str: + """Correct a response to be factually accurate. Intrinsic function that corrects the assistant's response to a user's question relative to the given contextual information. - :param context: Chat context containing user question and assistant answer. - :param backend: Backend instance that supports LoRA/aLoRA adapters. + Args: + context: Chat context containing user question and assistant answer. + backend: Backend instance that supports LoRA/aLoRA adapters. + documents: Document snippets that provide factual context for correction. + Each element may be a ``Document`` or a plain string (automatically + wrapped in ``Document``). When provided, replaces any documents + on the last assistant message. Keyword-only. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. - :return: Correct assistant response. + Returns: + Corrected assistant response. """ corrector_message = """ As a judge agent, your role is to help assess whether the provided text meets the given judging criteria, utilizing all available information, including conversations, documents, and tools. @@ -216,6 +314,31 @@ def factuality_correction(context: ChatContext, backend: AdapterMixin) -> float: ### Scoring Schema: If the last assistant's text meets the criteria, return a corrected version of the assistant's message based on the given context; otherwise, return 'none'. """ + if documents is not None: + turn = context.last_turn() + if turn is not None: + if turn.output is not None and turn.output.value is not None: + response_text = turn.output.value + elif ( + turn.model_input is not None + and isinstance(turn.model_input, Message) + and turn.model_input.role == "assistant" + ): + response_text = turn.model_input.content + else: + response_text = None + + if response_text is not None: + context = context.add( + Message( + "assistant", + response_text, + documents=_coerce_to_documents(documents), + ) + ) + context = context.add(Message("user", corrector_message)) - result_json = call_intrinsic("factuality-correction", context, backend) + result_json = call_intrinsic( + "factuality-correction", context, backend, model_options=model_options + ) return result_json["correction"] diff --git a/mellea/stdlib/components/intrinsic/rag.py b/mellea/stdlib/components/intrinsic/rag.py index 4dd20a1c8..bfa4c42e9 100644 --- a/mellea/stdlib/components/intrinsic/rag.py +++ b/mellea/stdlib/components/intrinsic/rag.py @@ -15,6 +15,7 @@ def check_answerability( documents: collections.abc.Iterable[str | Document], context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> str: """Test a user's question for answerability. @@ -31,23 +32,34 @@ def check_answerability( context: Chat context containing the conversation thus far. backend: Backend instance that supports adding the LoRA or aLoRA adapters for answerability checks. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: A string value of either "answerable" or "unanswerable" """ - question, context = _resolve_question(question, context, backend) + question, context, resolved_docs = _resolve_question(question, context, backend) + explicit_docs = _coerce_to_documents(documents) + docs_to_use = explicit_docs + if resolved_docs is not None: + if docs_to_use is not None: + docs_to_use.extend(resolved_docs) + else: + docs_to_use = resolved_docs result_json = call_intrinsic( "answerability", - context.add( - Message("user", question, documents=_coerce_to_documents(documents)) - ), + context.add(Message("user", question, documents=docs_to_use)), backend, + model_options=model_options, ) return result_json["answerability"] def rewrite_question( - question: str | None, context: ChatContext, backend: AdapterMixin + question: str | None, + context: ChatContext, + backend: AdapterMixin, + model_options: dict | None = None, ) -> str: """Rewrite a user's question for retrieval. @@ -60,13 +72,18 @@ def rewrite_question( user message in ``context``. context: Chat context containing the conversation thus far. backend: Backend instance that supports adding the LoRA or aLoRA adapters. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Rewritten version of ``question``. """ - question, context = _resolve_question(question, context, backend) + question, context, resolved_docs = _resolve_question(question, context, backend) result_json = call_intrinsic( - "query_rewrite", context.add(Message("user", question)), backend + "query_rewrite", + context.add(Message("user", question, documents=resolved_docs)), + backend, + model_options=model_options, ) return result_json["rewritten_question"] @@ -76,6 +93,7 @@ def clarify_query( documents: collections.abc.Iterable[str | Document], context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> str: """Generate clarification for an ambiguous query. @@ -92,18 +110,26 @@ def clarify_query( context: Chat context containing the conversation thus far. backend: Backend instance that supports the adapters that implement this intrinsic. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Clarification question string (e.g., "Do you mean A or B?"), or the string "CLEAR" if no clarification is needed. """ - question, context = _resolve_question(question, context, backend) + question, context, resolved_docs = _resolve_question(question, context, backend) + explicit_docs = _coerce_to_documents(documents) + docs_to_use = explicit_docs + if resolved_docs is not None: + if docs_to_use is not None: + docs_to_use.extend(resolved_docs) + else: + docs_to_use = resolved_docs result_json = call_intrinsic( "query_clarification", - context.add( - Message("user", question, documents=_coerce_to_documents(documents)) - ), + context.add(Message("user", question, documents=docs_to_use)), backend, + model_options=model_options, ) return result_json["clarification"] @@ -113,6 +139,7 @@ def find_citations( documents: collections.abc.Iterable[str | Document], context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> list[dict]: """Find information in documents that supports an assistant response. @@ -132,6 +159,8 @@ def find_citations( the user has just asked a question that will be answered with RAG documents. backend: Backend that supports one of the adapters that implements this intrinsic. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: List of records with the following fields: ``response_begin``, @@ -139,17 +168,19 @@ def find_citations( ``citation_end``, ``citation_text``. Begin and end offsets are character offsets into their respective UTF-8 strings. """ - response, context = _resolve_response(response, context) + response, context, resolved_docs = _resolve_response(response, context) + explicit_docs = _coerce_to_documents(documents, auto_doc_id=False) + docs_to_use = explicit_docs + if resolved_docs is not None: + if docs_to_use is not None: + docs_to_use.extend(resolved_docs) + else: + docs_to_use = resolved_docs result_json = call_intrinsic( "citations", - context.add( - Message( - "assistant", - response, - documents=_coerce_to_documents(documents, auto_doc_id=False), - ) - ), + context.add(Message("assistant", response, documents=docs_to_use)), backend, + model_options=model_options, ) return result_json @@ -159,6 +190,7 @@ def check_context_relevance( document: str | Document, context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> str: """Test whether a document is relevant to a user's question. @@ -174,6 +206,8 @@ def check_context_relevance( context: The chat up to the point where the user asked a question. backend: Backend instance that supports the adapters that implement this intrinsic. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: Context relevance judgement as one of the following strings: @@ -181,14 +215,15 @@ def check_context_relevance( - "irrelevant" - "partially relevant" """ - question, context = _resolve_question(question, context, backend) + question, context, resolved_docs = _resolve_question(question, context, backend) document = _coerce_to_document(document) result_json = call_intrinsic( "context_relevance", - context.add(Message("user", question)), + context.add(Message("user", question, documents=resolved_docs)), backend, # Target document is passed as an argument kwargs={"document_content": document.text}, + model_options=model_options, ) return result_json["context_relevance"] @@ -198,6 +233,7 @@ def flag_hallucinated_content( documents: collections.abc.Iterable[str | Document], context: ChatContext, backend: AdapterMixin, + model_options: dict | None = None, ) -> list[dict]: """Flag potentially-hallucinated sentences in an agent's response. @@ -215,18 +251,26 @@ def flag_hallucinated_content( context: A chat log that ends with a user asking a question. backend: Backend instance that supports the adapters that implement this intrinsic. + model_options: Optional model options to pass to the backend (e.g., + temperature, max_tokens). Defaults to ``{ModelOption.TEMPERATURE: 0.0}``. Returns: List of records with the following fields: ``response_begin``, ``response_end``, ``response_text``, ``faithfulness``, ``explanation``. """ - response, context = _resolve_response(response, context) + response, context, resolved_docs = _resolve_response(response, context) + explicit_docs = _coerce_to_documents(documents) + docs_to_use = explicit_docs + if resolved_docs is not None: + if docs_to_use is not None: + docs_to_use.extend(resolved_docs) + else: + docs_to_use = resolved_docs result_json = call_intrinsic( "hallucination_detection", - context.add( - Message("assistant", response, documents=_coerce_to_documents(documents)) - ), + context.add(Message("assistant", response, documents=docs_to_use)), backend, + model_options=model_options, ) return result_json diff --git a/test/backends/test_openai_intrinsics.py b/test/backends/test_openai_intrinsics.py index 6181ab03f..269e361d7 100644 --- a/test/backends/test_openai_intrinsics.py +++ b/test/backends/test_openai_intrinsics.py @@ -556,14 +556,12 @@ def test_call_intrinsic_factuality_detection(call_intrinsic_backend): for d in data.get("extra_body", {}).get("documents", []) ] messages = data["messages"] - for i, m in enumerate(messages): - is_last = i == len(messages) - 1 - if is_last and docs: - context = context.add(Message(m["role"], m["content"], documents=docs)) - else: - context = context.add(Message(m["role"], m["content"])) + for m in messages: + context = context.add(Message(m["role"], m["content"])) - result = guardian.factuality_detection(context, call_intrinsic_backend) + result = guardian.factuality_detection( + context, call_intrinsic_backend, documents=docs + ) assert result in ("yes", "no") @@ -581,12 +579,10 @@ def test_call_intrinsic_factuality_correction(call_intrinsic_backend): for d in data.get("extra_body", {}).get("documents", []) ] messages = data["messages"] - for i, m in enumerate(messages): - is_last = i == len(messages) - 1 - if is_last and docs: - context = context.add(Message(m["role"], m["content"], documents=docs)) - else: - context = context.add(Message(m["role"], m["content"])) + for m in messages: + context = context.add(Message(m["role"], m["content"])) - result = guardian.factuality_correction(context, call_intrinsic_backend) + result = guardian.factuality_correction( + context, call_intrinsic_backend, documents=docs + ) assert isinstance(result, str) diff --git a/test/stdlib/components/intrinsic/test_guardian.py b/test/stdlib/components/intrinsic/test_guardian.py index 0fd16f55d..cb9eeb7e1 100644 --- a/test/stdlib/components/intrinsic/test_guardian.py +++ b/test/stdlib/components/intrinsic/test_guardian.py @@ -43,21 +43,30 @@ def _backend(): cleanup_gpu_backend(backend_, "test_guardian") -def _read_guardian_input(file_name: str) -> ChatContext: - """Read test input and convert to a ChatContext.""" +def _read_guardian_input(file_name: str) -> tuple[ChatContext, list[Document]]: + """Read test input and convert to a ChatContext and documents list.""" with open(DATA_ROOT / "input_json" / file_name, encoding="utf-8") as f: json_data = json.load(f) context = ChatContext() for m in json_data["messages"]: context = context.add(Message(m["role"], m["content"])) - return context + + # Extract documents if present + documents = [] + if "extra_body" in json_data and "documents" in json_data["extra_body"]: + for doc_data in json_data["extra_body"]["documents"]: + documents.append( + Document(text=doc_data["text"], doc_id=doc_data.get("doc_id")) + ) + + return context, documents @pytest.mark.qualitative def test_policy_guardrails(backend): """Verify that policy_guardrails checks scenaio compliance with policy.""" - context = _read_guardian_input("policy_guardrails.json") + context, _ = _read_guardian_input("policy_guardrails.json") policy_text = 'hiring managers should steer away from any questions that directly seek information about protected classes—such as "how old are you," "where are you from," "what year did you graduate" or "what are your plans for having kids."' @@ -70,7 +79,7 @@ def test_policy_guardrails(backend): @pytest.mark.qualitative def test_guardian_check_harm(backend): """Verify that guardian_check detects harmful prompts.""" - context = _read_guardian_input("guardian_core.json") + context, _ = _read_guardian_input("guardian_core.json") # First call triggers adapter loading result = guardian.guardian_check( @@ -164,8 +173,32 @@ def test_guardian_check_function_call(backend): @pytest.mark.qualitative def test_factuality_detection(backend): """Verify that the factuality detection intrinsic functions properly.""" - context = _read_guardian_input("factuality_detection.json") + context, documents = _read_guardian_input("factuality_detection.json") + + # Test with documents passed as argument + result = guardian.factuality_detection(context, backend, documents=documents) + assert result == "yes" or result == "no" + +@pytest.mark.qualitative +def test_factuality_detection_from_context(backend): + """Verify factuality detection works when documents are already in the last message.""" + context, documents = _read_guardian_input("factuality_detection.json") + + # Rebuild context with documents attached to the last assistant message + last_turn = context.last_turn() + if last_turn and last_turn.model_input: + context = ChatContext().add(Message("user", "What is the question?")) + if isinstance(last_turn.model_input, Message): + context = context.add( + Message( + last_turn.model_input.role, + last_turn.model_input.content, + documents=documents, + ) + ) + + # Call without documents= argument; documents should be picked up from context result = guardian.factuality_detection(context, backend) assert result == "yes" or result == "no" @@ -173,8 +206,32 @@ def test_factuality_detection(backend): @pytest.mark.qualitative def test_factuality_correction(backend): """Verify that the factuality correction intrinsic functions properly.""" - context = _read_guardian_input("factuality_correction.json") + context, documents = _read_guardian_input("factuality_correction.json") + + # Test with documents passed as argument + result = guardian.factuality_correction(context, backend, documents=documents) + assert isinstance(result, str) + + +@pytest.mark.qualitative +def test_factuality_correction_from_context(backend): + """Verify factuality correction works when documents are already in the last message.""" + context, documents = _read_guardian_input("factuality_correction.json") + + # Rebuild context with documents attached to the last assistant message + last_turn = context.last_turn() + if last_turn and last_turn.model_input: + context = ChatContext().add(Message("user", "What is the question?")) + if isinstance(last_turn.model_input, Message): + context = context.add( + Message( + last_turn.model_input.role, + last_turn.model_input.content, + documents=documents, + ) + ) + # Call without documents= argument; documents should be picked up from context result = guardian.factuality_correction(context, backend) assert isinstance(result, str) diff --git a/test/stdlib/components/intrinsic/test_resolve_util.py b/test/stdlib/components/intrinsic/test_resolve_util.py index 90c1822ba..d3ca9b756 100644 --- a/test/stdlib/components/intrinsic/test_resolve_util.py +++ b/test/stdlib/components/intrinsic/test_resolve_util.py @@ -14,15 +14,17 @@ class TestResolveQuestion: def test_explicit_string(self): ctx = ChatContext() - text, returned_ctx = _resolve_question("hello", ctx) + text, returned_ctx, docs = _resolve_question("hello", ctx) assert text == "hello" assert returned_ctx is ctx + assert docs is None def test_from_context(self): ctx = ChatContext().add(Message("user", "What is 2+2?")) - text, rewound = _resolve_question(None, ctx) + text, rewound, docs = _resolve_question(None, ctx) assert text == "What is 2+2?" assert rewound.is_root_node # type: ignore[union-attr] + assert docs is None def test_context_with_prior_messages(self): ctx = ( @@ -31,8 +33,9 @@ def test_context_with_prior_messages(self): .add(Message("assistant", "reply")) .add(Message("user", "second")) ) - text, rewound = _resolve_question(None, ctx) + text, rewound, docs = _resolve_question(None, ctx) assert text == "second" + assert docs is None # Rewound context should end with the assistant reply last = rewound.last_turn() # type: ignore[union-attr] assert last is not None @@ -46,9 +49,10 @@ def test_empty_context_raises(self): def test_from_cblock(self): ctx = ChatContext().add(CBlock("raw question")) - text, rewound = _resolve_question(None, ctx) + text, rewound, docs = _resolve_question(None, ctx) assert text == "raw question" assert rewound.is_root_node # type: ignore[union-attr] + assert docs is None def test_cblock_none_value_raises(self): ctx = ChatContext().add(CBlock(None)) @@ -57,15 +61,17 @@ def test_cblock_none_value_raises(self): def test_from_component(self): ctx = ChatContext().add(Document("some document text")) - text, rewound = _resolve_question(None, ctx) + text, rewound, docs = _resolve_question(None, ctx) assert "some document text" in text assert rewound.is_root_node # type: ignore[union-attr] + assert docs is None def test_from_instruction_component(self): ctx = ChatContext().add(Instruction("Summarize the article")) - text, rewound = _resolve_question(None, ctx) + text, rewound, docs = _resolve_question(None, ctx) assert "Summarize the article" in text assert rewound.is_root_node # type: ignore[union-attr] + assert docs is None def test_from_component_uses_backend_formatter(self): from unittest.mock import MagicMock @@ -74,18 +80,30 @@ def test_from_component_uses_backend_formatter(self): mock_backend = MagicMock() mock_backend.formatter.print.return_value = "custom formatted" - text, rewound = _resolve_question(None, ctx, backend=mock_backend) + text, rewound, docs = _resolve_question(None, ctx, backend=mock_backend) assert text == "custom formatted" mock_backend.formatter.print.assert_called_once() assert rewound.is_root_node # type: ignore[union-attr] + assert docs is None + + def test_from_message_with_documents(self): + doc = Document(text="Supporting text", doc_id="1") + ctx = ChatContext().add(Message("user", "What is 2+2?", documents=[doc])) + text, rewound, docs = _resolve_question(None, ctx) + assert text == "What is 2+2?" + assert docs is not None + assert len(docs) == 1 + assert docs[0].text == "Supporting text" + assert rewound.is_root_node # type: ignore[union-attr] class TestResolveResponse: def test_explicit_string(self): ctx = ChatContext() - text, returned_ctx = _resolve_response("answer", ctx) + text, returned_ctx, docs = _resolve_response("answer", ctx) assert text == "answer" assert returned_ctx is ctx + assert docs is None def test_from_context(self): ctx = ( @@ -93,8 +111,9 @@ def test_from_context(self): .add(Message("user", "question")) .add(ModelOutputThunk(value="The answer is 4.")) ) - text, rewound = _resolve_response(None, ctx) + text, rewound, docs = _resolve_response(None, ctx) assert text == "The answer is 4." + assert docs is None # Rewound context should still have the user question last = rewound.last_turn() # type: ignore[union-attr] assert last is not None @@ -110,3 +129,21 @@ def test_none_value_raises(self): ctx = ChatContext().add(ModelOutputThunk(value=None)) with pytest.raises(ValueError, match="no value"): _resolve_response(None, ctx) + + def test_from_message_with_documents(self): + doc = Document(text="Supporting text", doc_id="1") + ctx = ( + ChatContext() + .add(Message("user", "question")) + .add(Message("assistant", "The answer is 4.", documents=[doc])) + ) + text, rewound, docs = _resolve_response(None, ctx) + assert text == "The answer is 4." + assert docs is not None + assert len(docs) == 1 + assert docs[0].text == "Supporting text" + # Rewound context should still have the user question + last = rewound.last_turn() # type: ignore[union-attr] + assert last is not None + assert isinstance(last.model_input, Message) + assert last.model_input.content == "question"