Skip to content

Commit a8f14de

Browse files
committed
feat: use judge_passed for all calcs
1 parent 0ed243e commit a8f14de

4 files changed

Lines changed: 24 additions & 24 deletions

File tree

packages/optimization/src/ldai_optimizer/client.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
extract_json_from_response,
6161
generate_slug,
6262
interpolate_variables,
63+
judge_passed,
6364
restore_variable_placeholders,
6465
validate_variation_response,
6566
)
@@ -142,16 +143,6 @@ def _compute_validation_count(pool_size: int) -> int:
142143
}
143144

144145

145-
def _judge_passed(score: float, threshold: float, is_inverted: bool) -> bool:
146-
"""Return True when a judge score meets its threshold.
147-
148-
For standard judges (higher is better) the score must reach the threshold
149-
from below: ``score >= threshold``. For inverted judges (lower is better,
150-
e.g. toxicity) the score must stay at or below the threshold:
151-
``score <= threshold``.
152-
"""
153-
return score <= threshold if is_inverted else score >= threshold
154-
155146

156147
class OptimizationClient:
157148
_options: OptimizationOptions
@@ -481,7 +472,7 @@ async def _call_judges(
481472
if optimization_judge.threshold is not None
482473
else 1.0
483474
)
484-
passed = _judge_passed(result.score, threshold, optimization_judge.is_inverted)
475+
passed = judge_passed(result.score, threshold, optimization_judge.is_inverted)
485476
logger.debug(
486477
"[Iteration %d] -> Judge '%s' scored %.3f (threshold=%.3f, inverted=%s) -> %s%s",
487478
iteration,
@@ -1868,7 +1859,7 @@ def _evaluate_response(self, optimize_context: OptimizationContext) -> bool:
18681859
if optimization_judge.threshold is not None
18691860
else 1.0
18701861
)
1871-
if not _judge_passed(result.score, threshold, optimization_judge.is_inverted):
1862+
if not judge_passed(result.score, threshold, optimization_judge.is_inverted):
18721863
return False
18731864

18741865
return True

packages/optimization/src/ldai_optimizer/prompts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
OptimizationContext,
88
OptimizationJudge,
99
)
10+
from ldai_optimizer.util import judge_passed
1011

1112
_DURATION_KEYWORDS = re.compile(
1213
r"\b(fast|faster|quickly|quick|latency|low-latency|duration|response\s+time|"
@@ -285,10 +286,7 @@ def variation_prompt_feedback(
285286
if optimization_judge:
286287
score = result.score
287288
if optimization_judge.threshold is not None:
288-
if optimization_judge.is_inverted:
289-
passed = score <= optimization_judge.threshold
290-
else:
291-
passed = score >= optimization_judge.threshold
289+
passed = judge_passed(score, optimization_judge.threshold, optimization_judge.is_inverted)
292290
status = "PASSED" if passed else "FAILED"
293291
feedback_line = (
294292
f"- {judge_key}: Score {score:.3f}"

packages/optimization/src/ldai_optimizer/util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,13 @@ def extract_json_from_response(response_str: str) -> Dict[str, Any]:
303303
)
304304

305305
return response_data
306+
307+
308+
def judge_passed(score: float, threshold: float, is_inverted: bool) -> bool:
309+
"""Return True when a judge score meets its threshold.
310+
311+
For standard judges (higher is better) the score must reach the threshold:
312+
``score >= threshold``. For inverted judges (lower is better, e.g. toxicity)
313+
the score must stay at or below the threshold: ``score <= threshold``.
314+
"""
315+
return score <= threshold if is_inverted else score >= threshold

packages/optimization/tests/test_client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from ldai.tracker import TokenUsage
1111
from ldclient import Context
1212

13-
from ldai_optimizer.client import OptimizationClient, _compute_validation_count, _find_model_config, _judge_passed
13+
from ldai_optimizer.client import OptimizationClient, _compute_validation_count, _find_model_config
14+
from ldai_optimizer.util import judge_passed
1415
from ldai_optimizer.dataclasses import (
1516
AIJudgeCallConfig,
1617
GroundTruthOptimizationOptions,
@@ -4410,24 +4411,24 @@ async def test_optimization_key_in_post_url_uses_string_key_not_uuid(self):
44104411

44114412

44124413
# ---------------------------------------------------------------------------
4413-
# _judge_passed helper
4414+
# judge_passed helper
44144415
# ---------------------------------------------------------------------------
44154416

44164417

44174418
class TestJudgePassed:
44184419
def test_standard_judge_passes_at_or_above_threshold(self):
4419-
assert _judge_passed(0.8, 0.8, is_inverted=False) is True
4420-
assert _judge_passed(1.0, 0.8, is_inverted=False) is True
4420+
assert judge_passed(0.8, 0.8, is_inverted=False) is True
4421+
assert judge_passed(1.0, 0.8, is_inverted=False) is True
44214422

44224423
def test_standard_judge_fails_below_threshold(self):
4423-
assert _judge_passed(0.5, 0.8, is_inverted=False) is False
4424+
assert judge_passed(0.5, 0.8, is_inverted=False) is False
44244425

44254426
def test_inverted_judge_passes_at_or_below_threshold(self):
4426-
assert _judge_passed(0.1, 0.3, is_inverted=True) is True
4427-
assert _judge_passed(0.3, 0.3, is_inverted=True) is True
4427+
assert judge_passed(0.1, 0.3, is_inverted=True) is True
4428+
assert judge_passed(0.3, 0.3, is_inverted=True) is True
44284429

44294430
def test_inverted_judge_fails_above_threshold(self):
4430-
assert _judge_passed(0.8, 0.3, is_inverted=True) is False
4431+
assert judge_passed(0.8, 0.3, is_inverted=True) is False
44314432

44324433

44334434
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)