diff --git a/backend/app/services/judgment_generation.py b/backend/app/services/judgment_generation.py index 121f1f8d..353828b3 100644 --- a/backend/app/services/judgment_generation.py +++ b/backend/app/services/judgment_generation.py @@ -14,6 +14,8 @@ from __future__ import annotations +import asyncio +import json from collections.abc import Sequence from typing import Any @@ -27,6 +29,7 @@ from backend.app.adapters.protocol import NativeQuery, QueryTemplate from backend.app.db import repo from backend.app.domain.study.template_defaults import compute_default_params +from backend.app.domain.ubi import LlmRateCallback from backend.app.llm.budget_gate import BudgetExceededError, peek_daily_total, safe_record_cost from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost from backend.app.llm.openai_judge import rate_query_batch @@ -67,10 +70,8 @@ def _build_doc_inputs(hits: Sequence[Any]) -> list[dict[str, str]]: # dump anyway now that this helper is service-layer: a future caller # could pass a hit whose source holds a non-serializable value, and # a TypeError here would abort the whole judgment run. - import json as _json - try: - body_raw = _json.dumps(source, ensure_ascii=False) + body_raw = json.dumps(source, ensure_ascii=False) except TypeError: body_raw = str(source) body = str(body_raw) @@ -360,3 +361,153 @@ async def process_judgment_query( duration_ms=result.duration_ms, ) return True + + +def make_hybrid_llm_rate_callback( + *, + openai_client: AsyncOpenAI, + model: str, + rubric: str, + bundle_system: str, + target: str, + adapter: Any, + redis: Redis, + budget_usd: float, + llm_fill_calls_counter: list[int], +) -> LlmRateCallback: + """Construct the :class:`LlmRateCallback` bound for UBI hybrid mode. + + The callback receives ``[(query_id, doc_id, query_text), ...]`` for + pairs the inner converter couldn't rate (``impression_count < + llm_fill_threshold``). It groups by ``query_text``, fetches doc + bodies via :func:`adapter.get_document`, calls + :func:`rate_query_batch`, records cost through the daily budget + gate, and returns ``{(query_id, doc_id): rating}``. + + Lives in the service layer (not the worker) alongside the rest of the + judgment composition; the UBI worker constructs the OpenAI client and + passes it in. ``llm_fill_calls_counter`` is a mutable list (length-1 int + counter) so the worker can read the per-list LLM-call total for the + calibration JSONB without threading another mutable through. + """ + + async def _callback( + pairs: list[tuple[str, str, str]], + ) -> dict[tuple[str, str], int]: + out: dict[tuple[str, str], int] = {} + if not pairs: + return out + + # Group by query_text → run one rate_query_batch per query. + by_query: dict[str, list[tuple[str, str]]] = {} + query_id_for_text: dict[str, str] = {} + for qid, did, qtext in pairs: + by_query.setdefault(qtext, []).append((qid, did)) + query_id_for_text.setdefault(qtext, qid) + + for query_text, qd_pairs in by_query.items(): + qid = query_id_for_text[query_text] + + # Pre-call budget peek (spec FR-2 + FR-5). + if budget_usd > 0: + current = await peek_daily_total(redis) + est_max = estimated_max_call_cost(model) + if current + est_max > budget_usd: + raise BudgetExceededError( + f"current ${current:.4f} + estimated ${est_max:.4f} " + f"> budget ${budget_usd:.4f}" + ) + + # Fetch doc bodies for this query's pairs. Map each prompt ordinal + # back to the FULL (query_id, doc_id) tuple — not just doc_id. + # Multiple distinct internal query_ids can share the same + # query_text (duplicate rows in the operator's query set), and + # they're grouped together here; mapping prompt_id → doc_id alone + # would attribute every rating to the single representative `qid` + # and silently drop ratings for the others (Gemini PR #317 + # findings #2 + #3). + # The per-pair doc fetches are independent + idempotent, so issue + # them concurrently rather than serially round-tripping the engine. + docs = await asyncio.gather( + *(adapter.get_document(target, doc_id) for _, doc_id in qd_pairs) + ) + doc_inputs: list[dict[str, str]] = [] + prompt_id_to_real: dict[str, tuple[str, str]] = {} + for i, ((pair_query_id, doc_id), doc) in enumerate(zip(qd_pairs, docs, strict=True)): + source = getattr(doc, "source", None) if doc is not None else None + body_raw: str + if isinstance(source, dict) and source.get("body"): + body_raw = str(source["body"]) + elif source is not None: + body_raw = json.dumps(source, ensure_ascii=False) + else: + body_raw = "" + body = body_raw[:_DOC_BODY_CHAR_LIMIT] + prompt_id = f"item-{i}" + doc_inputs.append({"doc_id": prompt_id, "body": body}) + prompt_id_to_real[prompt_id] = (pair_query_id, doc_id) + + expected = set(prompt_id_to_real.keys()) + user_prompt = render_user_prompt( + rubric_text=rubric, + query_text=query_text, + docs=doc_inputs, + ) + system_prompt = f"{bundle_system}\n\n\n{rubric}\n" + + try: + result = await rate_query_batch( + client=openai_client, + model=model, + system_prompt=system_prompt, + user_prompt=user_prompt, + expected_doc_ids=expected, + ) + except ( + openai.AuthenticationError, + openai.PermissionDeniedError, + openai.BadRequestError, + openai.NotFoundError, + UnknownModelPricingError, + ): + # Persistent provider misconfig OR unknown model pricing — + # propagate so the worker fails the whole list (via + # fail_on_budget_or_pricing_error) rather than silently + # skipping the query. With the budget disabled (budget_usd + # <= 0) the pre-call estimated_max_call_cost guard above is + # skipped, so this is the path that surfaces an unknown-pricing + # model from rate_query_batch's cost computation. + raise + except Exception as exc: + # Per-query operational failure — skip this query's pairs + # (matches the LLM worker's isolation pattern). + logger.warning( + "ubi worker: hybrid LLM call failed for query, skipping", + event_type="ubi_hybrid_llm_failed", + query_id=qid, + error_type=type(exc).__name__, + error=str(exc), + ) + continue + + await safe_record_cost( + redis, + result.cost_usd, + logger=logger, + log_message="ubi worker: record_cost failed (budget telemetry only)", + event_type="ubi_record_cost_failed", + ) + + llm_fill_calls_counter[0] += 1 + + # Map prompt-only ordinals back to the real (query_id, doc_id) so + # ratings stay attributed to the requesting query. + for r in result.ratings: + real_pair = prompt_id_to_real.get(r.doc_id) + if real_pair is None: + continue # hallucinated id; skip + out[real_pair] = r.rating + + return out + + return _callback diff --git a/backend/tests/unit/workers/test_judgments_ubi_helpers.py b/backend/tests/unit/workers/test_judgments_ubi_helpers.py index 2820c6ad..3be97f1b 100644 --- a/backend/tests/unit/workers/test_judgments_ubi_helpers.py +++ b/backend/tests/unit/workers/test_judgments_ubi_helpers.py @@ -151,10 +151,16 @@ def test_worker_exports() -> None: from backend.workers import judgments_ubi assert callable(judgments_ubi.generate_judgments_from_ubi) - assert callable(judgments_ubi._make_llm_rate_callback) assert callable(judgments_ubi._apply_mapping_strategy) +def test_hybrid_rate_callback_factory_lives_in_service() -> None: + """The hybrid LLM-fill callback factory moved to the service layer.""" + from backend.app.services.judgment_generation import make_hybrid_llm_rate_callback + + assert callable(make_hybrid_llm_rate_callback) + + def test_worker_registered_in_worker_settings() -> None: """The boot-time WorkerSettings.functions list MUST include the UBI job so Arq dispatches it to the worker process (FR-5 step 3). @@ -182,8 +188,9 @@ def test_worker_no_direct_openai_construction_outside_callback_factory() -> None source = Path("backend/workers/judgments_ubi.py").read_text() tree = ast.parse(source) # Walk top-level definitions; the only AsyncOpenAI(...) call we permit is - # inside _build_converter (or _make_llm_rate_callback as a nested - # construction). Other code paths must NOT instantiate the client. + # inside _build_converter (which constructs the client and passes it to + # the service-layer hybrid callback factory). Other code paths must NOT + # instantiate the client. forbidden_paths: list[str] = [] class _Visitor(ast.NodeVisitor): @@ -205,7 +212,6 @@ def visit_Call(self, node: ast.Call) -> None: if isinstance(func, ast.Name) and func.id == "AsyncOpenAI": if not self._current_fn or self._current_fn[-1] not in { "_build_converter", - "_make_llm_rate_callback", }: where = ".".join(self._current_fn) or "" forbidden_paths.append(where) diff --git a/backend/workers/judgments_ubi.py b/backend/workers/judgments_ubi.py index fa366d4d..25b10047 100644 --- a/backend/workers/judgments_ubi.py +++ b/backend/workers/judgments_ubi.py @@ -52,8 +52,9 @@ llm_fill_threshold``; below the threshold, defers the pair to LLM-fill via an injected callback." The callback signature is ``Callable[[list[tuple[str, str, str]]], Awaitable[dict[tuple[str, str], int]]]`` -taking ``(query_id, doc_id, query_text)`` tuples. The worker-local -:func:`_make_llm_rate_callback` constructs the callback bound to: +taking ``(query_id, doc_id, query_text)`` tuples. The service-layer +:func:`backend.app.services.judgment_generation.make_hybrid_llm_rate_callback` +constructs the callback bound to: * :func:`adapter.get_document` for a doc-body fetch BY ID. The FR-2 callback contract is explicitly per-``(query_id, doc_id)`` pair (it @@ -75,11 +76,10 @@ from __future__ import annotations import time -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Callable, Sequence from datetime import datetime from typing import Any -import openai import structlog import uuid_utils from openai import AsyncOpenAI @@ -97,18 +97,14 @@ SignalsConverter, ) from backend.app.domain.ubi.converter import ConverterConfig -from backend.app.llm.budget_gate import ( - BudgetExceededError, - peek_daily_total, - safe_record_cost, -) -from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost -from backend.app.llm.openai_judge import rate_query_batch -from backend.app.llm.prompt_loader import load_judgment_prompts, render_user_prompt +from backend.app.llm.budget_gate import BudgetExceededError +from backend.app.llm.cost_model import UnknownModelPricingError +from backend.app.llm.prompt_loader import load_judgment_prompts from backend.app.services.cluster import build_adapter from backend.app.services.judgment_generation import ( fail_judgment_list, fail_on_budget_or_pricing_error, + make_hybrid_llm_rate_callback, ) from backend.app.services.ubi_errors import UbiNotEnabledError from backend.app.services.ubi_reader import UbiReader @@ -116,9 +112,6 @@ logger = structlog.get_logger(__name__) -_DOC_BODY_CHAR_LIMIT = 500 -"""Per-doc body truncation length (mirrors generate_judgments_llm).""" - # ---------------------------------------------------------------------------- # mapping_strategy join helper @@ -178,151 +171,6 @@ def _apply_mapping_strategy( return mapping, ambiguous_count -# ---------------------------------------------------------------------------- -# Hybrid LLM-fill callback factory -# ---------------------------------------------------------------------------- - - -def _make_llm_rate_callback( - *, - openai_client: AsyncOpenAI, - model: str, - rubric: str, - bundle_system: str, - target: str, - adapter: Any, - redis: Redis, - budget_usd: float, - llm_fill_calls_counter: list[int], -) -> LlmRateCallback: - """Construct the :class:`LlmRateCallback` bound for hybrid mode. - - The callback receives ``[(query_id, doc_id, query_text), ...]`` for - pairs the inner converter couldn't rate (``impression_count < - llm_fill_threshold``). It groups by ``query_text``, fetches doc - bodies via :func:`adapter.get_document`, calls - :func:`rate_query_batch`, records cost through the daily budget - gate, and returns ``{(query_id, doc_id): rating}``. - - ``llm_fill_calls_counter`` is a mutable list (length-1 int counter) - so the worker can read the per-list LLM-call total for the - calibration JSONB without threading another mutable through. - """ - - async def _callback( - pairs: list[tuple[str, str, str]], - ) -> dict[tuple[str, str], int]: - out: dict[tuple[str, str], int] = {} - if not pairs: - return out - - # Group by query_text → run one rate_query_batch per query. - by_query: dict[str, list[tuple[str, str]]] = {} - query_id_for_text: dict[str, str] = {} - for qid, did, qtext in pairs: - by_query.setdefault(qtext, []).append((qid, did)) - query_id_for_text.setdefault(qtext, qid) - - for query_text, qd_pairs in by_query.items(): - qid = query_id_for_text[query_text] - - # Pre-call budget peek (spec FR-2 + FR-5). - if budget_usd > 0: - current = await peek_daily_total(redis) - est_max = estimated_max_call_cost(model) - if current + est_max > budget_usd: - raise BudgetExceededError( - f"current ${current:.4f} + estimated ${est_max:.4f} " - f"> budget ${budget_usd:.4f}" - ) - - # Fetch doc bodies for this query's pairs. Map each prompt ordinal - # back to the FULL (query_id, doc_id) tuple — not just doc_id. - # Multiple distinct internal query_ids can share the same - # query_text (duplicate rows in the operator's query set), and - # they're grouped together here; mapping prompt_id → doc_id alone - # would attribute every rating to the single representative `qid` - # and silently drop ratings for the others (Gemini PR #317 - # findings #2 + #3). - doc_inputs: list[dict[str, str]] = [] - prompt_id_to_real: dict[str, tuple[str, str]] = {} - for i, (pair_query_id, doc_id) in enumerate(qd_pairs): - doc = await adapter.get_document(target, doc_id) - source = getattr(doc, "source", None) if doc is not None else None - body_raw: str - if isinstance(source, dict) and source.get("body"): - body_raw = str(source["body"]) - elif source is not None: - import json as _json - - body_raw = _json.dumps(source, ensure_ascii=False) - else: - body_raw = "" - body = body_raw[:_DOC_BODY_CHAR_LIMIT] - prompt_id = f"item-{i}" - doc_inputs.append({"doc_id": prompt_id, "body": body}) - prompt_id_to_real[prompt_id] = (pair_query_id, doc_id) - - expected = set(prompt_id_to_real.keys()) - user_prompt = render_user_prompt( - rubric_text=rubric, - query_text=query_text, - docs=doc_inputs, - ) - system_prompt = f"{bundle_system}\n\n\n{rubric}\n" - - try: - result = await rate_query_batch( - client=openai_client, - model=model, - system_prompt=system_prompt, - user_prompt=user_prompt, - expected_doc_ids=expected, - ) - except ( - openai.AuthenticationError, - openai.PermissionDeniedError, - openai.BadRequestError, - openai.NotFoundError, - ): - # Persistent provider misconfig — propagate so the worker - # marks the list failed. - raise - except Exception as exc: - # Per-query operational failure — skip this query's pairs - # (matches the LLM worker's isolation pattern). - logger.warning( - "ubi worker: hybrid LLM call failed for query, skipping", - event_type="ubi_hybrid_llm_failed", - query_id=qid, - error_type=type(exc).__name__, - error=str(exc), - ) - continue - - await safe_record_cost( - redis, - result.cost_usd, - logger=logger, - log_message="ubi worker: record_cost failed (budget telemetry only)", - event_type="ubi_record_cost_failed", - ) - - llm_fill_calls_counter[0] += 1 - - # Map prompt-only ordinals back to the real (query_id, doc_id) so - # ratings stay attributed to the requesting query. - for r in result.ratings: - real_pair = prompt_id_to_real.get(r.doc_id) - if real_pair is None: - continue # hallucinated id; skip - out[real_pair] = r.rating - - return out - - return _callback - - # ---------------------------------------------------------------------------- # Main entry point # ---------------------------------------------------------------------------- @@ -664,7 +512,7 @@ def _build_converter( api_key=settings.openai_api_key, base_url=settings.openai_base_url ) bundle = load_judgment_prompts() - callback: LlmRateCallback = _make_llm_rate_callback( + callback: LlmRateCallback = make_hybrid_llm_rate_callback( openai_client=openai_client, model=settings.openai_model, rubric=params.get("rubric") or "Rate the document's relevance to the query (0-3).", @@ -718,11 +566,5 @@ async def _write_calibration_and_complete( __all__ = [ "generate_judgments_from_ubi", - "_make_llm_rate_callback", "_apply_mapping_strategy", ] - - -# Keep mypy happy about the Awaitable / Callable imports being used. -_LlmRateCallbackT: type = type(Callable[[list[tuple[str, str, str]]], Awaitable[Any]]) -del _LlmRateCallbackT