diff --git a/backend/app/services/judgment_generation.py b/backend/app/services/judgment_generation.py
index 121f1f8d..353828b3 100644
--- a/backend/app/services/judgment_generation.py
+++ b/backend/app/services/judgment_generation.py
@@ -14,6 +14,8 @@
from __future__ import annotations
+import asyncio
+import json
from collections.abc import Sequence
from typing import Any
@@ -27,6 +29,7 @@
from backend.app.adapters.protocol import NativeQuery, QueryTemplate
from backend.app.db import repo
from backend.app.domain.study.template_defaults import compute_default_params
+from backend.app.domain.ubi import LlmRateCallback
from backend.app.llm.budget_gate import BudgetExceededError, peek_daily_total, safe_record_cost
from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost
from backend.app.llm.openai_judge import rate_query_batch
@@ -67,10 +70,8 @@ def _build_doc_inputs(hits: Sequence[Any]) -> list[dict[str, str]]:
# dump anyway now that this helper is service-layer: a future caller
# could pass a hit whose source holds a non-serializable value, and
# a TypeError here would abort the whole judgment run.
- import json as _json
-
try:
- body_raw = _json.dumps(source, ensure_ascii=False)
+ body_raw = json.dumps(source, ensure_ascii=False)
except TypeError:
body_raw = str(source)
body = str(body_raw)
@@ -360,3 +361,153 @@ async def process_judgment_query(
duration_ms=result.duration_ms,
)
return True
+
+
+def make_hybrid_llm_rate_callback(
+ *,
+ openai_client: AsyncOpenAI,
+ model: str,
+ rubric: str,
+ bundle_system: str,
+ target: str,
+ adapter: Any,
+ redis: Redis,
+ budget_usd: float,
+ llm_fill_calls_counter: list[int],
+) -> LlmRateCallback:
+ """Construct the :class:`LlmRateCallback` bound for UBI hybrid mode.
+
+ The callback receives ``[(query_id, doc_id, query_text), ...]`` for
+ pairs the inner converter couldn't rate (``impression_count <
+ llm_fill_threshold``). It groups by ``query_text``, fetches doc
+ bodies via :func:`adapter.get_document`, calls
+ :func:`rate_query_batch`, records cost through the daily budget
+ gate, and returns ``{(query_id, doc_id): rating}``.
+
+ Lives in the service layer (not the worker) alongside the rest of the
+ judgment composition; the UBI worker constructs the OpenAI client and
+ passes it in. ``llm_fill_calls_counter`` is a mutable list (length-1 int
+ counter) so the worker can read the per-list LLM-call total for the
+ calibration JSONB without threading another mutable through.
+ """
+
+ async def _callback(
+ pairs: list[tuple[str, str, str]],
+ ) -> dict[tuple[str, str], int]:
+ out: dict[tuple[str, str], int] = {}
+ if not pairs:
+ return out
+
+ # Group by query_text → run one rate_query_batch per query.
+ by_query: dict[str, list[tuple[str, str]]] = {}
+ query_id_for_text: dict[str, str] = {}
+ for qid, did, qtext in pairs:
+ by_query.setdefault(qtext, []).append((qid, did))
+ query_id_for_text.setdefault(qtext, qid)
+
+ for query_text, qd_pairs in by_query.items():
+ qid = query_id_for_text[query_text]
+
+ # Pre-call budget peek (spec FR-2 + FR-5).
+ if budget_usd > 0:
+ current = await peek_daily_total(redis)
+ est_max = estimated_max_call_cost(model)
+ if current + est_max > budget_usd:
+ raise BudgetExceededError(
+ f"current ${current:.4f} + estimated ${est_max:.4f} "
+ f"> budget ${budget_usd:.4f}"
+ )
+
+ # Fetch doc bodies for this query's pairs. Map each prompt ordinal
+ # back to the FULL (query_id, doc_id) tuple — not just doc_id.
+ # Multiple distinct internal query_ids can share the same
+ # query_text (duplicate rows in the operator's query set), and
+ # they're grouped together here; mapping prompt_id → doc_id alone
+ # would attribute every rating to the single representative `qid`
+ # and silently drop ratings for the others (Gemini PR #317
+ # findings #2 + #3).
+ # The per-pair doc fetches are independent + idempotent, so issue
+ # them concurrently rather than serially round-tripping the engine.
+ docs = await asyncio.gather(
+ *(adapter.get_document(target, doc_id) for _, doc_id in qd_pairs)
+ )
+ doc_inputs: list[dict[str, str]] = []
+ prompt_id_to_real: dict[str, tuple[str, str]] = {}
+ for i, ((pair_query_id, doc_id), doc) in enumerate(zip(qd_pairs, docs, strict=True)):
+ source = getattr(doc, "source", None) if doc is not None else None
+ body_raw: str
+ if isinstance(source, dict) and source.get("body"):
+ body_raw = str(source["body"])
+ elif source is not None:
+ body_raw = json.dumps(source, ensure_ascii=False)
+ else:
+ body_raw = ""
+ body = body_raw[:_DOC_BODY_CHAR_LIMIT]
+ prompt_id = f"item-{i}"
+ doc_inputs.append({"doc_id": prompt_id, "body": body})
+ prompt_id_to_real[prompt_id] = (pair_query_id, doc_id)
+
+ expected = set(prompt_id_to_real.keys())
+ user_prompt = render_user_prompt(
+ rubric_text=rubric,
+ query_text=query_text,
+ docs=doc_inputs,
+ )
+ system_prompt = f"{bundle_system}\n\n\n{rubric}\n"
+
+ try:
+ result = await rate_query_batch(
+ client=openai_client,
+ model=model,
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ expected_doc_ids=expected,
+ )
+ except (
+ openai.AuthenticationError,
+ openai.PermissionDeniedError,
+ openai.BadRequestError,
+ openai.NotFoundError,
+ UnknownModelPricingError,
+ ):
+ # Persistent provider misconfig OR unknown model pricing —
+ # propagate so the worker fails the whole list (via
+ # fail_on_budget_or_pricing_error) rather than silently
+ # skipping the query. With the budget disabled (budget_usd
+ # <= 0) the pre-call estimated_max_call_cost guard above is
+ # skipped, so this is the path that surfaces an unknown-pricing
+ # model from rate_query_batch's cost computation.
+ raise
+ except Exception as exc:
+ # Per-query operational failure — skip this query's pairs
+ # (matches the LLM worker's isolation pattern).
+ logger.warning(
+ "ubi worker: hybrid LLM call failed for query, skipping",
+ event_type="ubi_hybrid_llm_failed",
+ query_id=qid,
+ error_type=type(exc).__name__,
+ error=str(exc),
+ )
+ continue
+
+ await safe_record_cost(
+ redis,
+ result.cost_usd,
+ logger=logger,
+ log_message="ubi worker: record_cost failed (budget telemetry only)",
+ event_type="ubi_record_cost_failed",
+ )
+
+ llm_fill_calls_counter[0] += 1
+
+ # Map prompt-only ordinals back to the real (query_id, doc_id) so
+ # ratings stay attributed to the requesting query.
+ for r in result.ratings:
+ real_pair = prompt_id_to_real.get(r.doc_id)
+ if real_pair is None:
+ continue # hallucinated id; skip
+ out[real_pair] = r.rating
+
+ return out
+
+ return _callback
diff --git a/backend/tests/unit/workers/test_judgments_ubi_helpers.py b/backend/tests/unit/workers/test_judgments_ubi_helpers.py
index 2820c6ad..3be97f1b 100644
--- a/backend/tests/unit/workers/test_judgments_ubi_helpers.py
+++ b/backend/tests/unit/workers/test_judgments_ubi_helpers.py
@@ -151,10 +151,16 @@ def test_worker_exports() -> None:
from backend.workers import judgments_ubi
assert callable(judgments_ubi.generate_judgments_from_ubi)
- assert callable(judgments_ubi._make_llm_rate_callback)
assert callable(judgments_ubi._apply_mapping_strategy)
+def test_hybrid_rate_callback_factory_lives_in_service() -> None:
+ """The hybrid LLM-fill callback factory moved to the service layer."""
+ from backend.app.services.judgment_generation import make_hybrid_llm_rate_callback
+
+ assert callable(make_hybrid_llm_rate_callback)
+
+
def test_worker_registered_in_worker_settings() -> None:
"""The boot-time WorkerSettings.functions list MUST include the UBI job
so Arq dispatches it to the worker process (FR-5 step 3).
@@ -182,8 +188,9 @@ def test_worker_no_direct_openai_construction_outside_callback_factory() -> None
source = Path("backend/workers/judgments_ubi.py").read_text()
tree = ast.parse(source)
# Walk top-level definitions; the only AsyncOpenAI(...) call we permit is
- # inside _build_converter (or _make_llm_rate_callback as a nested
- # construction). Other code paths must NOT instantiate the client.
+ # inside _build_converter (which constructs the client and passes it to
+ # the service-layer hybrid callback factory). Other code paths must NOT
+ # instantiate the client.
forbidden_paths: list[str] = []
class _Visitor(ast.NodeVisitor):
@@ -205,7 +212,6 @@ def visit_Call(self, node: ast.Call) -> None:
if isinstance(func, ast.Name) and func.id == "AsyncOpenAI":
if not self._current_fn or self._current_fn[-1] not in {
"_build_converter",
- "_make_llm_rate_callback",
}:
where = ".".join(self._current_fn) or ""
forbidden_paths.append(where)
diff --git a/backend/workers/judgments_ubi.py b/backend/workers/judgments_ubi.py
index fa366d4d..25b10047 100644
--- a/backend/workers/judgments_ubi.py
+++ b/backend/workers/judgments_ubi.py
@@ -52,8 +52,9 @@
llm_fill_threshold``; below the threshold, defers the pair to LLM-fill
via an injected callback." The callback signature is
``Callable[[list[tuple[str, str, str]]], Awaitable[dict[tuple[str, str], int]]]``
-taking ``(query_id, doc_id, query_text)`` tuples. The worker-local
-:func:`_make_llm_rate_callback` constructs the callback bound to:
+taking ``(query_id, doc_id, query_text)`` tuples. The service-layer
+:func:`backend.app.services.judgment_generation.make_hybrid_llm_rate_callback`
+constructs the callback bound to:
* :func:`adapter.get_document` for a doc-body fetch BY ID. The FR-2
callback contract is explicitly per-``(query_id, doc_id)`` pair (it
@@ -75,11 +76,10 @@
from __future__ import annotations
import time
-from collections.abc import Awaitable, Callable, Sequence
+from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Any
-import openai
import structlog
import uuid_utils
from openai import AsyncOpenAI
@@ -97,18 +97,14 @@
SignalsConverter,
)
from backend.app.domain.ubi.converter import ConverterConfig
-from backend.app.llm.budget_gate import (
- BudgetExceededError,
- peek_daily_total,
- safe_record_cost,
-)
-from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost
-from backend.app.llm.openai_judge import rate_query_batch
-from backend.app.llm.prompt_loader import load_judgment_prompts, render_user_prompt
+from backend.app.llm.budget_gate import BudgetExceededError
+from backend.app.llm.cost_model import UnknownModelPricingError
+from backend.app.llm.prompt_loader import load_judgment_prompts
from backend.app.services.cluster import build_adapter
from backend.app.services.judgment_generation import (
fail_judgment_list,
fail_on_budget_or_pricing_error,
+ make_hybrid_llm_rate_callback,
)
from backend.app.services.ubi_errors import UbiNotEnabledError
from backend.app.services.ubi_reader import UbiReader
@@ -116,9 +112,6 @@
logger = structlog.get_logger(__name__)
-_DOC_BODY_CHAR_LIMIT = 500
-"""Per-doc body truncation length (mirrors generate_judgments_llm)."""
-
# ----------------------------------------------------------------------------
# mapping_strategy join helper
@@ -178,151 +171,6 @@ def _apply_mapping_strategy(
return mapping, ambiguous_count
-# ----------------------------------------------------------------------------
-# Hybrid LLM-fill callback factory
-# ----------------------------------------------------------------------------
-
-
-def _make_llm_rate_callback(
- *,
- openai_client: AsyncOpenAI,
- model: str,
- rubric: str,
- bundle_system: str,
- target: str,
- adapter: Any,
- redis: Redis,
- budget_usd: float,
- llm_fill_calls_counter: list[int],
-) -> LlmRateCallback:
- """Construct the :class:`LlmRateCallback` bound for hybrid mode.
-
- The callback receives ``[(query_id, doc_id, query_text), ...]`` for
- pairs the inner converter couldn't rate (``impression_count <
- llm_fill_threshold``). It groups by ``query_text``, fetches doc
- bodies via :func:`adapter.get_document`, calls
- :func:`rate_query_batch`, records cost through the daily budget
- gate, and returns ``{(query_id, doc_id): rating}``.
-
- ``llm_fill_calls_counter`` is a mutable list (length-1 int counter)
- so the worker can read the per-list LLM-call total for the
- calibration JSONB without threading another mutable through.
- """
-
- async def _callback(
- pairs: list[tuple[str, str, str]],
- ) -> dict[tuple[str, str], int]:
- out: dict[tuple[str, str], int] = {}
- if not pairs:
- return out
-
- # Group by query_text → run one rate_query_batch per query.
- by_query: dict[str, list[tuple[str, str]]] = {}
- query_id_for_text: dict[str, str] = {}
- for qid, did, qtext in pairs:
- by_query.setdefault(qtext, []).append((qid, did))
- query_id_for_text.setdefault(qtext, qid)
-
- for query_text, qd_pairs in by_query.items():
- qid = query_id_for_text[query_text]
-
- # Pre-call budget peek (spec FR-2 + FR-5).
- if budget_usd > 0:
- current = await peek_daily_total(redis)
- est_max = estimated_max_call_cost(model)
- if current + est_max > budget_usd:
- raise BudgetExceededError(
- f"current ${current:.4f} + estimated ${est_max:.4f} "
- f"> budget ${budget_usd:.4f}"
- )
-
- # Fetch doc bodies for this query's pairs. Map each prompt ordinal
- # back to the FULL (query_id, doc_id) tuple — not just doc_id.
- # Multiple distinct internal query_ids can share the same
- # query_text (duplicate rows in the operator's query set), and
- # they're grouped together here; mapping prompt_id → doc_id alone
- # would attribute every rating to the single representative `qid`
- # and silently drop ratings for the others (Gemini PR #317
- # findings #2 + #3).
- doc_inputs: list[dict[str, str]] = []
- prompt_id_to_real: dict[str, tuple[str, str]] = {}
- for i, (pair_query_id, doc_id) in enumerate(qd_pairs):
- doc = await adapter.get_document(target, doc_id)
- source = getattr(doc, "source", None) if doc is not None else None
- body_raw: str
- if isinstance(source, dict) and source.get("body"):
- body_raw = str(source["body"])
- elif source is not None:
- import json as _json
-
- body_raw = _json.dumps(source, ensure_ascii=False)
- else:
- body_raw = ""
- body = body_raw[:_DOC_BODY_CHAR_LIMIT]
- prompt_id = f"item-{i}"
- doc_inputs.append({"doc_id": prompt_id, "body": body})
- prompt_id_to_real[prompt_id] = (pair_query_id, doc_id)
-
- expected = set(prompt_id_to_real.keys())
- user_prompt = render_user_prompt(
- rubric_text=rubric,
- query_text=query_text,
- docs=doc_inputs,
- )
- system_prompt = f"{bundle_system}\n\n\n{rubric}\n"
-
- try:
- result = await rate_query_batch(
- client=openai_client,
- model=model,
- system_prompt=system_prompt,
- user_prompt=user_prompt,
- expected_doc_ids=expected,
- )
- except (
- openai.AuthenticationError,
- openai.PermissionDeniedError,
- openai.BadRequestError,
- openai.NotFoundError,
- ):
- # Persistent provider misconfig — propagate so the worker
- # marks the list failed.
- raise
- except Exception as exc:
- # Per-query operational failure — skip this query's pairs
- # (matches the LLM worker's isolation pattern).
- logger.warning(
- "ubi worker: hybrid LLM call failed for query, skipping",
- event_type="ubi_hybrid_llm_failed",
- query_id=qid,
- error_type=type(exc).__name__,
- error=str(exc),
- )
- continue
-
- await safe_record_cost(
- redis,
- result.cost_usd,
- logger=logger,
- log_message="ubi worker: record_cost failed (budget telemetry only)",
- event_type="ubi_record_cost_failed",
- )
-
- llm_fill_calls_counter[0] += 1
-
- # Map prompt-only ordinals back to the real (query_id, doc_id) so
- # ratings stay attributed to the requesting query.
- for r in result.ratings:
- real_pair = prompt_id_to_real.get(r.doc_id)
- if real_pair is None:
- continue # hallucinated id; skip
- out[real_pair] = r.rating
-
- return out
-
- return _callback
-
-
# ----------------------------------------------------------------------------
# Main entry point
# ----------------------------------------------------------------------------
@@ -664,7 +512,7 @@ def _build_converter(
api_key=settings.openai_api_key, base_url=settings.openai_base_url
)
bundle = load_judgment_prompts()
- callback: LlmRateCallback = _make_llm_rate_callback(
+ callback: LlmRateCallback = make_hybrid_llm_rate_callback(
openai_client=openai_client,
model=settings.openai_model,
rubric=params.get("rubric") or "Rate the document's relevance to the query (0-3).",
@@ -718,11 +566,5 @@ async def _write_calibration_and_complete(
__all__ = [
"generate_judgments_from_ubi",
- "_make_llm_rate_callback",
"_apply_mapping_strategy",
]
-
-
-# Keep mypy happy about the Awaitable / Callable imports being used.
-_LlmRateCallbackT: type = type(Callable[[list[tuple[str, str, str]]], Awaitable[Any]])
-del _LlmRateCallbackT