-
Notifications
You must be signed in to change notification settings - Fork 2
refactor: move UBI hybrid LLM-fill callback to the service layer #571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||
| from collections.abc import Sequence | ||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -27,6 +29,7 @@ | |||||||||||||||||||||||||||||||||||||||
| from backend.app.adapters.protocol import NativeQuery, QueryTemplate | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.db import repo | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.domain.study.template_defaults import compute_default_params | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.domain.ubi import LlmRateCallback | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.llm.budget_gate import BudgetExceededError, peek_daily_total, safe_record_cost | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.llm.cost_model import UnknownModelPricingError, estimated_max_call_cost | ||||||||||||||||||||||||||||||||||||||||
| from backend.app.llm.openai_judge import rate_query_batch | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -67,10 +70,8 @@ def _build_doc_inputs(hits: Sequence[Any]) -> list[dict[str, str]]: | |||||||||||||||||||||||||||||||||||||||
| # dump anyway now that this helper is service-layer: a future caller | ||||||||||||||||||||||||||||||||||||||||
| # could pass a hit whose source holds a non-serializable value, and | ||||||||||||||||||||||||||||||||||||||||
| # a TypeError here would abort the whole judgment run. | ||||||||||||||||||||||||||||||||||||||||
| import json as _json | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| body_raw = _json.dumps(source, ensure_ascii=False) | ||||||||||||||||||||||||||||||||||||||||
| body_raw = json.dumps(source, ensure_ascii=False) | ||||||||||||||||||||||||||||||||||||||||
| except TypeError: | ||||||||||||||||||||||||||||||||||||||||
| body_raw = str(source) | ||||||||||||||||||||||||||||||||||||||||
| body = str(body_raw) | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -360,3 +361,153 @@ async def process_judgment_query( | |||||||||||||||||||||||||||||||||||||||
| duration_ms=result.duration_ms, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def make_hybrid_llm_rate_callback( | ||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||
| openai_client: AsyncOpenAI, | ||||||||||||||||||||||||||||||||||||||||
| model: str, | ||||||||||||||||||||||||||||||||||||||||
| rubric: str, | ||||||||||||||||||||||||||||||||||||||||
| bundle_system: str, | ||||||||||||||||||||||||||||||||||||||||
| target: str, | ||||||||||||||||||||||||||||||||||||||||
| adapter: Any, | ||||||||||||||||||||||||||||||||||||||||
| redis: Redis, | ||||||||||||||||||||||||||||||||||||||||
| budget_usd: float, | ||||||||||||||||||||||||||||||||||||||||
| llm_fill_calls_counter: list[int], | ||||||||||||||||||||||||||||||||||||||||
| ) -> LlmRateCallback: | ||||||||||||||||||||||||||||||||||||||||
| """Construct the :class:`LlmRateCallback` bound for UBI hybrid mode. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| The callback receives ``[(query_id, doc_id, query_text), ...]`` for | ||||||||||||||||||||||||||||||||||||||||
| pairs the inner converter couldn't rate (``impression_count < | ||||||||||||||||||||||||||||||||||||||||
| llm_fill_threshold``). It groups by ``query_text``, fetches doc | ||||||||||||||||||||||||||||||||||||||||
| bodies via :func:`adapter.get_document`, calls | ||||||||||||||||||||||||||||||||||||||||
| :func:`rate_query_batch`, records cost through the daily budget | ||||||||||||||||||||||||||||||||||||||||
| gate, and returns ``{(query_id, doc_id): rating}``. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Lives in the service layer (not the worker) alongside the rest of the | ||||||||||||||||||||||||||||||||||||||||
| judgment composition; the UBI worker constructs the OpenAI client and | ||||||||||||||||||||||||||||||||||||||||
| passes it in. ``llm_fill_calls_counter`` is a mutable list (length-1 int | ||||||||||||||||||||||||||||||||||||||||
| counter) so the worker can read the per-list LLM-call total for the | ||||||||||||||||||||||||||||||||||||||||
| calibration JSONB without threading another mutable through. | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| async def _callback( | ||||||||||||||||||||||||||||||||||||||||
| pairs: list[tuple[str, str, str]], | ||||||||||||||||||||||||||||||||||||||||
| ) -> dict[tuple[str, str], int]: | ||||||||||||||||||||||||||||||||||||||||
| out: dict[tuple[str, str], int] = {} | ||||||||||||||||||||||||||||||||||||||||
| if not pairs: | ||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Group by query_text → run one rate_query_batch per query. | ||||||||||||||||||||||||||||||||||||||||
| by_query: dict[str, list[tuple[str, str]]] = {} | ||||||||||||||||||||||||||||||||||||||||
| query_id_for_text: dict[str, str] = {} | ||||||||||||||||||||||||||||||||||||||||
| for qid, did, qtext in pairs: | ||||||||||||||||||||||||||||||||||||||||
| by_query.setdefault(qtext, []).append((qid, did)) | ||||||||||||||||||||||||||||||||||||||||
| query_id_for_text.setdefault(qtext, qid) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for query_text, qd_pairs in by_query.items(): | ||||||||||||||||||||||||||||||||||||||||
| qid = query_id_for_text[query_text] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Pre-call budget peek (spec FR-2 + FR-5). | ||||||||||||||||||||||||||||||||||||||||
| if budget_usd > 0: | ||||||||||||||||||||||||||||||||||||||||
| current = await peek_daily_total(redis) | ||||||||||||||||||||||||||||||||||||||||
| est_max = estimated_max_call_cost(model) | ||||||||||||||||||||||||||||||||||||||||
| if current + est_max > budget_usd: | ||||||||||||||||||||||||||||||||||||||||
| raise BudgetExceededError( | ||||||||||||||||||||||||||||||||||||||||
| f"current ${current:.4f} + estimated ${est_max:.4f} " | ||||||||||||||||||||||||||||||||||||||||
| f"> budget ${budget_usd:.4f}" | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Fetch doc bodies for this query's pairs. Map each prompt ordinal | ||||||||||||||||||||||||||||||||||||||||
| # back to the FULL (query_id, doc_id) tuple — not just doc_id. | ||||||||||||||||||||||||||||||||||||||||
| # Multiple distinct internal query_ids can share the same | ||||||||||||||||||||||||||||||||||||||||
| # query_text (duplicate rows in the operator's query set), and | ||||||||||||||||||||||||||||||||||||||||
| # they're grouped together here; mapping prompt_id → doc_id alone | ||||||||||||||||||||||||||||||||||||||||
| # would attribute every rating to the single representative `qid` | ||||||||||||||||||||||||||||||||||||||||
| # and silently drop ratings for the others (Gemini PR #317 | ||||||||||||||||||||||||||||||||||||||||
| # findings #2 + #3). | ||||||||||||||||||||||||||||||||||||||||
| # The per-pair doc fetches are independent + idempotent, so issue | ||||||||||||||||||||||||||||||||||||||||
| # them concurrently rather than serially round-tripping the engine. | ||||||||||||||||||||||||||||||||||||||||
| docs = await asyncio.gather( | ||||||||||||||||||||||||||||||||||||||||
| *(adapter.get_document(target, doc_id) for _, doc_id in qd_pairs) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| doc_inputs: list[dict[str, str]] = [] | ||||||||||||||||||||||||||||||||||||||||
| prompt_id_to_real: dict[str, tuple[str, str]] = {} | ||||||||||||||||||||||||||||||||||||||||
| for i, ((pair_query_id, doc_id), doc) in enumerate(zip(qd_pairs, docs, strict=True)): | ||||||||||||||||||||||||||||||||||||||||
| source = getattr(doc, "source", None) if doc is not None else None | ||||||||||||||||||||||||||||||||||||||||
| body_raw: str | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(source, dict) and source.get("body"): | ||||||||||||||||||||||||||||||||||||||||
| body_raw = str(source["body"]) | ||||||||||||||||||||||||||||||||||||||||
| elif source is not None: | ||||||||||||||||||||||||||||||||||||||||
| body_raw = json.dumps(source, ensure_ascii=False) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| body_raw = "" | ||||||||||||||||||||||||||||||||||||||||
| body = body_raw[:_DOC_BODY_CHAR_LIMIT] | ||||||||||||||||||||||||||||||||||||||||
| prompt_id = f"item-{i}" | ||||||||||||||||||||||||||||||||||||||||
| doc_inputs.append({"doc_id": prompt_id, "body": body}) | ||||||||||||||||||||||||||||||||||||||||
| prompt_id_to_real[prompt_id] = (pair_query_id, doc_id) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| expected = set(prompt_id_to_real.keys()) | ||||||||||||||||||||||||||||||||||||||||
| user_prompt = render_user_prompt( | ||||||||||||||||||||||||||||||||||||||||
| rubric_text=rubric, | ||||||||||||||||||||||||||||||||||||||||
| query_text=query_text, | ||||||||||||||||||||||||||||||||||||||||
| docs=doc_inputs, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| system_prompt = f"{bundle_system}\n\n<rubric>\n{rubric}\n</rubric>" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||
| result = await rate_query_batch( | ||||||||||||||||||||||||||||||||||||||||
| client=openai_client, | ||||||||||||||||||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||||||||||||||||||
| system_prompt=system_prompt, | ||||||||||||||||||||||||||||||||||||||||
| user_prompt=user_prompt, | ||||||||||||||||||||||||||||||||||||||||
| expected_doc_ids=expected, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| except ( | ||||||||||||||||||||||||||||||||||||||||
| openai.AuthenticationError, | ||||||||||||||||||||||||||||||||||||||||
| openai.PermissionDeniedError, | ||||||||||||||||||||||||||||||||||||||||
| openai.BadRequestError, | ||||||||||||||||||||||||||||||||||||||||
| openai.NotFoundError, | ||||||||||||||||||||||||||||||||||||||||
| UnknownModelPricingError, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| # Persistent provider misconfig OR unknown model pricing — | ||||||||||||||||||||||||||||||||||||||||
| # propagate so the worker fails the whole list (via | ||||||||||||||||||||||||||||||||||||||||
| # fail_on_budget_or_pricing_error) rather than silently | ||||||||||||||||||||||||||||||||||||||||
| # skipping the query. With the budget disabled (budget_usd | ||||||||||||||||||||||||||||||||||||||||
| # <= 0) the pre-call estimated_max_call_cost guard above is | ||||||||||||||||||||||||||||||||||||||||
| # skipped, so this is the path that surfaces an unknown-pricing | ||||||||||||||||||||||||||||||||||||||||
| # model from rate_query_batch's cost computation. | ||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+466
to
+480
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||
| # Per-query operational failure — skip this query's pairs | ||||||||||||||||||||||||||||||||||||||||
| # (matches the LLM worker's isolation pattern). | ||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||
| "ubi worker: hybrid LLM call failed for query, skipping", | ||||||||||||||||||||||||||||||||||||||||
| event_type="ubi_hybrid_llm_failed", | ||||||||||||||||||||||||||||||||||||||||
| query_id=qid, | ||||||||||||||||||||||||||||||||||||||||
| error_type=type(exc).__name__, | ||||||||||||||||||||||||||||||||||||||||
| error=str(exc), | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| await safe_record_cost( | ||||||||||||||||||||||||||||||||||||||||
| redis, | ||||||||||||||||||||||||||||||||||||||||
| result.cost_usd, | ||||||||||||||||||||||||||||||||||||||||
| logger=logger, | ||||||||||||||||||||||||||||||||||||||||
| log_message="ubi worker: record_cost failed (budget telemetry only)", | ||||||||||||||||||||||||||||||||||||||||
| event_type="ubi_record_cost_failed", | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| llm_fill_calls_counter[0] += 1 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Map prompt-only ordinals back to the real (query_id, doc_id) so | ||||||||||||||||||||||||||||||||||||||||
| # ratings stay attributed to the requesting query. | ||||||||||||||||||||||||||||||||||||||||
| for r in result.ratings: | ||||||||||||||||||||||||||||||||||||||||
| real_pair = prompt_id_to_real.get(r.doc_id) | ||||||||||||||||||||||||||||||||||||||||
| if real_pair is None: | ||||||||||||||||||||||||||||||||||||||||
| continue # hallucinated id; skip | ||||||||||||||||||||||||||||||||||||||||
| out[real_pair] = r.rating | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return _callback | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fetching document bodies sequentially in a loop using
await adapter.get_document(...)introduces a significant performance bottleneck due to sequential network/IO round-trips. Since these document fetches are independent, they can be executed concurrently usingasyncio.gatherto greatly reduce latency.