Skip to content

Commit 9da87f1

Browse files
committed
{"schema":"cmsg/1","type":"feat","scope":"eval","summary":"expected_keys + consolidation harness","intent":"support consolidation/reflection evaluation by stable keys","impact":"elf-eval dataset supports expected_keys; new e2e-consolidation-harness; docs+memo","breaking":false,"risk":"low","refs":[]}
1 parent b701893 commit 9da87f1

7 files changed

Lines changed: 978 additions & 176 deletions

File tree

Makefile.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ args = [
203203
# | ------------------------------ | --------- | --- |
204204
# | e2e | composite | |
205205
# | e2e-context-misranking-harness | command | |
206+
# | e2e-consolidation-harness | command | |
206207

207208
[tasks.e2e]
208209
workspace = false
@@ -217,6 +218,13 @@ args = [
217218
"scripts/context-misranking-harness.sh",
218219
]
219220

221+
[tasks.e2e-consolidation-harness]
222+
workspace = false
223+
command = "bash"
224+
args = [
225+
"scripts/consolidation-harness.sh",
226+
]
227+
220228

221229
# Meta
222230
# | task | type | cwd |

apps/elf-eval/src/lib.rs

Lines changed: 191 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use uuid::Uuid;
1616

1717
use elf_config::Config;
1818
use elf_service::{
19-
ElfService, RankingRequestOverride, SearchIndexResponse, SearchRequest, search::TraceReplayItem,
19+
ElfService, RankingRequestOverride, SearchIndexItem, SearchIndexResponse, SearchRequest,
20+
search::TraceReplayItem,
2021
};
2122
use elf_storage::{db::Db, qdrant::QdrantStore};
2223

@@ -71,7 +72,10 @@ struct EvalQuery {
7172
read_profile: Option<String>,
7273
top_k: Option<u32>,
7374
candidate_k: Option<u32>,
75+
#[serde(default)]
7476
expected_note_ids: Vec<Uuid>,
77+
#[serde(default)]
78+
expected_keys: Vec<String>,
7579
ranking: Option<RankingRequestOverride>,
7680
}
7781

@@ -106,6 +110,7 @@ struct EvalSummary {
106110
mean_ndcg: f64,
107111
latency_ms_p50: f64,
108112
latency_ms_p95: f64,
113+
avg_retrieved_summary_chars: f64,
109114
#[serde(skip_serializing_if = "Option::is_none")]
110115
stability: Option<StabilitySummary>,
111116
}
@@ -133,11 +138,23 @@ struct QueryReport {
133138
ndcg: f64,
134139
latency_ms: f64,
135140
expected_note_ids: Vec<Uuid>,
141+
expected_keys: Vec<String>,
142+
expected_kind: ExpectedKind,
136143
retrieved_note_ids: Vec<Uuid>,
144+
#[serde(skip_serializing_if = "Vec::is_empty")]
145+
retrieved_keys: Vec<Option<String>>,
146+
retrieved_summary_chars: usize,
137147
#[serde(skip_serializing_if = "Option::is_none")]
138148
stability: Option<QueryStability>,
139149
}
140150

151+
#[derive(Debug, Serialize, Clone, Copy, PartialEq, Eq)]
152+
#[serde(rename_all = "snake_case")]
153+
enum ExpectedKind {
154+
NoteId,
155+
Key,
156+
}
157+
141158
#[derive(Debug, Serialize, Clone, Copy)]
142159
struct QueryStability {
143160
runs_per_query: u32,
@@ -172,6 +189,7 @@ struct EvalSummaryDelta {
172189
mean_ndcg: f64,
173190
latency_ms_p50: f64,
174191
latency_ms_p95: f64,
192+
avg_retrieved_summary_chars: f64,
175193
#[serde(skip_serializing_if = "Option::is_none")]
176194
stability: Option<StabilitySummaryDelta>,
177195
}
@@ -357,6 +375,8 @@ struct MergedQuery {
357375
id: String,
358376
query: String,
359377
expected_note_ids: Vec<Uuid>,
378+
expected_keys: Vec<String>,
379+
expected_kind: ExpectedKind,
360380
request: SearchRequest,
361381
}
362382

@@ -511,6 +531,7 @@ fn diff_summary(a: &EvalSummary, b: &EvalSummary) -> EvalSummaryDelta {
511531
mean_ndcg: b.mean_ndcg - a.mean_ndcg,
512532
latency_ms_p50: b.latency_ms_p50 - a.latency_ms_p50,
513533
latency_ms_p95: b.latency_ms_p95 - a.latency_ms_p95,
534+
avg_retrieved_summary_chars: b.avg_retrieved_summary_chars - a.avg_retrieved_summary_chars,
514535
stability: match (&a.stability, &b.stability) {
515536
(Some(sa), Some(sb)) => Some(StabilitySummaryDelta {
516537
avg_positional_churn_at_k: sb.avg_positional_churn_at_k
@@ -612,12 +633,8 @@ fn merge_query(
612633
cfg: &Config,
613634
index: usize,
614635
) -> Result<MergedQuery> {
615-
if query.expected_note_ids.is_empty() {
616-
return Err(eyre::eyre!(
617-
"Query at index {index} must include at least one expected_note_id."
618-
));
619-
}
620-
636+
let expected_kind =
637+
resolve_expected_mode(index, &query.expected_note_ids, &query.expected_keys)?;
621638
let tenant_id = query
622639
.tenant_id
623640
.clone()
@@ -652,6 +669,8 @@ fn merge_query(
652669
id,
653670
query: query.query.clone(),
654671
expected_note_ids: query.expected_note_ids.clone(),
672+
expected_keys: query.expected_keys.clone(),
673+
expected_kind,
655674
request: SearchRequest {
656675
tenant_id,
657676
project_id,
@@ -669,16 +688,29 @@ fn merge_query(
669688
})
670689
}
671690

672-
fn unique_ids<I>(iter: I) -> Vec<Uuid>
673-
where
674-
I: Iterator<Item = Uuid>,
675-
{
691+
fn resolve_expected_mode(index: usize, note_ids: &[Uuid], keys: &[String]) -> Result<ExpectedKind> {
692+
let has_note_ids = !note_ids.is_empty();
693+
let has_keys = !keys.is_empty();
694+
695+
match (has_note_ids, has_keys) {
696+
(true, false) => Ok(ExpectedKind::NoteId),
697+
(false, true) => Ok(ExpectedKind::Key),
698+
(true, true) => Err(eyre::eyre!(
699+
"Query at index {index} must define exactly one expectation mode: expected_note_ids or expected_keys."
700+
)),
701+
(false, false) => Err(eyre::eyre!(
702+
"Query at index {index} must include at least one expected_note_ids or expected_keys."
703+
)),
704+
}
705+
}
706+
707+
fn unique_items(items: &[SearchIndexItem]) -> Vec<SearchIndexItem> {
676708
let mut seen = HashSet::new();
677709
let mut out = Vec::new();
678710

679-
for id in iter {
680-
if seen.insert(id) {
681-
out.push(id);
711+
for item in items {
712+
if seen.insert(item.note_id) {
713+
out.push(item.clone());
682714
}
683715
}
684716

@@ -730,12 +762,87 @@ fn compute_metrics(retrieved: &[Uuid], expected: &HashSet<Uuid>) -> Metrics {
730762
Metrics { recall_at_k, precision_at_k, rr, ndcg, relevant_count }
731763
}
732764

765+
fn compute_metrics_for_keys(retrieved: &[Option<String>], expected: &HashSet<String>) -> Metrics {
766+
let expected_count = expected.len();
767+
let mut matched: HashSet<String> = HashSet::new();
768+
let mut relevant_count = 0_usize;
769+
let mut dcg = 0.0_f64;
770+
let mut rr = 0.0_f64;
771+
let mut first_hit: Option<usize> = None;
772+
773+
for (idx, maybe_key) in retrieved.iter().enumerate() {
774+
let Some(key) = maybe_key else {
775+
continue;
776+
};
777+
778+
if expected.contains(key) && !matched.contains(key) {
779+
matched.insert(key.clone());
780+
781+
relevant_count += 1;
782+
783+
let rank = idx + 1;
784+
let denom = (rank as f64 + 1.0).log2();
785+
786+
dcg += 1.0 / denom;
787+
788+
if first_hit.is_none() {
789+
first_hit = Some(rank);
790+
}
791+
}
792+
}
793+
794+
if let Some(rank) = first_hit {
795+
rr = 1.0 / rank as f64;
796+
}
797+
798+
let ideal_hits = expected_count.min(retrieved.len());
799+
let mut idcg = 0.0_f64;
800+
801+
for idx in 0..ideal_hits {
802+
let rank = idx + 1;
803+
let denom = (rank as f64 + 1.0).log2();
804+
805+
idcg += 1.0 / denom;
806+
}
807+
808+
let ndcg = if idcg > 0.0 { dcg / idcg } else { 0.0 };
809+
let precision_at_k =
810+
if retrieved.is_empty() { 0.0 } else { relevant_count as f64 / retrieved.len() as f64 };
811+
let recall_at_k =
812+
if expected_count == 0 { 0.0 } else { relevant_count as f64 / expected_count as f64 };
813+
814+
Metrics { recall_at_k, precision_at_k, rr, ndcg, relevant_count }
815+
}
816+
817+
fn compute_metrics_for_query(
818+
merged: &MergedQuery,
819+
retrieved_note_ids: &[Uuid],
820+
retrieved_keys: &[Option<String>],
821+
) -> (Metrics, usize) {
822+
match merged.expected_kind {
823+
ExpectedKind::NoteId => {
824+
let expected: HashSet<Uuid> = merged.expected_note_ids.iter().copied().collect();
825+
let expected_count = expected.len();
826+
827+
(compute_metrics(retrieved_note_ids, &expected), expected_count)
828+
},
829+
ExpectedKind::Key => {
830+
let expected: HashSet<String> = merged.expected_keys.iter().cloned().collect();
831+
let expected_count = expected.len();
832+
833+
(compute_metrics_for_keys(retrieved_keys, &expected), expected_count)
834+
},
835+
}
836+
}
837+
733838
fn summarize(reports: &[QueryReport], latencies_ms: &[f64]) -> EvalSummary {
734839
let count = reports.len().max(1) as f64;
735840
let avg_recall_at_k = reports.iter().map(|r| r.recall_at_k).sum::<f64>() / count;
736841
let avg_precision_at_k = reports.iter().map(|r| r.precision_at_k).sum::<f64>() / count;
737842
let mean_rr = reports.iter().map(|r| r.rr).sum::<f64>() / count;
738843
let mean_ndcg = reports.iter().map(|r| r.ndcg).sum::<f64>() / count;
844+
let avg_retrieved_summary_chars =
845+
reports.iter().map(|r| r.retrieved_summary_chars as f64).sum::<f64>() / count;
739846
let mut sorted = latencies_ms.to_vec();
740847

741848
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
@@ -750,6 +857,7 @@ fn summarize(reports: &[QueryReport], latencies_ms: &[f64]) -> EvalSummary {
750857
mean_ndcg,
751858
latency_ms_p50: p50,
752859
latency_ms_p95: p95,
860+
avg_retrieved_summary_chars,
753861
stability: None,
754862
}
755863
}
@@ -1140,11 +1248,16 @@ async fn eval_config(
11401248

11411249
for (index, query) in dataset.queries.iter().enumerate() {
11421250
let merged = merge_query(&defaults, query, args, &service.cfg, index)?;
1143-
let expected: HashSet<Uuid> = merged.expected_note_ids.iter().copied().collect();
11441251
let (first, latency_ms, stability, trace_ids) =
1145-
run_query_n_times(&service, merged.request, runs_per_query).await?;
1146-
let retrieved = unique_ids(first.items.iter().map(|item| item.note_id));
1147-
let metrics = compute_metrics(&retrieved, &expected);
1252+
run_query_n_times(&service, merged.request.clone(), runs_per_query).await?;
1253+
let retrieved = unique_items(&first.items);
1254+
let retrieved_note_ids: Vec<Uuid> = retrieved.iter().map(|item| item.note_id).collect();
1255+
let retrieved_keys: Vec<Option<String>> =
1256+
retrieved.iter().map(|item| item.key.clone()).collect();
1257+
let retrieved_summary_chars =
1258+
retrieved.iter().map(|item| item.summary.len()).sum::<usize>();
1259+
let (metrics, expected_count) =
1260+
compute_metrics_for_query(&merged, &retrieved_note_ids, &retrieved_keys);
11481261

11491262
if let Some(s) = stability {
11501263
stability_positional.push(s.positional_churn_at_k);
@@ -1156,16 +1269,24 @@ async fn eval_config(
11561269
query: merged.query,
11571270
trace_id: first.trace_id,
11581271
trace_ids: (trace_ids.len() > 1).then_some(trace_ids),
1159-
expected_count: expected.len(),
1160-
retrieved_count: retrieved.len(),
1272+
expected_count,
1273+
retrieved_count: retrieved_note_ids.len(),
11611274
relevant_count: metrics.relevant_count,
11621275
recall_at_k: metrics.recall_at_k,
11631276
precision_at_k: metrics.precision_at_k,
11641277
rr: metrics.rr,
11651278
ndcg: metrics.ndcg,
11661279
latency_ms,
11671280
expected_note_ids: merged.expected_note_ids,
1168-
retrieved_note_ids: retrieved,
1281+
expected_keys: merged.expected_keys,
1282+
expected_kind: merged.expected_kind,
1283+
retrieved_note_ids,
1284+
retrieved_keys: if merged.expected_kind == ExpectedKind::Key {
1285+
retrieved_keys
1286+
} else {
1287+
Vec::new()
1288+
},
1289+
retrieved_summary_chars,
11691290
stability,
11701291
});
11711292
latencies_ms.push(latency_ms);
@@ -1217,7 +1338,7 @@ async fn run_query_n_times(
12171338
let k = request.top_k.unwrap_or(1).max(1) as usize;
12181339
let runs = runs_per_query.max(1);
12191340
let mut first_response: Option<SearchIndexResponse> = None;
1220-
let mut first_retrieved: Vec<Uuid> = Vec::new();
1341+
let mut first_retrieved_ids: Vec<Uuid> = Vec::new();
12211342
let mut trace_ids: Vec<Uuid> = Vec::with_capacity(runs as usize);
12221343
let mut latency_total_ms = 0.0_f64;
12231344
let mut positional_churn_sum = 0.0_f64;
@@ -1233,17 +1354,18 @@ async fn run_query_n_times(
12331354

12341355
trace_ids.push(response.trace_id);
12351356

1236-
let retrieved = unique_ids(response.items.iter().map(|item| item.note_id));
1357+
let retrieved = unique_items(&response.items);
1358+
let retrieved_ids = retrieved.iter().map(|item| item.note_id).collect::<Vec<_>>();
12371359

12381360
if run_idx == 0 {
1239-
first_retrieved = retrieved;
1361+
first_retrieved_ids = retrieved_ids;
12401362
first_response = Some(response);
12411363

12421364
continue;
12431365
}
12441366

12451367
let (positional_churn_at_k, set_churn_at_k) =
1246-
churn_against_baseline_at_k(&first_retrieved, &retrieved, k);
1368+
churn_against_baseline_at_k(&first_retrieved_ids, &retrieved_ids, k);
12471369

12481370
positional_churn_sum += positional_churn_at_k;
12491371
set_churn_sum += set_churn_at_k;
@@ -1271,7 +1393,50 @@ async fn run_query_n_times(
12711393

12721394
#[cfg(test)]
12731395
mod tests {
1274-
use crate::{OffsetDateTime, Uuid, retrieval_top_rank_retention};
1396+
use std::collections::HashSet;
1397+
1398+
use crate::{
1399+
ExpectedKind, OffsetDateTime, Uuid, compute_metrics_for_keys, resolve_expected_mode,
1400+
retrieval_top_rank_retention,
1401+
};
1402+
1403+
#[test]
1404+
fn resolve_expected_mode_requires_exactly_one_definition() {
1405+
let index = 0;
1406+
let note_ids = vec![Uuid::new_v4()];
1407+
let expected_keys = vec!["key-1".to_string()];
1408+
let note_only = resolve_expected_mode(index, &note_ids, &[]);
1409+
let key_only = resolve_expected_mode(index, &[], &expected_keys);
1410+
let none = resolve_expected_mode(index, &[], &[]);
1411+
let both = resolve_expected_mode(index, &note_ids, &expected_keys);
1412+
1413+
assert!(matches!(note_only.unwrap(), ExpectedKind::NoteId));
1414+
assert!(matches!(key_only.unwrap(), ExpectedKind::Key));
1415+
assert!(none.is_err(), "Expected missing expectations to be rejected");
1416+
assert!(both.is_err(), "Expected both expectation fields to be rejected");
1417+
}
1418+
1419+
#[test]
1420+
fn compute_metrics_for_keys_counts_first_hit_per_unique_key_and_ignores_missing_keys() {
1421+
let expected: HashSet<String> =
1422+
["alpha", "beta", "gamma"].into_iter().map(String::from).collect();
1423+
let retrieved = vec![
1424+
None,
1425+
Some("alpha".to_string()),
1426+
Some("alpha".to_string()),
1427+
Some("gamma".to_string()),
1428+
Some("missing".to_string()),
1429+
];
1430+
let metrics = compute_metrics_for_keys(&retrieved, &expected);
1431+
let expected_dcg = 1.0 / (3.0_f64).log2() + 1.0 / (5.0_f64).log2();
1432+
let expected_idcg = 1.0 + 1.0 / (3.0_f64).log2() + 1.0 / (4.0_f64).log2();
1433+
1434+
assert_eq!(metrics.relevant_count, 2);
1435+
assert!((metrics.precision_at_k - (2.0 / 5.0)).abs() < 1e-12);
1436+
assert!((metrics.recall_at_k - (2.0 / 3.0)).abs() < 1e-12);
1437+
assert!((metrics.rr - (1.0 / 2.0)).abs() < 1e-12);
1438+
assert!((metrics.ndcg - (expected_dcg / expected_idcg)).abs() < 1e-12);
1439+
}
12751440

12761441
#[test]
12771442
fn retrieval_top_rank_retention_counts_unique_notes_and_retained_notes() {

0 commit comments

Comments
 (0)