Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 154 additions & 3 deletions backend/app/services/judgment_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import asyncio
import json
from collections.abc import Sequence
from typing import Any

Expand All @@ -27,6 +29,7 @@
from backend.app.adapters.protocol import NativeQuery, QueryTemplate
from backend.app.db import repo
from backend.app.domain.study.template_defaults import compute_default_params
from backend.app.domain.ubi import LlmRateCallback
from backend.app.llm.budget_gate import BudgetExceededError, peek_daily_total, safe_record_cost
from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost
from backend.app.llm.openai_judge import rate_query_batch
Expand Down Expand Up @@ -67,10 +70,8 @@ def _build_doc_inputs(hits: Sequence[Any]) -> list[dict[str, str]]:
# dump anyway now that this helper is service-layer: a future caller
# could pass a hit whose source holds a non-serializable value, and
# a TypeError here would abort the whole judgment run.
import json as _json

try:
body_raw = _json.dumps(source, ensure_ascii=False)
body_raw = json.dumps(source, ensure_ascii=False)
except TypeError:
body_raw = str(source)
body = str(body_raw)
Expand Down Expand Up @@ -360,3 +361,153 @@ async def process_judgment_query(
duration_ms=result.duration_ms,
)
return True


def make_hybrid_llm_rate_callback(
*,
openai_client: AsyncOpenAI,
model: str,
rubric: str,
bundle_system: str,
target: str,
adapter: Any,
redis: Redis,
budget_usd: float,
llm_fill_calls_counter: list[int],
) -> LlmRateCallback:
"""Construct the :class:`LlmRateCallback` bound for UBI hybrid mode.

The callback receives ``[(query_id, doc_id, query_text), ...]`` for
pairs the inner converter couldn't rate (``impression_count <
llm_fill_threshold``). It groups by ``query_text``, fetches doc
bodies via :func:`adapter.get_document`, calls
:func:`rate_query_batch`, records cost through the daily budget
gate, and returns ``{(query_id, doc_id): rating}``.

Lives in the service layer (not the worker) alongside the rest of the
judgment composition; the UBI worker constructs the OpenAI client and
passes it in. ``llm_fill_calls_counter`` is a mutable list (length-1 int
counter) so the worker can read the per-list LLM-call total for the
calibration JSONB without threading another mutable through.
"""

async def _callback(
pairs: list[tuple[str, str, str]],
) -> dict[tuple[str, str], int]:
out: dict[tuple[str, str], int] = {}
if not pairs:
return out

# Group by query_text → run one rate_query_batch per query.
by_query: dict[str, list[tuple[str, str]]] = {}
query_id_for_text: dict[str, str] = {}
for qid, did, qtext in pairs:
by_query.setdefault(qtext, []).append((qid, did))
query_id_for_text.setdefault(qtext, qid)

for query_text, qd_pairs in by_query.items():
qid = query_id_for_text[query_text]

# Pre-call budget peek (spec FR-2 + FR-5).
if budget_usd > 0:
current = await peek_daily_total(redis)
est_max = estimated_max_call_cost(model)
if current + est_max > budget_usd:
raise BudgetExceededError(
f"current ${current:.4f} + estimated ${est_max:.4f} "
f"> budget ${budget_usd:.4f}"
)

# Fetch doc bodies for this query's pairs. Map each prompt ordinal
# back to the FULL (query_id, doc_id) tuple — not just doc_id.
# Multiple distinct internal query_ids can share the same
# query_text (duplicate rows in the operator's query set), and
# they're grouped together here; mapping prompt_id → doc_id alone
# would attribute every rating to the single representative `qid`
# and silently drop ratings for the others (Gemini PR #317
# findings #2 + #3).
# The per-pair doc fetches are independent + idempotent, so issue
# them concurrently rather than serially round-tripping the engine.
docs = await asyncio.gather(
*(adapter.get_document(target, doc_id) for _, doc_id in qd_pairs)
)
doc_inputs: list[dict[str, str]] = []
prompt_id_to_real: dict[str, tuple[str, str]] = {}
for i, ((pair_query_id, doc_id), doc) in enumerate(zip(qd_pairs, docs, strict=True)):
source = getattr(doc, "source", None) if doc is not None else None
body_raw: str
if isinstance(source, dict) and source.get("body"):
body_raw = str(source["body"])
elif source is not None:
body_raw = json.dumps(source, ensure_ascii=False)
else:
body_raw = ""
body = body_raw[:_DOC_BODY_CHAR_LIMIT]
prompt_id = f"item-{i}"
doc_inputs.append({"doc_id": prompt_id, "body": body})
prompt_id_to_real[prompt_id] = (pair_query_id, doc_id)
Comment on lines +434 to +448

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Fetching document bodies sequentially in a loop using await adapter.get_document(...) introduces a significant performance bottleneck due to sequential network/IO round-trips. Since these document fetches are independent, they can be executed concurrently using asyncio.gather to greatly reduce latency.

Suggested change
doc_inputs: list[dict[str, str]] = []
prompt_id_to_real: dict[str, tuple[str, str]] = {}
for i, (pair_query_id, doc_id) in enumerate(qd_pairs):
doc = await adapter.get_document(target, doc_id)
source = getattr(doc, "source", None) if doc is not None else None
body_raw: str
if isinstance(source, dict) and source.get("body"):
body_raw = str(source["body"])
elif source is not None:
import json as _json
body_raw = _json.dumps(source, ensure_ascii=False)
else:
body_raw = ""
body = body_raw[:_DOC_BODY_CHAR_LIMIT]
prompt_id = f"item-{i}"
doc_inputs.append({"doc_id": prompt_id, "body": body})
prompt_id_to_real[prompt_id] = (pair_query_id, doc_id)
import asyncio as _asyncio
docs = await _asyncio.gather(*(adapter.get_document(target, doc_id) for _, doc_id in qd_pairs))
doc_inputs: list[dict[str, str]] = []
prompt_id_to_real: dict[str, tuple[str, str]] = {}
for i, ((pair_query_id, doc_id), doc) in enumerate(zip(qd_pairs, docs)):
source = getattr(doc, "source", None) if doc is not None else None
body_raw: str
if isinstance(source, dict) and source.get("body"):
body_raw = str(source["body"])
elif source is not None:
import json as _json
body_raw = _json.dumps(source, ensure_ascii=False)
else:
body_raw = ""
body = body_raw[:_DOC_BODY_CHAR_LIMIT]
prompt_id = f"item-{i}"
doc_inputs.append({"doc_id": prompt_id, "body": body})
prompt_id_to_real[prompt_id] = (pair_query_id, doc_id)


expected = set(prompt_id_to_real.keys())
user_prompt = render_user_prompt(
rubric_text=rubric,
query_text=query_text,
docs=doc_inputs,
)
system_prompt = f"{bundle_system}\n\n<rubric>\n{rubric}\n</rubric>"

try:
result = await rate_query_batch(
client=openai_client,
model=model,
system_prompt=system_prompt,
user_prompt=user_prompt,
expected_doc_ids=expected,
)
except (
openai.AuthenticationError,
openai.PermissionDeniedError,
openai.BadRequestError,
openai.NotFoundError,
UnknownModelPricingError,
):
# Persistent provider misconfig OR unknown model pricing —
# propagate so the worker fails the whole list (via
# fail_on_budget_or_pricing_error) rather than silently
# skipping the query. With the budget disabled (budget_usd
# <= 0) the pre-call estimated_max_call_cost guard above is
# skipped, so this is the path that surfaces an unknown-pricing
# model from rate_query_batch's cost computation.
raise
Comment on lines +466 to +480

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The UnknownModelPricingError exception is currently caught by the generic except Exception as exc: block and treated as a per-query operational failure (which is logged and skipped). However, the worker expects UnknownModelPricingError to propagate out of the converter so it can fail the entire judgment list with UNKNOWN_MODEL_PRICING via fail_on_budget_or_pricing_error. We should explicitly propagate UnknownModelPricingError along with the other persistent provider misconfiguration errors.

Suggested change
except (
openai.AuthenticationError,
openai.PermissionDeniedError,
openai.BadRequestError,
openai.NotFoundError,
):
# Persistent provider misconfig — propagate so the worker
# marks the list failed.
raise
except (
openai.AuthenticationError,
openai.PermissionDeniedError,
openai.BadRequestError,
openai.NotFoundError,
UnknownModelPricingError,
):
# Persistent provider misconfig or pricing error — propagate so the worker
# marks the list failed.
raise

except Exception as exc:
# Per-query operational failure — skip this query's pairs
# (matches the LLM worker's isolation pattern).
logger.warning(
"ubi worker: hybrid LLM call failed for query, skipping",
event_type="ubi_hybrid_llm_failed",
query_id=qid,
error_type=type(exc).__name__,
error=str(exc),
)
continue

await safe_record_cost(
redis,
result.cost_usd,
logger=logger,
log_message="ubi worker: record_cost failed (budget telemetry only)",
event_type="ubi_record_cost_failed",
)

llm_fill_calls_counter[0] += 1

# Map prompt-only ordinals back to the real (query_id, doc_id) so
# ratings stay attributed to the requesting query.
for r in result.ratings:
real_pair = prompt_id_to_real.get(r.doc_id)
if real_pair is None:
continue # hallucinated id; skip
out[real_pair] = r.rating

return out

return _callback
14 changes: 10 additions & 4 deletions backend/tests/unit/workers/test_judgments_ubi_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,16 @@ def test_worker_exports() -> None:
from backend.workers import judgments_ubi

assert callable(judgments_ubi.generate_judgments_from_ubi)
assert callable(judgments_ubi._make_llm_rate_callback)
assert callable(judgments_ubi._apply_mapping_strategy)


def test_hybrid_rate_callback_factory_lives_in_service() -> None:
"""The hybrid LLM-fill callback factory moved to the service layer."""
from backend.app.services.judgment_generation import make_hybrid_llm_rate_callback

assert callable(make_hybrid_llm_rate_callback)


def test_worker_registered_in_worker_settings() -> None:
"""The boot-time WorkerSettings.functions list MUST include the UBI job
so Arq dispatches it to the worker process (FR-5 step 3).
Expand Down Expand Up @@ -182,8 +188,9 @@ def test_worker_no_direct_openai_construction_outside_callback_factory() -> None
source = Path("backend/workers/judgments_ubi.py").read_text()
tree = ast.parse(source)
# Walk top-level definitions; the only AsyncOpenAI(...) call we permit is
# inside _build_converter (or _make_llm_rate_callback as a nested
# construction). Other code paths must NOT instantiate the client.
# inside _build_converter (which constructs the client and passes it to
# the service-layer hybrid callback factory). Other code paths must NOT
# instantiate the client.
forbidden_paths: list[str] = []

class _Visitor(ast.NodeVisitor):
Expand All @@ -205,7 +212,6 @@ def visit_Call(self, node: ast.Call) -> None:
if isinstance(func, ast.Name) and func.id == "AsyncOpenAI":
if not self._current_fn or self._current_fn[-1] not in {
"_build_converter",
"_make_llm_rate_callback",
}:
where = ".".join(self._current_fn) or "<module>"
forbidden_paths.append(where)
Expand Down
Loading
Loading