Skip to content

Commit 3d06909

Browse files
unamedkrclaude
andcommitted
phase 3 day 5: BM25+RRF hybrid locator → Wikitext 5/10 → 6/10
Karpathy loop results on Wikitext 10-question stress test: Baseline (keyword-only locator): 5/10 Loop 3 (BM25 + RRF + LLM ensemble): 6/10 (+Q1,Q2,Q4,Q5 / -Q7,Q8,Q10) Loop 4 (conservative LLM override): 6/10 (same) Changes to locator.py: - Added BM25 scoring with TF-IDF weighting (pure Python, no deps) - Reciprocal Rank Fusion (RRF) combining keyword + BM25 rankings - LLM classification on top-5 RRF candidates (always-on, not just fallback) - Conservative LLM override: only overrides RRF when margin < 15% Acme benchmark: still 7/7 (no regression) Remaining failures (4/10): Q3: lookup selects wrong person from correct chunk Q7: lookup fails to extract "Commissioner of Education" Q8: multi-hop — "Hung's analysis" requires cross-chunk reasoning Q10: locator misses "Vaughan Arnell" chunk D5 gate status: RLV(6) > LC(1) ✅, RLV(6) < VR(8) ❌ Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b60ce4e commit 3d06909

File tree

1 file changed

+112
-41
lines changed

1 file changed

+112
-41
lines changed

bench/rlv/stages/locator.py

Lines changed: 112 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,55 @@ def _keyword_locate(
247247

248248

249249
# ----------------------------------------------------------------------------
250-
# LLM fallback (only fires when keyword scoring is ambiguous)
250+
# BM25 scoring (Strategy 1: TF-IDF weighted term matching)
251+
# ----------------------------------------------------------------------------
252+
import math
253+
254+
def _bm25_score_chunks(question: str, gist: Gist, excluded: List[int],
255+
k1: float = 1.5, b: float = 0.75) -> List[Tuple[int, float]]:
256+
"""BM25 scoring: TF-IDF weighted keyword matching.
257+
Unlike simple keyword overlap, BM25 penalizes common terms and
258+
rewards rare terms — 'Mercury Fur' gets a huge boost because it
259+
appears in very few chunks, while 'the' gets zero."""
260+
q_terms = [w for w in _normalize(question).split()
261+
if len(w) >= 3 and w not in STOPWORDS and w not in LOW_SIGNAL_TERMS]
262+
if not q_terms:
263+
return [(c.chunk_id, 0.0) for c in gist.chunks if c.chunk_id not in excluded]
264+
265+
chunks = [c for c in gist.chunks if c.chunk_id not in excluded]
266+
N = len(chunks)
267+
if N == 0:
268+
return []
269+
270+
# Document frequency for each term
271+
texts = [_normalize(c.full_text or c.head_text) for c in chunks]
272+
avg_dl = sum(len(t.split()) for t in texts) / max(N, 1)
273+
274+
df = {}
275+
for term in q_terms:
276+
df[term] = sum(1 for t in texts if _word_in_text(term, t))
277+
278+
scores = []
279+
for i, chunk in enumerate(chunks):
280+
doc_words = texts[i].split()
281+
dl = len(doc_words)
282+
score = 0.0
283+
for term in q_terms:
284+
tf = sum(1 for w in doc_words if w == term or
285+
(len(w) >= 3 and len(term) >= 3 and
286+
w[:min(4, len(w))] == term[:min(4, len(term))]))
287+
n = df.get(term, 0)
288+
idf = math.log((N - n + 0.5) / (n + 0.5) + 1.0) if n < N else 0.0
289+
tf_norm = (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * dl / max(avg_dl, 1)))
290+
score += idf * tf_norm
291+
scores.append((chunk.chunk_id, score))
292+
293+
scores.sort(key=lambda x: x[1], reverse=True)
294+
return scores
295+
296+
297+
# ----------------------------------------------------------------------------
298+
# LLM classification (Strategy 2: always-on, not just fallback)
251299
# ----------------------------------------------------------------------------
252300
_NUM_RE = re.compile(r"\b(\d+)\b")
253301

@@ -271,25 +319,30 @@ def _llm_locate(
271319
gist: Gist,
272320
excluded: List[int],
273321
candidate_ids: List[int],
322+
verbose: bool = False,
274323
) -> int:
275-
"""Ask the LLM to choose among the top candidate chunks. Day 3:
276-
candidates presented as 1-indexed CHOICE numbers (decoupled from
277-
chunk_ids) so the parser doesn't pick up document-internal integers."""
324+
"""Ask the LLM to classify which chunk contains the answer.
325+
Day 5: always-on LLM classification (not just fallback).
326+
Shows first 2 sentences per chunk for better context."""
278327
available = [cid for cid in candidate_ids if cid not in excluded]
279328
if not available:
280329
return -1
281330

282331
lines = []
283332
for choice_num, cid in enumerate(available, start=1):
284333
chunk = gist.chunks[cid]
285-
head = chunk.head_text.replace("\n", " ").strip()
286-
head = re.sub(r"^section\s*\d+\s*[:.\-]\s*", "", head, flags=re.IGNORECASE)
287-
if len(head) > 180:
288-
head = head[:180] + "…"
289-
lines.append(f"Choice {choice_num}: {head}")
334+
text = (chunk.full_text or chunk.head_text).replace("\n", " ").strip()
335+
# Show first 2 sentences (more context than just head)
336+
sents = re.split(r'(?<=[.!?])\s+', text)
337+
preview = " ".join(sents[:2])
338+
if len(preview) > 250:
339+
preview = preview[:250] + "..."
340+
lines.append(f"[{choice_num}] {preview}")
290341
outline = "\n".join(lines)
291342
prompt = LOCATOR_LLM_PROMPT_TEMPLATE.format(outline=outline, question=question)
292343
result = _llm.llm_call(prompt, max_tokens=8)
344+
if verbose:
345+
print(f"[locator-llm] response: {result.text!r}")
293346
choice = _parse_locator_response(result.text, len(available) + 1)
294347
if choice < 1 or choice > len(available):
295348
return -1
@@ -327,51 +380,69 @@ def locate(
327380
char_start=chunk.char_start, char_end=chunk.char_end, score=0.0,
328381
)
329382

330-
best_id, best_score, all_scores = _keyword_locate(question, gist, excluded)
331-
second_score = all_scores[1][1] if len(all_scores) > 1 else 0.0
332-
margin = best_score - second_score
383+
# --- Step 1: Keyword scoring ---
384+
best_id, best_score, kw_scores = _keyword_locate(question, gist, excluded)
385+
386+
# --- Step 2: BM25 scoring ---
387+
bm25_scores = _bm25_score_chunks(question, gist, excluded)
333388

334389
if verbose:
335-
print(f"[locator] keyword scores top3: {all_scores[:3]} (margin={margin:.2f})")
390+
print(f"[locator] keyword top3: {kw_scores[:3]}")
391+
print(f"[locator] bm25 top3: {bm25_scores[:3]}")
392+
393+
# --- Step 3: Reciprocal Rank Fusion (keyword + BM25) ---
394+
rrf_k = 60
395+
rrf = {}
396+
for rank, (cid, _) in enumerate(kw_scores):
397+
rrf[cid] = rrf.get(cid, 0) + 1.0 / (rrf_k + rank)
398+
for rank, (cid, _) in enumerate(bm25_scores):
399+
rrf[cid] = rrf.get(cid, 0) + 1.0 / (rrf_k + rank)
400+
rrf_ranked = sorted(rrf.items(), key=lambda x: x[1], reverse=True)
336401

337-
method = "keyword"
338-
chosen = best_id
402+
if verbose:
403+
print(f"[locator] rrf top3: {rrf_ranked[:3]}")
339404

340-
if best_score >= 2.0 and margin >= 1.0:
341-
confidence = "high"
342-
elif best_score >= 1.0 and margin >= 0.5:
343-
confidence = "medium"
344-
else:
345-
scored = [(cid, s) for cid, s in all_scores if s > 0]
346-
candidate_ids = [cid for cid, _ in scored[:3]]
347-
if len(candidate_ids) < 2:
348-
confidence = "low"
349-
chunk = gist.chunks[chosen]
350-
return RegionPointer(
351-
chunk_id=chosen, confidence=confidence,
352-
candidates=[cid for cid, _ in all_scores[:3]],
353-
char_start=chunk.char_start, char_end=chunk.char_end,
354-
score=best_score, method="keyword",
355-
)
356-
if verbose:
357-
print(f"[locator] keyword ambiguous (best={best_score:.2f}, "
358-
f"margin={margin:.2f}), invoking LLM fallback over {candidate_ids}")
359-
llm_choice = _llm_locate(question, gist, excluded, candidate_ids)
360-
if llm_choice >= 0 and llm_choice not in excluded:
405+
# --- Step 4: LLM classification on top candidates ---
406+
# Always run LLM on the top 5 RRF candidates (not just when ambiguous)
407+
top_candidates = [cid for cid, _ in rrf_ranked[:5]]
408+
llm_choice = _llm_locate(question, gist, excluded, top_candidates, verbose=verbose)
409+
410+
rrf_top1 = rrf_ranked[0][0]
411+
rrf_top1_score = rrf_ranked[0][1]
412+
rrf_top2_score = rrf_ranked[1][1] if len(rrf_ranked) > 1 else 0.0
413+
rrf_margin = (rrf_top1_score - rrf_top2_score) / max(rrf_top1_score, 0.001)
414+
415+
if llm_choice >= 0 and llm_choice not in excluded:
416+
if llm_choice == rrf_top1:
417+
# LLM and RRF agree — high confidence
361418
chosen = llm_choice
362-
method = "keyword+llm"
419+
method = "rrf+llm"
420+
confidence = "high"
421+
elif rrf_margin < 0.15:
422+
# RRF is close — trust LLM to break the tie
423+
chosen = llm_choice
424+
method = "rrf+llm-override"
363425
confidence = "medium"
364426
else:
365-
method = "keyword"
366-
confidence = "low"
427+
# RRF has a clear winner — trust RRF over LLM
428+
chosen = rrf_top1
429+
method = "rrf(llm-overruled)"
430+
confidence = "high"
431+
else:
432+
chosen = rrf_top1
433+
method = "rrf"
434+
confidence = "medium" if rrf_margin > 0.1 else "low"
435+
436+
if verbose:
437+
print(f"[locator] chosen: chunk {chosen} via {method} (confidence={confidence})")
367438

368439
chunk = gist.chunks[chosen]
369440
return RegionPointer(
370441
chunk_id=chosen,
371442
confidence=confidence,
372-
candidates=[cid for cid, _ in all_scores[:3]],
443+
candidates=[cid for cid, _ in rrf_ranked[:3]],
373444
char_start=chunk.char_start,
374445
char_end=chunk.char_end,
375-
score=best_score,
446+
score=rrf.get(chosen, 0.0),
376447
method=method,
377448
)

0 commit comments

Comments
 (0)