feat(consolidation): parallelize llm_batch processing within a single op#1604
feat(consolidation): parallelize llm_batch processing within a single op#1604connorblack wants to merge 4 commits into
Conversation
Replaces the serial `for llm_batch in llm_batches` loop in run_consolidation_job with bounded asyncio.gather using a Semaphore at config.consolidation_llm_max_concurrent. Each batch executes in its own ConsolidationPerfLog instance so timings/llm_calls/prompt_chars are not raced across concurrent gather participants; the per-task perf is merged into the parent's shared perf after gather completes. Why this matters: - consolidation_llm_max_concurrent (default 8) was unused for single-bank workloads — the previous serial loop bottlenecked on single-LLM-call latency. With gather + Semaphore we saturate up to N parallel LLM calls within one op for an N x speedup. - The tag-group security boundary (memories with different tags never share an LLM call) is preserved unchanged. - The adaptive split-on-failure protocol remains serial within each batch (correctness requirement of the split protocol). - The cross-batch DB commit is now a single round-trip per fetch wave instead of one commit per batch (functionally equivalent — IDs are disjoint by tag-grouping). Cancellation granularity changes from per-batch to per-gather-wave; acceptable since waves are bounded by consolidation_batch_size memories (default 50) and complete in seconds. New module surface: - ConsolidationPerfLog.merge() — aggregate per-task perf into shared - _BatchExecutionResult dataclass — self-contained per-batch result - _execute_one_llm_batch() — extracted per-batch work, used as gather task
Address reviewer feedback on the prior commit before opening upstream PR: - ConsolidationPerfLog.merge() -> __iadd__: matches the codebase's TokenUsage.__add__ accumulator idiom (used heavily in fact_extraction). Callers now write `perf += batch_result.perf`. - Extract _resolve_obs_tags_list() so observation_scopes parsing happens once per llm_batch instead of once per sub_batch (all sub_batches share the same parent batch's tags by tag-grouping invariant). - Extract _apply_action_to_stats() so the action-vocabulary mapping has one definition; per-batch counters and aggregate stats now come from a single pass over batch_result.results instead of two. - Plumb operation_id through _execute_one_llm_batch with a per-sub-batch cancellation check via _check_op_alive — restores per-batch cancellation granularity that the prior commit traded for per-wave only. - Tighten three docstrings (merge, _BatchExecutionResult, _execute_one_llm_batch) to contracts; drop refactor-narrating paragraphs. - Inline _resolve_obs_tags_list and the shared-tags assignment at the top of _execute_one_llm_batch, replacing the per-sub-batch redundant computation and the for-memory tag-tracking loop (all memories in a batch share tags). - Comment in run_consolidation_job notes the new semaphore stacks on top of the global LLM semaphore in llm_wrapper.py — effective concurrency is min(this, the global cap). - Rename loop variable b -> batch and br -> batch_result for readability. - Drop dead `denom = max(1, br.memories_count)` guard (memories_count is always >= 1 for non-empty batches by construction); use `or 1` inline. No behavior changes intended. Smoke-tested end-to-end on a deployed image.
Address remaining findings before opening upstream PR: - Introduce _ResultDelta dataclass with __iadd__; _record_result returns _ResultDelta instead of an opaque dict. Caller now writes `batch_counters += _record_result(stats, result)` (matches the codebase's TokenUsage.__add__ accumulator idiom). - Track 'merged' in per-batch counters and per-batch log line. Pre-patch silently dropped merged actions from the per-batch log; the new helper is the right place to fix this. - Extract _merge_pass_result(existing, new) -> dict from the dense inline block in _execute_one_llm_batch's multi-pass loop. Centralizes the skipped-is-weak / non-skipped-combine-into-multiple action vocabulary that previously duplicated knowledge across two sites. - Extract _is_op_cancelled(memory_engine, operation_id) -> bool predicate. Used at both the per-sub-batch check inside _execute_one_llm_batch and the end-of-wave checkpoint in run_consolidation_job; the implicit `operation_id is not None` short-circuit is now in one place. - Replace `start_num = llm_batch_num` capture pattern with `enumerate(llm_batches, start=llm_batch_num + 1)` to match the codebase's idiom (memory_engine.py:3234, search/tracer.py:325, etc.). - Drop dead `(memories_count or 1)` defensive guard. memories_count is always >= 1 by tag-group construction (range slice over non-empty group). - Rename _apply_action_to_stats -> _record_result. Old name suggested one-way write; new name covers the mutate-and-return shape clearly.
There was a problem hiding this comment.
Pull request overview
This PR parallelizes consolidation LLM-batch execution within each DB fetch “wave” by replacing a serial per-batch loop with bounded asyncio.gather() concurrency, while refactoring the batch execution into an async helper and making per-batch performance/log aggregation safe under parallelism.
Changes:
- Execute
llm_batchesconcurrently with a per-waveasyncio.Semaphorebounded byconfig.consolidation_llm_max_concurrent. - Extract per-batch execution into
_execute_one_llm_batch()and return a_BatchExecutionResultfor parent aggregation. - Centralize stats/log accounting via
_record_result()/_ResultDeltaand merge per-task perf into the parent viaConsolidationPerfLog.__iadd__.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address Copilot review on vectorize-io#1604: 1. Coalesce None on consolidation_llm_max_concurrent The field is Optional[int] in HindsightConfig and only set when HINDSIGHT_API_CONSOLIDATION_LLM_MAX_CONCURRENT is in env. Without the coalesce, max(1, None) raises TypeError on first wave. Fall back to the global llm_max_concurrent cap (which always has a default). 2. Use return_exceptions=True on the gather wave Without it, gather's first-exception-cancels-siblings semantic would skip the post-wave UPDATE, leaving observation rows already inserted by successful batches without their consolidated_at marker. Those memories would be re-consolidated on the next run, producing duplicate observations. Now: partition gather results into successes vs the first exception, apply DB markers for successful batches, then re-raise so the worker poller's standard exception handling kicks in. Additional exceptions in the same wave are logged via exc_info. 3. Revert obs_tags caching in _execute_one_llm_batch Round-2 hoisted observation_scopes parsing out of the while-pending loop on the assumption tag_groups upstream guaranteed uniform scopes per batch. tag_groups actually only keys on tags, so adaptive split sub_batches could legitimately have different observation_scopes than the parent llm_batch. Parse per sub_batch to preserve scope semantics. 4. Add tests/test_consolidation_parallelism.py - test_consolidation_honors_max_concurrent: pin max=4, ingest 8 memories with distinct tag groups, mock _process_memory_batch with a slot counter; assert 1 < max_in_flight <= 4. - test_partial_failure_preserves_succeeded_markers: ingest 6 memories, mock raises on the 2nd call; assert 5 succeeded markers present, 1 absent, exception re-raised.
nicoloboschi
left a comment
There was a problem hiding this comment.
Must fix: same-tag batches must NOT run in parallel
The gather parallelizes all llm_batches, but multiple batches can come from the same tag group (when a group has more memories than llm_batch_size):
for group in tag_groups.values():
for i in range(0, len(group), llm_batch_size):
llm_batches.append(group[i : i + llm_batch_size]) # multiple batches, same tagsWhen two same-tag batches run concurrently, both read the observation network at the same time — so batch 2 never sees observations created/updated by batch 1. The LLM decides "create new observation" in both batches for what should have been a single consolidated observation. Memories that should be consolidated together won't be, because each batch operates on a stale snapshot of the observation graph.
Suggested fix
Parallelize across tag groups (disjoint observation scopes → safe), but keep batches within a tag group serial (shared observation scope → must see prior writes):
async def _process_tag_group(group_batches, sem, ...):
results = []
async with sem: # bound concurrency across all tag groups
for batch in group_batches:
results.append(await _execute_one_llm_batch(batch, ...))
return results
# Build per-tag-group batch lists instead of flattening
per_tag_group: list[list[...]] = []
for group in tag_groups.values():
per_tag_group.append([group[i:i+llm_batch_size] for i in range(0, len(group), llm_batch_size)])
gather_results = await asyncio.gather(
*(_process_tag_group(batches, sem, ...) for batches in per_tag_group),
return_exceptions=True,
)Test gap
test_consolidation_honors_max_concurrent uses distinct tags per memory (person0, person1, ...) so it can't catch this. A regression test should ingest N memories with the same tag, set llm_batch_size small enough to split them, and assert that batch 2's observation recall sees batch 1's writes (i.e., observations are consolidated together, not duplicated).
Summary
Replace the serial
for llm_batch in llm_batchesloop inrun_consolidation_jobwith boundedasyncio.gatherusing aSemaphoreatconfig.consolidation_llm_max_concurrent. The env knob has existed for a while (default 8) but was unused for single-bank workloads — the previous serial loop bottlenecked on single-LLM-call latency regardless of available capacity.Why this matters
In production we observed
consolidation_llm_max_concurrent=8configured but the consolidator only ever made one in-flight LLM call per op. With Qwen3.6-35B-A3B-FP8 on a single GB10 GPU, latency per consolidation call is 5-10s oncechat_template_kwargs.enable_thinking=falseis set; with parallelism unlocked, the same wave processes 8 memories in roughly the same wall time as 1.Verified on a 14k-memory bank: same-millisecond timestamp on the per-batch log lines for all batches in a wave (proves
gatherreturned them as a wave) with no double-create or skipped failures across hundreds of batches.Correctness invariants preserved
pendingqueue (halve sub-batch on LLM failure, mark failed only at size=1) remains serial within each batch — that ordering is required by the split protocol.LIMIT \$2query produces a fixed set per fetch round; tag-group slicing produces non-overlapping subsets. No two batches in a gather wave can race on the samememory_id.executemanyper fetch round (functionally equivalent to the previous per-batch commits — IDs are disjoint).Module surface added
ConsolidationPerfLog.__iadd__— accumulator (matchesTokenUsage.__add__codebase idiom)._BatchExecutionResultdataclass — self-contained per-batch outputs for parent aggregation._record_result(stats, result) -> _ResultDelta— centralizes the action-vocabulary mapping (created/updated/merged/multiple/skipped/failed). Side benefit: pre-patch silently droppedmergedactions from per-batch log lines; the new helper tracks them._merge_pass_result(existing, new) -> dict— extracted from the dense inline block in the multi-pass loop; testable in isolation._resolve_obs_tags_list(memory_tags, scope_spec)— translatesobservation_scopesspec into concrete tag-set passes; called once per llm_batch instead of once per sub_batch._is_op_cancelled(memory_engine, operation_id) -> bool— predicate used at both the per-sub-batch check inside_execute_one_llm_batchand the end-of-wave checkpoint inrun_consolidation_job._execute_one_llm_batch(...)async helper extracted from the for-loop body. Acceptsoperation_idso per-sub-batch cancellation polling restores per-batch cancellation granularity (the gather wave checkpoint alone would have lengthened cancellation latency from ~5s to ~45s).Per-batch logging
Each batch executes against its own
ConsolidationPerfLoginstance so timings/llm_calls/prompt_chars are not raced across concurrent gather participants. Per-task perf is merged into the parent via+=after gather completes. The pre-patch snap-delta-style logging (capture-before, log-after) was incorrect under parallel execution; the new shape carries per-batch perf in the result dataclass and emits log lines deterministically ordered bybatch_num.Throughput notes
Empirical: ~16× speedup of consolidation throughput when combined with thinking-off (~5s/call vs ~80s/call). At a fixed concurrency of 8, the 8 in-flight LLM calls per wave saturate the GPU; cranking to 16 increases per-call latency due to KV-cache contention without proportional throughput gain (this is GPU-bound, not Python-bound). Conservative production setting: leave
consolidation_llm_max_concurrentat the default of 8.Test plan
hindsight-api-slim/tests/test_consolidation*.pyrequiresHINDSIGHT_API_LLM_API_KEYfor testcontainers setup; was not exercised against this fork. CI'scorelane (paths-filter:hindsight-api-slim/**) will run the suite end-to-end.consolidation_failed_atregressions across 1000+ batches.pre-commithooks (ruff check --fix,ruff format,ty check) all green perscripts/hooks/lint.sh.Notes for reviewers
Three commits on the branch — visible iteration via two rounds of internal review (reuse / quality / efficiency). Squash on merge is fine if that's the project preference; the final state is what matters.
Cancellation latency under parallelism: per-sub-batch via
operation_id(~one DB roundtrip per sub_batch) plus end-of-wave checkpoint. With wave size ≤consolidation_batch_size(default 50) and waves completing in seconds with thinking off, worst-case cancellation latency is bounded by the slowest LLM call in the active wave.Related follow-ups (not in this PR)