From a1541f3bf165ab91adc479c6d35b6e1d4226dfe8 Mon Sep 17 00:00:00 2001 From: phernandez Date: Mon, 6 Apr 2026 19:56:26 -0500 Subject: [PATCH 01/10] perf(core): batch postgres vector sync preparation Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 428 +++++++++++++++++- .../cloud/test_cloud_api_client_and_utils.py | 2 - .../test_postgres_search_repository_unit.py | 55 ++- 3 files changed, 454 insertions(+), 31 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index 27d253f7d..925c975f5 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -3,8 +3,9 @@ import asyncio import json import re +import time from datetime import datetime -from typing import List, Optional +from typing import List, Optional, cast from loguru import logger from sqlalchemy import text @@ -15,12 +16,21 @@ from basic_memory.repository.embedding_provider import EmbeddingProvider from basic_memory.repository.embedding_provider_factory import create_embedding_provider from basic_memory.repository.search_index_row import SearchIndexRow -from basic_memory.repository.search_repository_base import SearchRepositoryBase +from basic_memory.repository.search_repository_base import ( + SearchRepositoryBase, + VectorSyncBatchResult, + _EntitySyncRuntime, + _PendingEmbeddingJob, + _PreparedEntityVectorSync, +) from basic_memory.repository.metadata_filters import parse_metadata_filters from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode +POSTGRES_VECTOR_PREPARE_CONCURRENCY = 4 + + def _strip_nul_from_row(row_data: dict) -> dict: """Strip NUL bytes from all string values in a row dict. @@ -441,34 +451,406 @@ async def _run_vector_query( ) return [dict(row) for row in vector_result.mappings().all()] - async def _write_embeddings( + async def sync_entity_vectors_batch( self, - session: AsyncSession, - jobs: list[tuple[int, str]], - embeddings: list[list[float]], - ) -> None: - for (row_id, _), vector in zip(jobs, embeddings, strict=True): - vector_literal = self._format_pgvector_literal(vector) - await session.execute( + entity_ids: list[int], + progress_callback=None, + ) -> VectorSyncBatchResult: + """Sync semantic vectors with concurrent Postgres preparation windows. + + Trigger: cloud indexing uses Neon Postgres where network latency dominates + thousands of per-entity prepare queries. + Why: preparing a small window of entities concurrently hides round-trip latency + without exhausting the tenant connection pool. + Outcome: Postgres vector sync keeps the existing flush semantics while reducing + wall-clock time on large cloud projects. + """ + self._assert_semantic_available() + await self._ensure_vector_tables() + assert self._embedding_provider is not None + + total_entities = len(entity_ids) + result = VectorSyncBatchResult( + entities_total=total_entities, + entities_synced=0, + entities_failed=0, + ) + if total_entities == 0: + return result + + logger.info( + "Vector batch sync start: project_id={project_id} entities_total={entities_total} " + "sync_batch_size={sync_batch_size} prepare_concurrency={prepare_concurrency}", + project_id=self.project_id, + entities_total=total_entities, + sync_batch_size=self._semantic_embedding_sync_batch_size, + prepare_concurrency=POSTGRES_VECTOR_PREPARE_CONCURRENCY, + ) + + pending_jobs: list[_PendingEmbeddingJob] = [] + entity_runtime: dict[int, _EntitySyncRuntime] = {} + failed_entity_ids: set[int] = set() + synced_entity_ids: set[int] = set() + + for window_start in range(0, total_entities, POSTGRES_VECTOR_PREPARE_CONCURRENCY): + window_entity_ids = entity_ids[ + window_start : window_start + POSTGRES_VECTOR_PREPARE_CONCURRENCY + ] + + if progress_callback is not None: + for offset, entity_id in enumerate(window_entity_ids, start=window_start): + progress_callback(entity_id, offset, total_entities) + + prepared_window = await asyncio.gather( + *(self._prepare_entity_vector_jobs(entity_id) for entity_id in window_entity_ids), + return_exceptions=True, + ) + + for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): + if isinstance(prepared, BaseException): + failed_entity_ids.add(entity_id) + logger.warning( + "Vector batch sync entity prepare failed: project_id={project_id} " + "entity_id={entity_id} error={error}", + project_id=self.project_id, + entity_id=entity_id, + error=str(prepared), + ) + continue + + prepared_sync = cast(_PreparedEntityVectorSync, prepared) + + embedding_jobs_count = len(prepared_sync.embedding_jobs) + result.embedding_jobs_total += embedding_jobs_count + + if embedding_jobs_count == 0: + synced_entity_ids.add(entity_id) + total_seconds = time.perf_counter() - prepared_sync.sync_start + self._log_vector_sync_complete( + entity_id=entity_id, + total_seconds=total_seconds, + embed_seconds=0.0, + write_seconds=0.0, + source_rows_count=prepared_sync.source_rows_count, + embedding_jobs_count=0, + ) + continue + + entity_runtime[entity_id] = _EntitySyncRuntime( + sync_start=prepared_sync.sync_start, + source_rows_count=prepared_sync.source_rows_count, + embedding_jobs_count=embedding_jobs_count, + remaining_jobs=embedding_jobs_count, + ) + pending_jobs.extend( + _PendingEmbeddingJob( + entity_id=entity_id, + chunk_row_id=row_id, + chunk_text=chunk_text, + ) + for row_id, chunk_text in prepared_sync.embedding_jobs + ) + + while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: + flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] + pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] + try: + embed_seconds, write_seconds = await self._flush_embedding_jobs( + flush_jobs=flush_jobs, + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) + result.embed_seconds_total += embed_seconds + result.write_seconds_total += write_seconds + except Exception as exc: + affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) + failed_entity_ids.update(affected_entity_ids) + for failed_entity_id in affected_entity_ids: + entity_runtime.pop(failed_entity_id, None) + logger.warning( + "Vector batch sync flush failed: project_id={project_id} " + "affected_entities={affected_entities} chunk_count={chunk_count} " + "error={error}", + project_id=self.project_id, + affected_entities=affected_entity_ids, + chunk_count=len(flush_jobs), + error=str(exc), + ) + + if pending_jobs: + flush_jobs = list(pending_jobs) + pending_jobs = [] + try: + embed_seconds, write_seconds = await self._flush_embedding_jobs( + flush_jobs=flush_jobs, + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) + result.embed_seconds_total += embed_seconds + result.write_seconds_total += write_seconds + except Exception as exc: + affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) + failed_entity_ids.update(affected_entity_ids) + for failed_entity_id in affected_entity_ids: + entity_runtime.pop(failed_entity_id, None) + logger.warning( + "Vector batch sync final flush failed: project_id={project_id} " + "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", + project_id=self.project_id, + affected_entities=affected_entity_ids, + chunk_count=len(flush_jobs), + error=str(exc), + ) + + if entity_runtime: + orphan_runtime_entities = sorted(entity_runtime.keys()) + failed_entity_ids.update(orphan_runtime_entities) + logger.warning( + "Vector batch sync left unfinished entities after flushes: " + "project_id={project_id} unfinished_entities={unfinished_entities}", + project_id=self.project_id, + unfinished_entities=orphan_runtime_entities, + ) + + synced_entity_ids.difference_update(failed_entity_ids) + result.failed_entity_ids = sorted(failed_entity_ids) + result.entities_failed = len(result.failed_entity_ids) + result.entities_synced = len(synced_entity_ids) + + logger.info( + "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " + "entities_synced={entities_synced} entities_failed={entities_failed} " + "embedding_jobs_total={embedding_jobs_total} embed_seconds_total={embed_seconds_total:.3f} " + "write_seconds_total={write_seconds_total:.3f}", + project_id=self.project_id, + entities_total=result.entities_total, + entities_synced=result.entities_synced, + entities_failed=result.entities_failed, + embedding_jobs_total=result.embedding_jobs_total, + embed_seconds_total=result.embed_seconds_total, + write_seconds_total=result.write_seconds_total, + ) + + return result + + async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync: + """Prepare chunk mutations with Postgres-specific bulk upserts.""" + sync_start = time.perf_counter() + + logger.info( + "Vector sync start: project_id={project_id} entity_id={entity_id}", + project_id=self.project_id, + entity_id=entity_id, + ) + + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + + row_result = await session.execute( text( - "INSERT INTO search_vector_embeddings (" - "chunk_id, project_id, embedding, embedding_dims, updated_at" - ") VALUES (" - ":chunk_id, :project_id, CAST(:embedding AS vector), :embedding_dims, NOW()" - ") " - "ON CONFLICT (chunk_id) DO UPDATE SET " - "project_id = EXCLUDED.project_id, " - "embedding = EXCLUDED.embedding, " - "embedding_dims = EXCLUDED.embedding_dims, " - "updated_at = NOW()" + "SELECT id, type, title, permalink, content_stems, content_snippet, " + "category, relation_type " + "FROM search_index " + "WHERE entity_id = :entity_id AND project_id = :project_id " + "ORDER BY " + "CASE type " + "WHEN :entity_type THEN 0 " + "WHEN :observation_type THEN 1 " + "WHEN :relation_type_type THEN 2 " + "ELSE 3 END, id ASC" ), { - "chunk_id": row_id, + "entity_id": entity_id, "project_id": self.project_id, - "embedding": vector_literal, - "embedding_dims": len(vector), + "entity_type": SearchItemType.ENTITY.value, + "observation_type": SearchItemType.OBSERVATION.value, + "relation_type_type": SearchItemType.RELATION.value, }, ) + rows = row_result.fetchall() + source_rows_count = len(rows) + + if not rows: + logger.info( + "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " + "source_rows_count={source_rows_count} built_chunk_records_count=0", + project_id=self.project_id, + entity_id=entity_id, + source_rows_count=source_rows_count, + ) + await self._delete_entity_chunks(session, entity_id) + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + ) + + chunk_records = self._build_chunk_records(rows) + built_chunk_records_count = len(chunk_records) + logger.info( + "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " + "source_rows_count={source_rows_count} " + "built_chunk_records_count={built_chunk_records_count}", + project_id=self.project_id, + entity_id=entity_id, + source_rows_count=source_rows_count, + built_chunk_records_count=built_chunk_records_count, + ) + if not chunk_records: + await self._delete_entity_chunks(session, entity_id) + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + ) + + existing_rows_result = await session.execute( + text( + "SELECT c.id, c.chunk_key, c.source_hash, " + "(e.chunk_id IS NOT NULL) AS has_embedding " + "FROM search_vector_chunks c " + "LEFT JOIN search_vector_embeddings e ON e.chunk_id = c.id " + "WHERE c.project_id = :project_id AND c.entity_id = :entity_id" + ), + {"project_id": self.project_id, "entity_id": entity_id}, + ) + existing_rows = existing_rows_result.mappings().all() + existing_by_key = {str(row["chunk_key"]): row for row in existing_rows} + existing_chunks_count = len(existing_by_key) + incoming_by_key = {record["chunk_key"]: record for record in chunk_records} + + stale_ids = [ + int(row["id"]) + for chunk_key, row in existing_by_key.items() + if chunk_key not in incoming_by_key + ] + stale_chunks_count = len(stale_ids) + if stale_ids: + await self._delete_stale_chunks(session, stale_ids, entity_id) + + orphan_ids = {int(row["id"]) for row in existing_rows if not bool(row["has_embedding"])} + orphan_chunks_count = len(orphan_ids) + + upsert_records: list[dict[str, str]] = [] + embedding_jobs: list[tuple[int, str]] = [] + + for record in chunk_records: + current = existing_by_key.get(record["chunk_key"]) + if current is None: + upsert_records.append(record) + continue + + row_id = int(current["id"]) + is_orphan = row_id in orphan_ids + if current["source_hash"] == record["source_hash"]: + if is_orphan: + embedding_jobs.append((row_id, record["chunk_text"])) + continue + + upsert_records.append(record) + + if upsert_records: + upsert_params: dict[str, object] = { + "project_id": self.project_id, + "entity_id": entity_id, + } + upsert_values: list[str] = [] + for index, record in enumerate(upsert_records): + upsert_params[f"chunk_key_{index}"] = record["chunk_key"] + upsert_params[f"chunk_text_{index}"] = record["chunk_text"] + upsert_params[f"source_hash_{index}"] = record["source_hash"] + upsert_values.append( + "(" + ":entity_id, :project_id, " + f":chunk_key_{index}, :chunk_text_{index}, :source_hash_{index}, NOW()" + ")" + ) + + upsert_result = await session.execute( + text(f""" + INSERT INTO search_vector_chunks ( + entity_id, + project_id, + chunk_key, + chunk_text, + source_hash, + updated_at + ) VALUES {", ".join(upsert_values)} + ON CONFLICT (project_id, entity_id, chunk_key) DO UPDATE SET + chunk_text = EXCLUDED.chunk_text, + source_hash = EXCLUDED.source_hash, + updated_at = NOW() + RETURNING id, chunk_key + """), + upsert_params, + ) + upserted_ids_by_key = { + str(row["chunk_key"]): int(row["id"]) for row in upsert_result.mappings().all() + } + for record in upsert_records: + row_id = upserted_ids_by_key[record["chunk_key"]] + embedding_jobs.append((row_id, record["chunk_text"])) + + logger.info( + "Vector sync diff complete: project_id={project_id} entity_id={entity_id} " + "existing_chunks_count={existing_chunks_count} " + "stale_chunks_count={stale_chunks_count} " + "orphan_chunks_count={orphan_chunks_count} " + "embedding_jobs_count={embedding_jobs_count}", + project_id=self.project_id, + entity_id=entity_id, + existing_chunks_count=existing_chunks_count, + stale_chunks_count=stale_chunks_count, + orphan_chunks_count=orphan_chunks_count, + embedding_jobs_count=len(embedding_jobs), + ) + + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=embedding_jobs, + ) + + async def _write_embeddings( + self, + session: AsyncSession, + jobs: list[tuple[int, str]], + embeddings: list[list[float]], + ) -> None: + params: dict[str, object] = {"project_id": self.project_id} + value_rows: list[str] = [] + + for index, ((row_id, _), vector) in enumerate(zip(jobs, embeddings, strict=True)): + params[f"chunk_id_{index}"] = row_id + params[f"embedding_{index}"] = self._format_pgvector_literal(vector) + params[f"embedding_dims_{index}"] = len(vector) + value_rows.append( + "(" + f":chunk_id_{index}, :project_id, CAST(:embedding_{index} AS vector), " + f":embedding_dims_{index}, NOW()" + ")" + ) + + await session.execute( + text(f""" + INSERT INTO search_vector_embeddings ( + chunk_id, + project_id, + embedding, + embedding_dims, + updated_at + ) VALUES {", ".join(value_rows)} + ON CONFLICT (chunk_id) DO UPDATE SET + project_id = EXCLUDED.project_id, + embedding = EXCLUDED.embedding, + embedding_dims = EXCLUDED.embedding_dims, + updated_at = NOW() + """), + params, + ) async def _delete_entity_chunks( self, diff --git a/tests/cli/cloud/test_cloud_api_client_and_utils.py b/tests/cli/cloud/test_cloud_api_client_and_utils.py index c08a2b2a8..b17da47db 100644 --- a/tests/cli/cloud/test_cloud_api_client_and_utils.py +++ b/tests/cli/cloud/test_cloud_api_client_and_utils.py @@ -270,8 +270,6 @@ async def api_request(**_kwargs): await project_exists("alpha", api_request=api_request) - - @pytest.mark.asyncio async def test_make_api_request_prefers_api_key_over_oauth(config_home, config_manager): """API key in config should be used without needing an OAuth token on disk.""" diff --git a/tests/repository/test_postgres_search_repository_unit.py b/tests/repository/test_postgres_search_repository_unit.py index e7aa896ee..9a16adb59 100644 --- a/tests/repository/test_postgres_search_repository_unit.py +++ b/tests/repository/test_postgres_search_repository_unit.py @@ -5,12 +5,14 @@ are difficult to reach in integration tests. """ +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from basic_memory.config import BasicMemoryConfig, DatabaseBackend from basic_memory.repository.postgres_search_repository import PostgresSearchRepository +from basic_memory.repository.search_repository_base import _PreparedEntityVectorSync from basic_memory.repository.semantic_errors import ( SemanticDependenciesMissingError, SemanticSearchDisabledError, @@ -229,14 +231,55 @@ class TestWriteEmbeddings: """Cover _write_embeddings upsert logic.""" @pytest.mark.asyncio - async def test_write_embeddings_executes_per_job(self): + async def test_write_embeddings_executes_single_bulk_upsert(self): repo = _make_repo() session = AsyncMock() jobs = [(100, "chunk text A"), (200, "chunk text B")] embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] await repo._write_embeddings(session, jobs, embeddings) - assert session.execute.call_count == 2 - first_params = session.execute.call_args_list[0][0][1] - assert first_params["chunk_id"] == 100 - assert first_params["project_id"] == repo.project_id - assert first_params["embedding_dims"] == 4 + assert session.execute.call_count == 1 + params = session.execute.call_args[0][1] + assert params["chunk_id_0"] == 100 + assert params["chunk_id_1"] == 200 + assert params["project_id"] == repo.project_id + assert params["embedding_dims_0"] == 4 + assert params["embedding_dims_1"] == 4 + + +class TestBatchPrepareConcurrency: + """Cover the Postgres-specific concurrent prepare window.""" + + @pytest.mark.asyncio + async def test_sync_entity_vectors_batch_prepares_entities_concurrently(self, monkeypatch): + repo = _make_repo( + semantic_enabled=True, + embedding_provider=StubEmbeddingProvider(), + ) + repo._semantic_embedding_sync_batch_size = 8 + repo._vector_tables_initialized = True + + active_prepares = 0 + max_active_prepares = 0 + + async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: + nonlocal active_prepares, max_active_prepares + active_prepares += 1 + max_active_prepares = max(max_active_prepares, active_prepares) + await asyncio.sleep(0) + active_prepares -= 1 + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=float(entity_id), + source_rows_count=1, + embedding_jobs=[], + ) + + monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + + result = await repo.sync_entity_vectors_batch([1, 2, 3, 4]) + + assert result.entities_total == 4 + assert result.entities_synced == 4 + assert result.entities_failed == 0 + assert max_active_prepares > 1 From 901eb009febf44432fbec3f036d745d09bab6a33 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 11:16:31 -0500 Subject: [PATCH 02/10] update claude settings Signed-off-by: phernandez --- .claude/settings.json | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/.claude/settings.json b/.claude/settings.json index c72c6b731..54c66f76b 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -1,3 +1,22 @@ { - "enabledPlugins": {} + "$schema": "https://json.schemastore.org/claude-code-settings.json", + "env": { + "CLAUDE_BASH_MAINTAIN_PROJECT_WORKING_DIR": "1", + "CLAUDE_CODE_DISABLE_FEEDBACK_SURVEY": "1", + "DISABLE_TELEMETRY": "1", + "CLAUDE_CODE_NO_FLICKER": "1", + "CLAUDE_CODE_DISABLE_ADAPTIVE_THINKING": "1" + }, + "permissions": { + "allow": [ + "Bash(just fast-check)", + "Bash(just check)", + "Bash(just fix)", + "Bash(just typecheck)", + "Bash(just lint)", + "Bash(just test)" + ], + "deny": [] + }, + "enableAllProjectMcpServers": true } From 3175e7c50647745f9a1ad10993b75dfd7935d91a Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 11:17:07 -0500 Subject: [PATCH 03/10] refactor timing for indexing Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 27 +++++- .../repository/search_repository_base.py | 89 ++++++++++++++++--- test-int/semantic/test_semantic_quality.py | 5 +- .../test_postgres_search_repository_unit.py | 65 ++++++++++++++ tests/repository/test_semantic_search_base.py | 65 ++++++++++++++ 5 files changed, 232 insertions(+), 19 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index 925c975f5..2af9fc89b 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -522,13 +522,18 @@ async def sync_entity_vectors_batch( embedding_jobs_count = len(prepared_sync.embedding_jobs) result.embedding_jobs_total += embedding_jobs_count + result.prepare_seconds_total += prepared_sync.prepare_seconds if embedding_jobs_count == 0: synced_entity_ids.add(entity_id) total_seconds = time.perf_counter() - prepared_sync.sync_start + queue_wait_seconds = max(0.0, total_seconds - prepared_sync.prepare_seconds) + result.queue_wait_seconds_total += queue_wait_seconds self._log_vector_sync_complete( entity_id=entity_id, total_seconds=total_seconds, + prepare_seconds=prepared_sync.prepare_seconds, + queue_wait_seconds=queue_wait_seconds, embed_seconds=0.0, write_seconds=0.0, source_rows_count=prepared_sync.source_rows_count, @@ -541,6 +546,7 @@ async def sync_entity_vectors_batch( source_rows_count=prepared_sync.source_rows_count, embedding_jobs_count=embedding_jobs_count, remaining_jobs=embedding_jobs_count, + prepare_seconds=prepared_sync.prepare_seconds, ) pending_jobs.extend( _PendingEmbeddingJob( @@ -562,6 +568,10 @@ async def sync_entity_vectors_batch( ) result.embed_seconds_total += embed_seconds result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) except Exception as exc: affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) @@ -588,6 +598,10 @@ async def sync_entity_vectors_batch( ) result.embed_seconds_total += embed_seconds result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) except Exception as exc: affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) @@ -620,13 +634,16 @@ async def sync_entity_vectors_batch( logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " - "embedding_jobs_total={embedding_jobs_total} embed_seconds_total={embed_seconds_total:.3f} " - "write_seconds_total={write_seconds_total:.3f}", + "embedding_jobs_total={embedding_jobs_total} prepare_seconds_total={prepare_seconds_total:.3f} " + "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " + "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", project_id=self.project_id, entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, embedding_jobs_total=result.embedding_jobs_total, + prepare_seconds_total=result.prepare_seconds_total, + queue_wait_seconds_total=result.queue_wait_seconds_total, embed_seconds_total=result.embed_seconds_total, write_seconds_total=result.write_seconds_total, ) @@ -679,11 +696,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe source_rows_count=source_rows_count, ) await self._delete_entity_chunks(session, entity_id) + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=[], + prepare_seconds=prepare_seconds, ) chunk_records = self._build_chunk_records(rows) @@ -699,11 +718,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) if not chunk_records: await self._delete_entity_chunks(session, entity_id) + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=[], + prepare_seconds=prepare_seconds, ) existing_rows_result = await session.execute( @@ -807,11 +828,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe embedding_jobs_count=len(embedding_jobs), ) + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=embedding_jobs, + prepare_seconds=prepare_seconds, ) async def _write_embeddings( diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 0e420c1a6..924eea0d8 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -44,6 +44,8 @@ class VectorSyncBatchResult: entities_failed: int failed_entity_ids: list[int] = field(default_factory=list) embedding_jobs_total: int = 0 + prepare_seconds_total: float = 0.0 + queue_wait_seconds_total: float = 0.0 embed_seconds_total: float = 0.0 write_seconds_total: float = 0.0 @@ -56,6 +58,7 @@ class _PreparedEntityVectorSync: sync_start: float source_rows_count: int embedding_jobs: list[tuple[int, str]] + prepare_seconds: float = 0.0 @dataclass @@ -75,6 +78,7 @@ class _EntitySyncRuntime: source_rows_count: int embedding_jobs_count: int remaining_jobs: int + prepare_seconds: float = 0.0 embed_seconds: float = 0.0 write_seconds: float = 0.0 @@ -696,13 +700,18 @@ async def _sync_entity_vectors_internal( embedding_jobs_count = len(prepared.embedding_jobs) result.embedding_jobs_total += embedding_jobs_count + result.prepare_seconds_total += prepared.prepare_seconds if embedding_jobs_count == 0: synced_entity_ids.add(entity_id) total_seconds = time.perf_counter() - prepared.sync_start + queue_wait_seconds = max(0.0, total_seconds - prepared.prepare_seconds) + result.queue_wait_seconds_total += queue_wait_seconds self._log_vector_sync_complete( entity_id=entity_id, total_seconds=total_seconds, + prepare_seconds=prepared.prepare_seconds, + queue_wait_seconds=queue_wait_seconds, embed_seconds=0.0, write_seconds=0.0, source_rows_count=prepared.source_rows_count, @@ -715,6 +724,7 @@ async def _sync_entity_vectors_internal( source_rows_count=prepared.source_rows_count, embedding_jobs_count=embedding_jobs_count, remaining_jobs=embedding_jobs_count, + prepare_seconds=prepared.prepare_seconds, ) pending_jobs.extend( _PendingEmbeddingJob( @@ -734,6 +744,10 @@ async def _sync_entity_vectors_internal( ) result.embed_seconds_total += embed_seconds result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) except Exception as exc: if not continue_on_error: raise @@ -761,6 +775,10 @@ async def _sync_entity_vectors_internal( ) result.embed_seconds_total += embed_seconds result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) except Exception as exc: if not continue_on_error: raise @@ -799,13 +817,16 @@ async def _sync_entity_vectors_internal( logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " - "embedding_jobs_total={embedding_jobs_total} embed_seconds_total={embed_seconds_total:.3f} " - "write_seconds_total={write_seconds_total:.3f}", + "embedding_jobs_total={embedding_jobs_total} prepare_seconds_total={prepare_seconds_total:.3f} " + "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " + "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", project_id=self.project_id, entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, embedding_jobs_total=result.embedding_jobs_total, + prepare_seconds_total=result.prepare_seconds_total, + queue_wait_seconds_total=result.queue_wait_seconds_total, embed_seconds_total=result.embed_seconds_total, write_seconds_total=result.write_seconds_total, ) @@ -863,11 +884,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) await self._delete_entity_chunks(session, entity_id) await session.commit() + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=[], + prepare_seconds=prepare_seconds, ) chunk_records = self._build_chunk_records(rows) @@ -884,11 +907,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe if not chunk_records: await self._delete_entity_chunks(session, entity_id) await session.commit() + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=[], + prepare_seconds=prepare_seconds, ) # --- Diff existing chunks against incoming --- @@ -994,11 +1019,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) await session.commit() + prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=embedding_jobs, + prepare_seconds=prepare_seconds, ) async def _flush_embedding_jobs( @@ -1063,24 +1090,52 @@ async def _flush_embedding_jobs( if runtime.remaining_jobs <= 0: synced_entity_ids.add(entity_id) - total_seconds = time.perf_counter() - runtime.sync_start - self._log_vector_sync_complete( - entity_id=entity_id, - total_seconds=total_seconds, - embed_seconds=runtime.embed_seconds, - write_seconds=runtime.write_seconds, - source_rows_count=runtime.source_rows_count, - embedding_jobs_count=runtime.embedding_jobs_count, - ) - entity_runtime.pop(entity_id, None) return embed_seconds, write_seconds + def _finalize_completed_entity_syncs( + self, + *, + entity_runtime: dict[int, _EntitySyncRuntime], + synced_entity_ids: set[int], + ) -> float: + """Finalize completed entities and return cumulative queue wait seconds.""" + queue_wait_seconds_total = 0.0 + for entity_id, runtime in list(entity_runtime.items()): + if runtime.remaining_jobs > 0: + continue + + synced_entity_ids.add(entity_id) + total_seconds = time.perf_counter() - runtime.sync_start + queue_wait_seconds = max( + 0.0, + total_seconds + - runtime.prepare_seconds + - runtime.embed_seconds + - runtime.write_seconds, + ) + queue_wait_seconds_total += queue_wait_seconds + self._log_vector_sync_complete( + entity_id=entity_id, + total_seconds=total_seconds, + prepare_seconds=runtime.prepare_seconds, + queue_wait_seconds=queue_wait_seconds, + embed_seconds=runtime.embed_seconds, + write_seconds=runtime.write_seconds, + source_rows_count=runtime.source_rows_count, + embedding_jobs_count=runtime.embedding_jobs_count, + ) + entity_runtime.pop(entity_id, None) + + return queue_wait_seconds_total + def _log_vector_sync_complete( self, *, entity_id: int, total_seconds: float, + prepare_seconds: float, + queue_wait_seconds: float, embed_seconds: float, write_seconds: float, source_rows_count: int, @@ -1089,12 +1144,15 @@ def _log_vector_sync_complete( """Log completion and slow-entity warnings with a consistent format.""" logger.info( "Vector sync complete: project_id={project_id} entity_id={entity_id} " - "total_seconds={total_seconds:.3f} embed_seconds={embed_seconds:.3f} " + "total_seconds={total_seconds:.3f} prepare_seconds={prepare_seconds:.3f} " + "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " "embedding_jobs_count={embedding_jobs_count}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, + prepare_seconds=prepare_seconds, + queue_wait_seconds=queue_wait_seconds, embed_seconds=embed_seconds, write_seconds=write_seconds, source_rows_count=source_rows_count, @@ -1103,12 +1161,15 @@ def _log_vector_sync_complete( if total_seconds > 10: logger.warning( "Vector sync slow entity: project_id={project_id} entity_id={entity_id} " - "total_seconds={total_seconds:.3f} embed_seconds={embed_seconds:.3f} " + "total_seconds={total_seconds:.3f} prepare_seconds={prepare_seconds:.3f} " + "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " "embedding_jobs_count={embedding_jobs_count}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, + prepare_seconds=prepare_seconds, + queue_wait_seconds=queue_wait_seconds, embed_seconds=embed_seconds, write_seconds=write_seconds, source_rows_count=source_rows_count, diff --git a/test-int/semantic/test_semantic_quality.py b/test-int/semantic/test_semantic_quality.py index 7e4a7a2f0..a74292fe2 100644 --- a/test-int/semantic/test_semantic_quality.py +++ b/test-int/semantic/test_semantic_quality.py @@ -51,9 +51,8 @@ ("sqlite-fastembed", "paraphrase", "hybrid"): 0.25, ("postgres-fastembed", "lexical", "hybrid"): 0.37, ("postgres-fastembed", "paraphrase", "hybrid"): 0.25, - # OpenAI hybrid should handle paraphrases better than FastEmbed - ("postgres-openai", "lexical", "hybrid"): 0.37, - ("postgres-openai", "paraphrase", "hybrid"): 0.25, + # OpenAI metrics are still recorded, but we do not gate on them yet. + # The current benchmark corpus is too small to make that combo stable. } diff --git a/tests/repository/test_postgres_search_repository_unit.py b/tests/repository/test_postgres_search_repository_unit.py index 9a16adb59..4a331e0de 100644 --- a/tests/repository/test_postgres_search_repository_unit.py +++ b/tests/repository/test_postgres_search_repository_unit.py @@ -10,6 +10,7 @@ import pytest +import basic_memory.repository.search_repository_base as search_repository_base_module from basic_memory.config import BasicMemoryConfig, DatabaseBackend from basic_memory.repository.postgres_search_repository import PostgresSearchRepository from basic_memory.repository.search_repository_base import _PreparedEntityVectorSync @@ -283,3 +284,67 @@ async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: assert result.entities_synced == 4 assert result.entities_failed == 0 assert max_active_prepares > 1 + + +@pytest.mark.asyncio +async def test_postgres_batch_sync_tracks_prepare_and_queue_wait(monkeypatch): + """Postgres batch sync should separate queue wait from prepare/embed/write.""" + repo = _make_repo( + semantic_enabled=True, + embedding_provider=StubEmbeddingProvider(), + ) + repo._semantic_embedding_sync_batch_size = 2 + repo._vector_tables_initialized = True + + async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(200 + entity_id, f"chunk-{entity_id}")], + prepare_seconds=1.0, + ) + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + for job in flush_jobs: + runtime = entity_runtime[job.entity_id] + if job.entity_id == 1: + runtime.embed_seconds = 1.0 + runtime.write_seconds = 0.5 + else: + runtime.embed_seconds = 2.0 + runtime.write_seconds = 0.5 + runtime.remaining_jobs = 0 + synced_entity_ids.add(job.entity_id) + return (3.0, 1.0) + + completion_records: list[dict] = [] + + def _capture_log(**kwargs): + completion_records.append(kwargs) + + perf_counter_values = iter([4.0, 5.0]) + + monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) + monkeypatch.setattr( + search_repository_base_module.time, + "perf_counter", + lambda: next(perf_counter_values), + ) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_total == 2 + assert result.entities_synced == 2 + assert result.entities_failed == 0 + assert result.prepare_seconds_total == pytest.approx(2.0) + assert result.queue_wait_seconds_total == pytest.approx(3.0) + assert result.embed_seconds_total == pytest.approx(3.0) + assert result.write_seconds_total == pytest.approx(1.0) + assert len(completion_records) == 2 + for record in completion_records: + assert record["prepare_seconds"] == pytest.approx(1.0) + assert record["queue_wait_seconds"] == pytest.approx(1.5) diff --git a/tests/repository/test_semantic_search_base.py b/tests/repository/test_semantic_search_base.py index 5cac8afc9..7c66a91fd 100644 --- a/tests/repository/test_semantic_search_base.py +++ b/tests/repository/test_semantic_search_base.py @@ -8,6 +8,7 @@ import pytest +import basic_memory.repository.search_repository_base as search_repository_base_module from basic_memory.repository.search_repository_base import ( MAX_VECTOR_CHUNK_CHARS, SearchRepositoryBase, @@ -335,6 +336,8 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert result.entities_failed == 0 assert result.failed_entity_ids == [] assert result.embedding_jobs_total == 3 + assert result.prepare_seconds_total == pytest.approx(0.0) + assert result.queue_wait_seconds_total == pytest.approx(0.0) assert result.embed_seconds_total == pytest.approx(0.2) assert result.write_seconds_total == pytest.approx(0.4) @@ -373,3 +376,65 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert result.entities_synced == 1 assert result.entities_failed == 2 assert result.failed_entity_ids == [2, 3] + + +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_tracks_prepare_and_queue_wait_seconds(monkeypatch): + """Queue wait should be reported separately from prepare/embed/write timings.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = object() + repo._semantic_embedding_sync_batch_size = 2 + + async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(100 + entity_id, f"chunk-{entity_id}")], + prepare_seconds=1.0, + ) + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + assert len(flush_jobs) == 2 + for job in flush_jobs: + runtime = entity_runtime[job.entity_id] + if job.entity_id == 1: + runtime.embed_seconds = 1.0 + runtime.write_seconds = 0.5 + else: + runtime.embed_seconds = 2.0 + runtime.write_seconds = 0.5 + runtime.remaining_jobs = 0 + synced_entity_ids.add(job.entity_id) + return (3.0, 1.0) + + logged_completion: list[dict] = [] + + def _capture_log(**kwargs): + logged_completion.append(kwargs) + + perf_counter_values = iter([4.0, 5.0]) + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) + monkeypatch.setattr( + search_repository_base_module.time, + "perf_counter", + lambda: next(perf_counter_values), + ) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_total == 2 + assert result.entities_synced == 2 + assert result.entities_failed == 0 + assert result.prepare_seconds_total == pytest.approx(2.0) + assert result.queue_wait_seconds_total == pytest.approx(3.0) + assert result.embed_seconds_total == pytest.approx(3.0) + assert result.write_seconds_total == pytest.approx(1.0) + assert len(logged_completion) == 2 + for record in logged_completion: + assert record["prepare_seconds"] == pytest.approx(1.0) + assert record["queue_wait_seconds"] == pytest.approx(1.5) From 9738267b206b8b754ff21df1f6d3cd744fbf14e9 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 11:41:09 -0500 Subject: [PATCH 04/10] perf(core): skip unchanged vector sync work Signed-off-by: phernandez --- ...h7i8j9k0l1_add_vector_sync_fingerprints.py | 84 +++++++++ src/basic_memory/models/search.py | 4 + .../repository/postgres_search_repository.py | 139 ++++++++++++++- .../repository/search_repository_base.py | 165 +++++++++++++++++- .../repository/sqlite_search_repository.py | 2 + .../test_postgres_search_repository.py | 110 ++++++++++++ .../test_sqlite_vector_search_repository.py | 88 +++++++++- 7 files changed, 575 insertions(+), 17 deletions(-) create mode 100644 src/basic_memory/alembic/versions/m6h7i8j9k0l1_add_vector_sync_fingerprints.py diff --git a/src/basic_memory/alembic/versions/m6h7i8j9k0l1_add_vector_sync_fingerprints.py b/src/basic_memory/alembic/versions/m6h7i8j9k0l1_add_vector_sync_fingerprints.py new file mode 100644 index 000000000..abea11b79 --- /dev/null +++ b/src/basic_memory/alembic/versions/m6h7i8j9k0l1_add_vector_sync_fingerprints.py @@ -0,0 +1,84 @@ +"""Persist vector sync fingerprints on chunk metadata. + +Revision ID: m6h7i8j9k0l1 +Revises: l5g6h7i8j9k0 +Create Date: 2026-04-07 00:00:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "m6h7i8j9k0l1" +down_revision: Union[str, None] = "l5g6h7i8j9k0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add entity fingerprint + embedding model metadata to Postgres chunk rows. + + Trigger: vector sync now fast-skips unchanged entities using persisted + semantic fingerprints. + Why: chunk rows already own the per-entity derived metadata we diff against, + so persisting the fingerprint on that table avoids a second sync-state table. + Outcome: existing rows get empty-string placeholders and will be refreshed on + the next vector sync before they become eligible for skip checks. + """ + connection = op.get_bind() + if connection.dialect.name != "postgresql": + return + + op.execute( + """ + ALTER TABLE search_vector_chunks + ADD COLUMN IF NOT EXISTS entity_fingerprint TEXT + """ + ) + op.execute( + """ + ALTER TABLE search_vector_chunks + ADD COLUMN IF NOT EXISTS embedding_model TEXT + """ + ) + op.execute( + """ + UPDATE search_vector_chunks + SET entity_fingerprint = COALESCE(entity_fingerprint, ''), + embedding_model = COALESCE(embedding_model, '') + """ + ) + op.execute( + """ + ALTER TABLE search_vector_chunks + ALTER COLUMN entity_fingerprint SET NOT NULL + """ + ) + op.execute( + """ + ALTER TABLE search_vector_chunks + ALTER COLUMN embedding_model SET NOT NULL + """ + ) + + +def downgrade() -> None: + """Remove vector sync fingerprint columns from Postgres chunk rows.""" + connection = op.get_bind() + if connection.dialect.name != "postgresql": + return + + op.execute( + """ + ALTER TABLE search_vector_chunks + DROP COLUMN IF EXISTS embedding_model + """ + ) + op.execute( + """ + ALTER TABLE search_vector_chunks + DROP COLUMN IF EXISTS entity_fingerprint + """ + ) diff --git a/src/basic_memory/models/search.py b/src/basic_memory/models/search.py index cb5d6f30d..75b2c67d5 100644 --- a/src/basic_memory/models/search.py +++ b/src/basic_memory/models/search.py @@ -104,6 +104,8 @@ chunk_key TEXT NOT NULL, chunk_text TEXT NOT NULL, source_hash TEXT NOT NULL, + entity_fingerprint TEXT NOT NULL, + embedding_model TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE (project_id, entity_id, chunk_key) ) @@ -124,6 +126,8 @@ chunk_key TEXT NOT NULL, chunk_text TEXT NOT NULL, source_hash TEXT NOT NULL, + entity_fingerprint TEXT NOT NULL, + embedding_model TEXT NOT NULL, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ) """) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index 2af9fc89b..fd0a2c652 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -305,6 +305,8 @@ async def _ensure_vector_tables(self) -> None: chunk_key TEXT NOT NULL, chunk_text TEXT NOT NULL, source_hash TEXT NOT NULL, + entity_fingerprint TEXT NOT NULL, + embedding_model TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE (project_id, entity_id, chunk_key) ) @@ -319,6 +321,47 @@ async def _ensure_vector_tables(self) -> None: """ ) ) + await session.execute( + text( + """ + ALTER TABLE search_vector_chunks + ADD COLUMN IF NOT EXISTS entity_fingerprint TEXT + """ + ) + ) + await session.execute( + text( + """ + ALTER TABLE search_vector_chunks + ADD COLUMN IF NOT EXISTS embedding_model TEXT + """ + ) + ) + await session.execute( + text( + """ + UPDATE search_vector_chunks + SET entity_fingerprint = COALESCE(entity_fingerprint, ''), + embedding_model = COALESCE(embedding_model, '') + """ + ) + ) + await session.execute( + text( + """ + ALTER TABLE search_vector_chunks + ALTER COLUMN entity_fingerprint SET NOT NULL + """ + ) + ) + await session.execute( + text( + """ + ALTER TABLE search_vector_chunks + ALTER COLUMN embedding_model SET NOT NULL + """ + ) + ) # --- Embeddings table (dimension-dependent, created at runtime) --- # Trigger: provider dimensions may differ from what was previously deployed. @@ -521,6 +564,10 @@ async def sync_entity_vectors_batch( prepared_sync = cast(_PreparedEntityVectorSync, prepared) embedding_jobs_count = len(prepared_sync.embedding_jobs) + result.chunks_total += prepared_sync.chunks_total + result.chunks_skipped += prepared_sync.chunks_skipped + if prepared_sync.entity_skipped: + result.entities_skipped += 1 result.embedding_jobs_total += embedding_jobs_count result.prepare_seconds_total += prepared_sync.prepare_seconds @@ -537,7 +584,10 @@ async def sync_entity_vectors_batch( embed_seconds=0.0, write_seconds=0.0, source_rows_count=prepared_sync.source_rows_count, + chunks_total=prepared_sync.chunks_total, + chunks_skipped=prepared_sync.chunks_skipped, embedding_jobs_count=0, + entity_skipped=prepared_sync.entity_skipped, ) continue @@ -546,6 +596,9 @@ async def sync_entity_vectors_batch( source_rows_count=prepared_sync.source_rows_count, embedding_jobs_count=embedding_jobs_count, remaining_jobs=embedding_jobs_count, + chunks_total=prepared_sync.chunks_total, + chunks_skipped=prepared_sync.chunks_skipped, + entity_skipped=prepared_sync.entity_skipped, prepare_seconds=prepared_sync.prepare_seconds, ) pending_jobs.extend( @@ -634,13 +687,18 @@ async def sync_entity_vectors_batch( logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " - "embedding_jobs_total={embedding_jobs_total} prepare_seconds_total={prepare_seconds_total:.3f} " + "entities_skipped={entities_skipped} chunks_total={chunks_total} " + "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " + "prepare_seconds_total={prepare_seconds_total:.3f} " "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", project_id=self.project_id, entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, + entities_skipped=result.entities_skipped, + chunks_total=result.chunks_total, + chunks_skipped=result.chunks_skipped, embedding_jobs_total=result.embedding_jobs_total, prepare_seconds_total=result.prepare_seconds_total, queue_wait_seconds_total=result.queue_wait_seconds_total, @@ -707,6 +765,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe chunk_records = self._build_chunk_records(rows) built_chunk_records_count = len(chunk_records) + current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) + current_embedding_model = self._embedding_model_key() logger.info( "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " "source_rows_count={source_rows_count} " @@ -729,7 +789,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe existing_rows_result = await session.execute( text( - "SELECT c.id, c.chunk_key, c.source_hash, " + "SELECT c.id, c.chunk_key, c.source_hash, c.entity_fingerprint, " + "c.embedding_model, " "(e.chunk_id IS NOT NULL) AS has_embedding " "FROM search_vector_chunks c " "LEFT JOIN search_vector_embeddings e ON e.chunk_id = c.id " @@ -754,8 +815,43 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe orphan_ids = {int(row["id"]) for row in existing_rows if not bool(row["has_embedding"])} orphan_chunks_count = len(orphan_ids) + skip_unchanged_entity = ( + existing_chunks_count == built_chunk_records_count + and stale_chunks_count == 0 + and orphan_chunks_count == 0 + and existing_chunks_count > 0 + and all( + row["entity_fingerprint"] == current_entity_fingerprint + and row["embedding_model"] == current_embedding_model + for row in existing_rows + ) + ) + if skip_unchanged_entity: + logger.info( + "Vector sync skipped unchanged entity: project_id={project_id} " + "entity_id={entity_id} chunks_skipped={chunks_skipped} " + "entity_fingerprint={entity_fingerprint} embedding_model={embedding_model}", + project_id=self.project_id, + entity_id=entity_id, + chunks_skipped=built_chunk_records_count, + entity_fingerprint=current_entity_fingerprint, + embedding_model=current_embedding_model, + ) + prepare_seconds = time.perf_counter() - sync_start + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + chunks_total=built_chunk_records_count, + chunks_skipped=built_chunk_records_count, + entity_skipped=True, + prepare_seconds=prepare_seconds, + ) + upsert_records: list[dict[str, str]] = [] embedding_jobs: list[tuple[int, str]] = [] + skipped_chunks_count = 0 for record in chunk_records: current = existing_by_key.get(record["chunk_key"]) @@ -765,9 +861,29 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe row_id = int(current["id"]) is_orphan = row_id in orphan_ids - if current["source_hash"] == record["source_hash"]: - if is_orphan: - embedding_jobs.append((row_id, record["chunk_text"])) + same_source_hash = current["source_hash"] == record["source_hash"] + same_entity_fingerprint = ( + current["entity_fingerprint"] == current_entity_fingerprint + ) + same_embedding_model = current["embedding_model"] == current_embedding_model + + if same_source_hash and not is_orphan and same_embedding_model: + if not same_entity_fingerprint: + await session.execute( + text( + "UPDATE search_vector_chunks " + "SET entity_fingerprint = :entity_fingerprint, " + "embedding_model = :embedding_model, " + "updated_at = NOW() " + "WHERE id = :id" + ), + { + "id": row_id, + "entity_fingerprint": current_entity_fingerprint, + "embedding_model": current_embedding_model, + }, + ) + skipped_chunks_count += 1 continue upsert_records.append(record) @@ -782,10 +898,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe upsert_params[f"chunk_key_{index}"] = record["chunk_key"] upsert_params[f"chunk_text_{index}"] = record["chunk_text"] upsert_params[f"source_hash_{index}"] = record["source_hash"] + upsert_params[f"entity_fingerprint_{index}"] = current_entity_fingerprint + upsert_params[f"embedding_model_{index}"] = current_embedding_model upsert_values.append( "(" ":entity_id, :project_id, " - f":chunk_key_{index}, :chunk_text_{index}, :source_hash_{index}, NOW()" + f":chunk_key_{index}, :chunk_text_{index}, :source_hash_{index}, " + f":entity_fingerprint_{index}, :embedding_model_{index}, NOW()" ")" ) @@ -797,11 +916,15 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe chunk_key, chunk_text, source_hash, + entity_fingerprint, + embedding_model, updated_at ) VALUES {", ".join(upsert_values)} ON CONFLICT (project_id, entity_id, chunk_key) DO UPDATE SET chunk_text = EXCLUDED.chunk_text, source_hash = EXCLUDED.source_hash, + entity_fingerprint = EXCLUDED.entity_fingerprint, + embedding_model = EXCLUDED.embedding_model, updated_at = NOW() RETURNING id, chunk_key """), @@ -819,12 +942,14 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "existing_chunks_count={existing_chunks_count} " "stale_chunks_count={stale_chunks_count} " "orphan_chunks_count={orphan_chunks_count} " + "chunks_skipped={chunks_skipped} " "embedding_jobs_count={embedding_jobs_count}", project_id=self.project_id, entity_id=entity_id, existing_chunks_count=existing_chunks_count, stale_chunks_count=stale_chunks_count, orphan_chunks_count=orphan_chunks_count, + chunks_skipped=skipped_chunks_count, embedding_jobs_count=len(embedding_jobs), ) @@ -834,6 +959,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=embedding_jobs, + chunks_total=built_chunk_records_count, + chunks_skipped=skipped_chunks_count, prepare_seconds=prepare_seconds, ) diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 924eea0d8..7d7deb4e1 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -42,7 +42,10 @@ class VectorSyncBatchResult: entities_total: int entities_synced: int entities_failed: int + entities_skipped: int = 0 failed_entity_ids: list[int] = field(default_factory=list) + chunks_total: int = 0 + chunks_skipped: int = 0 embedding_jobs_total: int = 0 prepare_seconds_total: float = 0.0 queue_wait_seconds_total: float = 0.0 @@ -58,6 +61,9 @@ class _PreparedEntityVectorSync: sync_start: float source_rows_count: int embedding_jobs: list[tuple[int, str]] + chunks_total: int = 0 + chunks_skipped: int = 0 + entity_skipped: bool = False prepare_seconds: float = 0.0 @@ -78,6 +84,9 @@ class _EntitySyncRuntime: source_rows_count: int embedding_jobs_count: int remaining_jobs: int + chunks_total: int = 0 + chunks_skipped: int = 0 + entity_skipped: bool = False prepare_seconds: float = 0.0 embed_seconds: float = 0.0 write_seconds: float = 0.0 @@ -486,6 +495,35 @@ def _build_chunk_records(self, rows) -> list[dict[str, str]]: return list(records_by_key.values()) + def _build_entity_fingerprint(self, chunk_records: list[dict[str, str]]) -> str: + """Hash the semantic chunk inputs for one entity. + + Trigger: vector sync eligibility depends on the chunk records derived + from search_index rows, not raw file bytes. + Why: title/permalink/observation metadata can change vector inputs even + when unrelated file bytes do not, and vice versa. + Outcome: one deterministic fingerprint invalidates the entity-level skip + whenever the embeddable chunk set changes. + """ + canonical_records = [ + { + "chunk_key": record["chunk_key"], + "source_hash": record["source_hash"], + } + for record in sorted(chunk_records, key=lambda record: record["chunk_key"]) + ] + payload = json.dumps(canonical_records, separators=(",", ":"), sort_keys=True) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + def _embedding_model_key(self) -> str: + """Build a stable model identity for vector invalidation checks.""" + assert self._embedding_provider is not None + return ( + f"{type(self._embedding_provider).__name__}:" + f"{self._embedding_provider.model_name}:" + f"{self._embedding_provider.dimensions}" + ) + # --- Text splitting --- def _split_text_into_chunks(self, text_value: str) -> list[str]: @@ -699,6 +737,10 @@ async def _sync_entity_vectors_internal( continue embedding_jobs_count = len(prepared.embedding_jobs) + result.chunks_total += prepared.chunks_total + result.chunks_skipped += prepared.chunks_skipped + if prepared.entity_skipped: + result.entities_skipped += 1 result.embedding_jobs_total += embedding_jobs_count result.prepare_seconds_total += prepared.prepare_seconds @@ -715,7 +757,10 @@ async def _sync_entity_vectors_internal( embed_seconds=0.0, write_seconds=0.0, source_rows_count=prepared.source_rows_count, + chunks_total=prepared.chunks_total, + chunks_skipped=prepared.chunks_skipped, embedding_jobs_count=0, + entity_skipped=prepared.entity_skipped, ) continue @@ -724,6 +769,9 @@ async def _sync_entity_vectors_internal( source_rows_count=prepared.source_rows_count, embedding_jobs_count=embedding_jobs_count, remaining_jobs=embedding_jobs_count, + chunks_total=prepared.chunks_total, + chunks_skipped=prepared.chunks_skipped, + entity_skipped=prepared.entity_skipped, prepare_seconds=prepared.prepare_seconds, ) pending_jobs.extend( @@ -817,13 +865,18 @@ async def _sync_entity_vectors_internal( logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " - "embedding_jobs_total={embedding_jobs_total} prepare_seconds_total={prepare_seconds_total:.3f} " + "entities_skipped={entities_skipped} chunks_total={chunks_total} " + "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " + "prepare_seconds_total={prepare_seconds_total:.3f} " "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", project_id=self.project_id, entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, + entities_skipped=result.entities_skipped, + chunks_total=result.chunks_total, + chunks_skipped=result.chunks_skipped, embedding_jobs_total=result.embedding_jobs_total, prepare_seconds_total=result.prepare_seconds_total, queue_wait_seconds_total=result.queue_wait_seconds_total, @@ -895,6 +948,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe chunk_records = self._build_chunk_records(rows) built_chunk_records_count = len(chunk_records) + current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) + current_embedding_model = self._embedding_model_key() logger.info( "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " "source_rows_count={source_rows_count} " @@ -919,7 +974,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe # --- Diff existing chunks against incoming --- existing_rows_result = await session.execute( text( - "SELECT id, chunk_key, source_hash " + "SELECT id, chunk_key, source_hash, entity_fingerprint, embedding_model " "FROM search_vector_chunks " "WHERE project_id = :project_id AND entity_id = :entity_id" ), @@ -952,9 +1007,49 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe orphan_ids = {int(row.id) for row in orphan_rows} orphan_chunks_count = len(orphan_ids) + # Trigger: the persisted chunk metadata exactly matches the current + # semantic fingerprint/model and every chunk still has an embedding. + # Why: full reindex and embeddings-only runs should avoid reopening + # the expensive chunk diff + embed path for unchanged entities. + # Outcome: return immediately with skip counters and no writes. + skip_unchanged_entity = ( + existing_chunks_count == built_chunk_records_count + and stale_chunks_count == 0 + and orphan_chunks_count == 0 + and existing_chunks_count > 0 + and all( + row.entity_fingerprint == current_entity_fingerprint + and row.embedding_model == current_embedding_model + for row in existing_by_key.values() + ) + ) + if skip_unchanged_entity: + logger.info( + "Vector sync skipped unchanged entity: project_id={project_id} " + "entity_id={entity_id} chunks_skipped={chunks_skipped} " + "entity_fingerprint={entity_fingerprint} embedding_model={embedding_model}", + project_id=self.project_id, + entity_id=entity_id, + chunks_skipped=built_chunk_records_count, + entity_fingerprint=current_entity_fingerprint, + embedding_model=current_embedding_model, + ) + prepare_seconds = time.perf_counter() - sync_start + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + chunks_total=built_chunk_records_count, + chunks_skipped=built_chunk_records_count, + entity_skipped=True, + prepare_seconds=prepare_seconds, + ) + # --- Upsert changed / new chunks, collect embedding jobs --- timestamp_expr = self._timestamp_now_expr() embedding_jobs: list[tuple[int, str]] = [] + skipped_chunks_count = 0 for record in chunk_records: current = existing_by_key.get(record["chunk_key"]) @@ -962,16 +1057,44 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe # but chunk has no embedding (orphan from crash). # Outcome: schedule re-embedding without touching chunk metadata. is_orphan = current and int(current.id) in orphan_ids - if current and current.source_hash == record["source_hash"] and not is_orphan: - continue - if current: row_id = int(current.id) - if current.source_hash != record["source_hash"]: + same_source_hash = current.source_hash == record["source_hash"] + same_entity_fingerprint = ( + current.entity_fingerprint == current_entity_fingerprint + ) + same_embedding_model = current.embedding_model == current_embedding_model + + if same_source_hash and not is_orphan and same_embedding_model: + if not same_entity_fingerprint: + await session.execute( + text( + "UPDATE search_vector_chunks " + "SET entity_fingerprint = :entity_fingerprint, " + "embedding_model = :embedding_model, " + f"updated_at = {timestamp_expr} " + "WHERE id = :id" + ), + { + "id": row_id, + "entity_fingerprint": current_entity_fingerprint, + "embedding_model": current_embedding_model, + }, + ) + skipped_chunks_count += 1 + continue + + if ( + current.source_hash != record["source_hash"] + or current.entity_fingerprint != current_entity_fingerprint + or current.embedding_model != current_embedding_model + ): await session.execute( text( "UPDATE search_vector_chunks " "SET chunk_text = :chunk_text, source_hash = :source_hash, " + "entity_fingerprint = :entity_fingerprint, " + "embedding_model = :embedding_model, " f"updated_at = {timestamp_expr} " "WHERE id = :id" ), @@ -979,6 +1102,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "id": row_id, "chunk_text": record["chunk_text"], "source_hash": record["source_hash"], + "entity_fingerprint": current_entity_fingerprint, + "embedding_model": current_embedding_model, }, ) embedding_jobs.append((row_id, record["chunk_text"])) @@ -987,9 +1112,11 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe inserted = await session.execute( text( "INSERT INTO search_vector_chunks (" - "entity_id, project_id, chunk_key, chunk_text, source_hash, updated_at" + "entity_id, project_id, chunk_key, chunk_text, source_hash, " + "entity_fingerprint, embedding_model, updated_at" ") VALUES (" f":entity_id, :project_id, :chunk_key, :chunk_text, :source_hash, " + ":entity_fingerprint, :embedding_model, " f"{timestamp_expr}" ") RETURNING id" ), @@ -999,6 +1126,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "chunk_key": record["chunk_key"], "chunk_text": record["chunk_text"], "source_hash": record["source_hash"], + "entity_fingerprint": current_entity_fingerprint, + "embedding_model": current_embedding_model, }, ) row_id = int(inserted.scalar_one()) @@ -1009,12 +1138,14 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "existing_chunks_count={existing_chunks_count} " "stale_chunks_count={stale_chunks_count} " "orphan_chunks_count={orphan_chunks_count} " + "chunks_skipped={chunks_skipped} " "embedding_jobs_count={embedding_jobs_count}", project_id=self.project_id, entity_id=entity_id, existing_chunks_count=existing_chunks_count, stale_chunks_count=stale_chunks_count, orphan_chunks_count=orphan_chunks_count, + chunks_skipped=skipped_chunks_count, embedding_jobs_count=len(embedding_jobs), ) await session.commit() @@ -1025,6 +1156,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe sync_start=sync_start, source_rows_count=source_rows_count, embedding_jobs=embedding_jobs, + chunks_total=built_chunk_records_count, + chunks_skipped=skipped_chunks_count, prepare_seconds=prepare_seconds, ) @@ -1123,7 +1256,10 @@ def _finalize_completed_entity_syncs( embed_seconds=runtime.embed_seconds, write_seconds=runtime.write_seconds, source_rows_count=runtime.source_rows_count, + chunks_total=runtime.chunks_total, + chunks_skipped=runtime.chunks_skipped, embedding_jobs_count=runtime.embedding_jobs_count, + entity_skipped=runtime.entity_skipped, ) entity_runtime.pop(entity_id, None) @@ -1139,7 +1275,10 @@ def _log_vector_sync_complete( embed_seconds: float, write_seconds: float, source_rows_count: int, + chunks_total: int, + chunks_skipped: int, embedding_jobs_count: int, + entity_skipped: bool, ) -> None: """Log completion and slow-entity warnings with a consistent format.""" logger.info( @@ -1147,7 +1286,8 @@ def _log_vector_sync_complete( "total_seconds={total_seconds:.3f} prepare_seconds={prepare_seconds:.3f} " "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " - "embedding_jobs_count={embedding_jobs_count}", + "chunks_total={chunks_total} chunks_skipped={chunks_skipped} " + "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, @@ -1156,7 +1296,10 @@ def _log_vector_sync_complete( embed_seconds=embed_seconds, write_seconds=write_seconds, source_rows_count=source_rows_count, + chunks_total=chunks_total, + chunks_skipped=chunks_skipped, embedding_jobs_count=embedding_jobs_count, + entity_skipped=entity_skipped, ) if total_seconds > 10: logger.warning( @@ -1164,7 +1307,8 @@ def _log_vector_sync_complete( "total_seconds={total_seconds:.3f} prepare_seconds={prepare_seconds:.3f} " "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " - "embedding_jobs_count={embedding_jobs_count}", + "chunks_total={chunks_total} chunks_skipped={chunks_skipped} " + "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, @@ -1173,7 +1317,10 @@ def _log_vector_sync_complete( embed_seconds=embed_seconds, write_seconds=write_seconds, source_rows_count=source_rows_count, + chunks_total=chunks_total, + chunks_skipped=chunks_skipped, embedding_jobs_count=embedding_jobs_count, + entity_skipped=entity_skipped, ) async def _prepare_vector_session(self, session: AsyncSession) -> None: diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index bc0fe37b6..a97c01887 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -398,6 +398,8 @@ async def _ensure_vector_tables(self) -> None: "chunk_key", "chunk_text", "source_hash", + "entity_fingerprint", + "embedding_model", "updated_at", } schema_mismatch = bool(chunks_columns) and set(chunks_columns) != expected_columns diff --git a/tests/repository/test_postgres_search_repository.py b/tests/repository/test_postgres_search_repository.py index 46d194b0b..a435bcdf5 100644 --- a/tests/repository/test_postgres_search_repository.py +++ b/tests/repository/test_postgres_search_repository.py @@ -47,6 +47,12 @@ def _vectorize(text: str) -> list[float]: return [0.0, 0.0, 0.0, 1.0] +class StubEmbeddingProviderV2(StubEmbeddingProvider): + """Same vectors, different model identity to force Postgres resync.""" + + model_name = "stub-v2" + + async def _skip_if_pgvector_unavailable(session_maker) -> None: """Skip semantic pgvector tests when extension is not available in test Postgres image.""" async with db.scoped_session(session_maker) as session: @@ -442,6 +448,110 @@ async def test_postgres_semantic_hybrid_search_combines_fts_and_vector(session_m assert any(result.permalink == "specs/search-index" for result in results) +@pytest.mark.asyncio +async def test_postgres_vector_sync_skips_unchanged_and_reembeds_changed_content( + session_maker, test_project +): + """Postgres vector sync tracks new, changed, unchanged, and model-changed entities.""" + await _skip_if_pgvector_unavailable(session_maker) + app_config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + database_backend=DatabaseBackend.POSTGRES, + semantic_search_enabled=True, + ) + repo = PostgresSearchRepository( + session_maker, + project_id=test_project.id, + app_config=app_config, + embedding_provider=StubEmbeddingProvider(), + ) + await repo.init_search_index() + + now = datetime.now(timezone.utc) + await repo.index_item( + SearchIndexRow( + project_id=test_project.id, + id=421, + title="Auth and Schema Notes", + content_stems="# Overview\n- auth token rotation\n- schema migration planning", + content_snippet="# Overview\n- auth token rotation\n- schema migration planning", + permalink="specs/auth-and-schema", + file_path="specs/auth-and-schema.md", + type=SearchItemType.ENTITY.value, + entity_id=421, + metadata={"note_type": "spec"}, + created_at=now, + updated_at=now, + ) + ) + + new_result = await repo.sync_entity_vectors_batch([421]) + assert new_result.entities_synced == 1 + assert new_result.entities_skipped == 0 + assert new_result.chunks_total >= 2 + assert new_result.chunks_skipped == 0 + assert new_result.embedding_jobs_total == new_result.chunks_total + + async with db.scoped_session(session_maker) as session: + stored_rows = await session.execute( + text( + "SELECT entity_fingerprint, embedding_model " + "FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + {"project_id": test_project.id, "entity_id": 421}, + ) + metadata_rows = stored_rows.fetchall() + assert metadata_rows + assert len({row.entity_fingerprint for row in metadata_rows}) == 1 + assert len({row.embedding_model for row in metadata_rows}) == 1 + assert metadata_rows[0].embedding_model == "StubEmbeddingProvider:stub:4" + + unchanged_result = await repo.sync_entity_vectors_batch([421]) + assert unchanged_result.entities_synced == 1 + assert unchanged_result.entities_skipped == 1 + assert unchanged_result.embedding_jobs_total == 0 + assert unchanged_result.chunks_skipped == unchanged_result.chunks_total + + await repo.index_item( + SearchIndexRow( + project_id=test_project.id, + id=421, + title="Auth and Schema Notes", + content_stems="# Overview\n- auth token rotation\n- database schema migration planning", + content_snippet="# Overview\n- auth token rotation\n- database schema migration planning", + permalink="specs/auth-and-schema", + file_path="specs/auth-and-schema.md", + type=SearchItemType.ENTITY.value, + entity_id=421, + metadata={"note_type": "spec"}, + created_at=now, + updated_at=now, + ) + ) + changed_result = await repo.sync_entity_vectors_batch([421]) + assert changed_result.entities_synced == 1 + assert changed_result.entities_skipped == 0 + assert changed_result.embedding_jobs_total >= 1 + assert changed_result.chunks_skipped >= 1 + assert changed_result.embedding_jobs_total < changed_result.chunks_total + + repo_v2 = PostgresSearchRepository( + session_maker, + project_id=test_project.id, + app_config=app_config, + embedding_provider=StubEmbeddingProviderV2(), + ) + await repo_v2.init_search_index() + model_changed_result = await repo_v2.sync_entity_vectors_batch([421]) + assert model_changed_result.entities_synced == 1 + assert model_changed_result.entities_skipped == 0 + assert model_changed_result.chunks_skipped == 0 + assert model_changed_result.embedding_jobs_total == model_changed_result.chunks_total + + @pytest.mark.asyncio async def test_postgres_vector_mode_rejects_non_text_query(session_maker, test_project): """Vector mode should fail fast for title-only queries.""" diff --git a/tests/repository/test_sqlite_vector_search_repository.py b/tests/repository/test_sqlite_vector_search_repository.py index c6c9c71ba..d85f17fac 100644 --- a/tests/repository/test_sqlite_vector_search_repository.py +++ b/tests/repository/test_sqlite_vector_search_repository.py @@ -36,6 +36,12 @@ def _vectorize(text: str) -> list[float]: return [0.0, 0.0, 0.0, 1.0] +class StubEmbeddingProviderV2(StubEmbeddingProvider): + """Same vectors, different model identity to force resync.""" + + model_name = "stub-v2" + + def _entity_row( *, project_id: int, @@ -62,14 +68,17 @@ def _entity_row( ) -def _enable_semantic(search_repository: SQLiteSearchRepository) -> None: +def _enable_semantic( + search_repository: SQLiteSearchRepository, + embedding_provider: StubEmbeddingProvider | None = None, +) -> None: try: import sqlite_vec # noqa: F401 except ImportError: pytest.skip("sqlite-vec dependency is required for sqlite vector repository tests.") search_repository._semantic_enabled = True - search_repository._embedding_provider = StubEmbeddingProvider() + search_repository._embedding_provider = embedding_provider or StubEmbeddingProvider() search_repository._vector_dimensions = search_repository._embedding_provider.dimensions search_repository._vector_tables_initialized = False @@ -102,6 +111,8 @@ async def test_sqlite_vec_tables_are_created_and_rebuilt(search_repository): "chunk_key", "chunk_text", "source_hash", + "entity_fingerprint", + "embedding_model", "updated_at", } @@ -182,6 +193,79 @@ async def test_sqlite_chunk_upsert_and_delete_lifecycle(search_repository): assert int(embedding_count.scalar_one()) == 0 +@pytest.mark.asyncio +async def test_sqlite_vector_sync_skips_unchanged_and_reembeds_changed_content(search_repository): + """SQLite vector sync tracks new, changed, unchanged, and model-changed entities.""" + if not isinstance(search_repository, SQLiteSearchRepository): + pytest.skip("sqlite-vec repository behavior is local SQLite-only.") + + _enable_semantic(search_repository) + await search_repository.init_search_index() + + await search_repository.index_item( + _entity_row( + project_id=search_repository.project_id, + row_id=111, + entity_id=111, + title="Auth and Schema Notes", + permalink="specs/auth-and-schema", + content_stems="# Overview\n- auth token rotation\n- schema migration planning", + ) + ) + + new_result = await search_repository.sync_entity_vectors_batch([111]) + assert new_result.entities_synced == 1 + assert new_result.entities_skipped == 0 + assert new_result.chunks_total >= 2 + assert new_result.chunks_skipped == 0 + assert new_result.embedding_jobs_total == new_result.chunks_total + + async with db.scoped_session(search_repository.session_maker) as session: + stored_rows = await session.execute( + text( + "SELECT entity_fingerprint, embedding_model " + "FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + {"project_id": search_repository.project_id, "entity_id": 111}, + ) + metadata_rows = stored_rows.fetchall() + assert metadata_rows + assert len({row.entity_fingerprint for row in metadata_rows}) == 1 + assert len({row.embedding_model for row in metadata_rows}) == 1 + assert metadata_rows[0].embedding_model == "StubEmbeddingProvider:stub:4" + + unchanged_result = await search_repository.sync_entity_vectors_batch([111]) + assert unchanged_result.entities_synced == 1 + assert unchanged_result.entities_skipped == 1 + assert unchanged_result.embedding_jobs_total == 0 + assert unchanged_result.chunks_skipped == unchanged_result.chunks_total + + await search_repository.index_item( + _entity_row( + project_id=search_repository.project_id, + row_id=111, + entity_id=111, + title="Auth and Schema Notes", + permalink="specs/auth-and-schema", + content_stems="# Overview\n- auth token rotation\n- database schema migration planning", + ) + ) + changed_result = await search_repository.sync_entity_vectors_batch([111]) + assert changed_result.entities_synced == 1 + assert changed_result.entities_skipped == 0 + assert changed_result.embedding_jobs_total >= 1 + assert changed_result.chunks_skipped >= 1 + assert changed_result.embedding_jobs_total < changed_result.chunks_total + + _enable_semantic(search_repository, StubEmbeddingProviderV2()) + model_changed_result = await search_repository.sync_entity_vectors_batch([111]) + assert model_changed_result.entities_synced == 1 + assert model_changed_result.entities_skipped == 0 + assert model_changed_result.chunks_skipped == 0 + assert model_changed_result.embedding_jobs_total == model_changed_result.chunks_total + + @pytest.mark.asyncio async def test_sqlite_vector_search_returns_ranked_entities(search_repository): """Vector mode ranks entities using sqlite-vec nearest-neighbor search.""" From d2055f3f310b1a32c731b0f8401aed62daf4bc6c Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 11:43:37 -0500 Subject: [PATCH 05/10] don't pass force_full=true by default Signed-off-by: phernandez --- src/basic_memory/cli/commands/cloud/cloud_utils.py | 4 ++-- src/basic_memory/cli/commands/cloud/upload_command.py | 2 +- src/basic_memory/cli/commands/doctor.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/basic_memory/cli/commands/cloud/cloud_utils.py b/src/basic_memory/cli/commands/cloud/cloud_utils.py index 0fff40fa7..8d7c5acb8 100644 --- a/src/basic_memory/cli/commands/cloud/cloud_utils.py +++ b/src/basic_memory/cli/commands/cloud/cloud_utils.py @@ -116,12 +116,12 @@ async def sync_project(project_name: str, force_full: bool = False) -> None: Args: project_name: Name of project to sync - force_full: If True, force a full scan bypassing watermark optimization + force_full: ignored, kept for backwards compatibility """ try: from basic_memory.cli.commands.command_utils import run_sync - await run_sync(project=project_name, force_full=force_full) + await run_sync(project=project_name) except Exception as e: raise CloudUtilsError(f"Failed to sync project '{project_name}': {e}") from e diff --git a/src/basic_memory/cli/commands/cloud/upload_command.py b/src/basic_memory/cli/commands/cloud/upload_command.py index 9c0ca0ed4..bdde3becf 100644 --- a/src/basic_memory/cli/commands/cloud/upload_command.py +++ b/src/basic_memory/cli/commands/cloud/upload_command.py @@ -142,7 +142,7 @@ async def _upload(): if sync and not dry_run: console.print(f"[blue]Syncing project '{project}'...[/blue]") try: - await sync_project(project, force_full=True) + await sync_project(project) except Exception as e: console.print(f"[yellow]Warning: Sync failed: {e}[/yellow]") console.print("[dim]Files uploaded but may not be indexed yet[/dim]") diff --git a/src/basic_memory/cli/commands/doctor.py b/src/basic_memory/cli/commands/doctor.py index 61e76df45..d148f0dc6 100644 --- a/src/basic_memory/cli/commands/doctor.py +++ b/src/basic_memory/cli/commands/doctor.py @@ -101,7 +101,7 @@ async def run_doctor() -> None: console.print("[green]OK[/green] Manual file written") sync_data = await project_client.sync( - project_id, force_full=True, run_in_background=False + project_id, force_full=False, run_in_background=False ) sync_report = SyncReportResponse.model_validate(sync_data) if sync_report.total == 0: From d9ab2a19fd95ca05756dd1ddcfd652e7b7f999a8 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 12:14:41 -0500 Subject: [PATCH 06/10] perf(core): shard oversized vector sync work Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 67 +++++- .../repository/search_repository_base.py | 218 +++++++++++++++++- .../test_postgres_search_repository.py | 101 ++++++++ .../test_postgres_search_repository_unit.py | 76 ++++++ 4 files changed, 446 insertions(+), 16 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index fd0a2c652..bddd0dd69 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -533,6 +533,7 @@ async def sync_entity_vectors_batch( pending_jobs: list[_PendingEmbeddingJob] = [] entity_runtime: dict[int, _EntitySyncRuntime] = {} failed_entity_ids: set[int] = set() + deferred_entity_ids: set[int] = set() synced_entity_ids: set[int] = set() for window_start in range(0, total_entities, POSTGRES_VECTOR_PREPARE_CONCURRENCY): @@ -572,7 +573,10 @@ async def sync_entity_vectors_batch( result.prepare_seconds_total += prepared_sync.prepare_seconds if embedding_jobs_count == 0: - synced_entity_ids.add(entity_id) + if prepared_sync.entity_complete: + synced_entity_ids.add(entity_id) + else: + deferred_entity_ids.add(entity_id) total_seconds = time.perf_counter() - prepared_sync.sync_start queue_wait_seconds = max(0.0, total_seconds - prepared_sync.prepare_seconds) result.queue_wait_seconds_total += queue_wait_seconds @@ -588,6 +592,12 @@ async def sync_entity_vectors_batch( chunks_skipped=prepared_sync.chunks_skipped, embedding_jobs_count=0, entity_skipped=prepared_sync.entity_skipped, + entity_complete=prepared_sync.entity_complete, + oversized_entity=prepared_sync.oversized_entity, + pending_jobs_total=prepared_sync.pending_jobs_total, + shard_index=prepared_sync.shard_index, + shard_count=prepared_sync.shard_count, + remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard, ) continue @@ -599,6 +609,12 @@ async def sync_entity_vectors_batch( chunks_total=prepared_sync.chunks_total, chunks_skipped=prepared_sync.chunks_skipped, entity_skipped=prepared_sync.entity_skipped, + entity_complete=prepared_sync.entity_complete, + oversized_entity=prepared_sync.oversized_entity, + pending_jobs_total=prepared_sync.pending_jobs_total, + shard_index=prepared_sync.shard_index, + shard_count=prepared_sync.shard_count, + remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard, prepare_seconds=prepared_sync.prepare_seconds, ) pending_jobs.extend( @@ -624,10 +640,13 @@ async def sync_entity_vectors_batch( (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, ) except Exception as exc: affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) for failed_entity_id in affected_entity_ids: entity_runtime.pop(failed_entity_id, None) logger.warning( @@ -654,10 +673,13 @@ async def sync_entity_vectors_batch( (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, ) except Exception as exc: affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) for failed_entity_id in affected_entity_ids: entity_runtime.pop(failed_entity_id, None) logger.warning( @@ -672,6 +694,8 @@ async def sync_entity_vectors_batch( if entity_runtime: orphan_runtime_entities = sorted(entity_runtime.keys()) failed_entity_ids.update(orphan_runtime_entities) + synced_entity_ids.difference_update(orphan_runtime_entities) + deferred_entity_ids.difference_update(orphan_runtime_entities) logger.warning( "Vector batch sync left unfinished entities after flushes: " "project_id={project_id} unfinished_entities={unfinished_entities}", @@ -680,13 +704,17 @@ async def sync_entity_vectors_batch( ) synced_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(synced_entity_ids) result.failed_entity_ids = sorted(failed_entity_ids) result.entities_failed = len(result.failed_entity_ids) + result.entities_deferred = len(deferred_entity_ids) result.entities_synced = len(synced_entity_ids) logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " + "entities_deferred={entities_deferred} " "entities_skipped={entities_skipped} chunks_total={chunks_total} " "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " "prepare_seconds_total={prepare_seconds_total:.3f} " @@ -696,6 +724,7 @@ async def sync_entity_vectors_batch( entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, + entities_deferred=result.entities_deferred, entities_skipped=result.entities_skipped, chunks_total=result.chunks_total, chunks_skipped=result.chunks_skipped, @@ -849,14 +878,13 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe prepare_seconds=prepare_seconds, ) - upsert_records: list[dict[str, str]] = [] - embedding_jobs: list[tuple[int, str]] = [] + pending_records: list[dict[str, str]] = [] skipped_chunks_count = 0 for record in chunk_records: current = existing_by_key.get(record["chunk_key"]) if current is None: - upsert_records.append(record) + pending_records.append(record) continue row_id = int(current["id"]) @@ -886,7 +914,19 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe skipped_chunks_count += 1 continue - upsert_records.append(record) + pending_records.append(record) + + shard_plan = self._plan_entity_vector_shard(pending_records) + self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan) + + scheduled_records = [ + record + for record in sorted(pending_records, key=lambda record: record["chunk_key"]) + if record["chunk_key"] in shard_plan.scheduled_chunk_keys + ] + + embedding_jobs: list[tuple[int, str]] = [] + upsert_records = list(scheduled_records) if upsert_records: upsert_params: dict[str, object] = { @@ -943,7 +983,10 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "stale_chunks_count={stale_chunks_count} " "orphan_chunks_count={orphan_chunks_count} " "chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count}", + "embedding_jobs_count={embedding_jobs_count} " + "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " + "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} " + "oversized_entity={oversized_entity} entity_complete={entity_complete}", project_id=self.project_id, entity_id=entity_id, existing_chunks_count=existing_chunks_count, @@ -951,6 +994,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe orphan_chunks_count=orphan_chunks_count, chunks_skipped=skipped_chunks_count, embedding_jobs_count=len(embedding_jobs), + pending_jobs_total=shard_plan.pending_jobs_total, + shard_index=shard_plan.shard_index, + shard_count=shard_plan.shard_count, + remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, + oversized_entity=shard_plan.oversized_entity, + entity_complete=shard_plan.entity_complete, ) prepare_seconds = time.perf_counter() - sync_start @@ -961,6 +1010,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe embedding_jobs=embedding_jobs, chunks_total=built_chunk_records_count, chunks_skipped=skipped_chunks_count, + entity_complete=shard_plan.entity_complete, + oversized_entity=shard_plan.oversized_entity, + pending_jobs_total=shard_plan.pending_jobs_total, + shard_index=shard_plan.shard_index, + shard_count=shard_plan.shard_count, + remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, prepare_seconds=prepare_seconds, ) diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 7d7deb4e1..80c68b5f4 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -2,6 +2,7 @@ import hashlib import json +import math import re import time from abc import ABC, abstractmethod @@ -33,6 +34,7 @@ SMALL_NOTE_CONTENT_LIMIT = 2000 HEADER_LINE_PATTERN = re.compile(r"^\s*#{1,6}\s+") BULLET_PATTERN = re.compile(r"^[\-\*]\s+") +OVERSIZED_ENTITY_VECTOR_SHARD_SIZE = 256 @dataclass @@ -42,6 +44,7 @@ class VectorSyncBatchResult: entities_total: int entities_synced: int entities_failed: int + entities_deferred: int = 0 entities_skipped: int = 0 failed_entity_ids: list[int] = field(default_factory=list) chunks_total: int = 0 @@ -64,6 +67,12 @@ class _PreparedEntityVectorSync: chunks_total: int = 0 chunks_skipped: int = 0 entity_skipped: bool = False + entity_complete: bool = True + oversized_entity: bool = False + pending_jobs_total: int = 0 + shard_index: int = 1 + shard_count: int = 1 + remaining_jobs_after_shard: int = 0 prepare_seconds: float = 0.0 @@ -87,11 +96,30 @@ class _EntitySyncRuntime: chunks_total: int = 0 chunks_skipped: int = 0 entity_skipped: bool = False + entity_complete: bool = True + oversized_entity: bool = False + pending_jobs_total: int = 0 + shard_index: int = 1 + shard_count: int = 1 + remaining_jobs_after_shard: int = 0 prepare_seconds: float = 0.0 embed_seconds: float = 0.0 write_seconds: float = 0.0 +@dataclass(frozen=True) +class _EntityVectorShardPlan: + """Shard selection for one entity's pending embedding work.""" + + scheduled_chunk_keys: set[str] + pending_jobs_total: int + shard_index: int + shard_count: int + remaining_jobs_after_shard: int + oversized_entity: bool + entity_complete: bool + + class SearchRepositoryBase(ABC): """Abstract base class for backend-specific search repository implementations. @@ -524,6 +552,79 @@ def _embedding_model_key(self) -> str: f"{self._embedding_provider.dimensions}" ) + def _plan_entity_vector_shard( + self, + pending_records: list[dict[str, str]], + ) -> _EntityVectorShardPlan: + """Select the bounded shard to process for one entity sync invocation.""" + pending_jobs_total = len(pending_records) + if pending_jobs_total == 0: + return _EntityVectorShardPlan( + scheduled_chunk_keys=set(), + pending_jobs_total=0, + shard_index=1, + shard_count=1, + remaining_jobs_after_shard=0, + oversized_entity=False, + entity_complete=True, + ) + + ordered_pending_records = sorted(pending_records, key=lambda record: record["chunk_key"]) + scheduled_records = ordered_pending_records[:OVERSIZED_ENTITY_VECTOR_SHARD_SIZE] + remaining_jobs_after_shard = pending_jobs_total - len(scheduled_records) + return _EntityVectorShardPlan( + scheduled_chunk_keys={record["chunk_key"] for record in scheduled_records}, + pending_jobs_total=pending_jobs_total, + shard_index=1, + shard_count=max( + 1, + math.ceil(pending_jobs_total / OVERSIZED_ENTITY_VECTOR_SHARD_SIZE), + ), + remaining_jobs_after_shard=remaining_jobs_after_shard, + oversized_entity=pending_jobs_total > OVERSIZED_ENTITY_VECTOR_SHARD_SIZE, + entity_complete=remaining_jobs_after_shard == 0, + ) + + def _log_vector_shard_plan( + self, + *, + entity_id: int, + shard_plan: _EntityVectorShardPlan, + ) -> None: + """Emit shard planning logs once the pending work is known.""" + if shard_plan.pending_jobs_total == 0: + return + + scheduled_jobs_count = shard_plan.pending_jobs_total - shard_plan.remaining_jobs_after_shard + if shard_plan.oversized_entity: + logger.warning( + "Vector sync oversized entity detected: project_id={project_id} " + "entity_id={entity_id} pending_jobs_total={pending_jobs_total} " + "shard_size={shard_size} shard_count={shard_count}", + project_id=self.project_id, + entity_id=entity_id, + pending_jobs_total=shard_plan.pending_jobs_total, + shard_size=OVERSIZED_ENTITY_VECTOR_SHARD_SIZE, + shard_count=shard_plan.shard_count, + ) + + logger.info( + "Vector sync shard planned: project_id={project_id} entity_id={entity_id} " + "pending_jobs_total={pending_jobs_total} scheduled_jobs_count={scheduled_jobs_count} " + "shard_index={shard_index} shard_count={shard_count} " + "remaining_jobs_after_shard={remaining_jobs_after_shard} " + "oversized_entity={oversized_entity} entity_complete={entity_complete}", + project_id=self.project_id, + entity_id=entity_id, + pending_jobs_total=shard_plan.pending_jobs_total, + scheduled_jobs_count=scheduled_jobs_count, + shard_index=shard_plan.shard_index, + shard_count=shard_plan.shard_count, + remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, + oversized_entity=shard_plan.oversized_entity, + entity_complete=shard_plan.entity_complete, + ) + # --- Text splitting --- def _split_text_into_chunks(self, text_value: str) -> list[str]: @@ -715,6 +816,7 @@ async def _sync_entity_vectors_internal( pending_jobs: list[_PendingEmbeddingJob] = [] entity_runtime: dict[int, _EntitySyncRuntime] = {} failed_entity_ids: set[int] = set() + deferred_entity_ids: set[int] = set() synced_entity_ids: set[int] = set() for index, entity_id in enumerate(entity_ids): @@ -745,7 +847,10 @@ async def _sync_entity_vectors_internal( result.prepare_seconds_total += prepared.prepare_seconds if embedding_jobs_count == 0: - synced_entity_ids.add(entity_id) + if prepared.entity_complete: + synced_entity_ids.add(entity_id) + else: + deferred_entity_ids.add(entity_id) total_seconds = time.perf_counter() - prepared.sync_start queue_wait_seconds = max(0.0, total_seconds - prepared.prepare_seconds) result.queue_wait_seconds_total += queue_wait_seconds @@ -761,6 +866,12 @@ async def _sync_entity_vectors_internal( chunks_skipped=prepared.chunks_skipped, embedding_jobs_count=0, entity_skipped=prepared.entity_skipped, + entity_complete=prepared.entity_complete, + oversized_entity=prepared.oversized_entity, + pending_jobs_total=prepared.pending_jobs_total, + shard_index=prepared.shard_index, + shard_count=prepared.shard_count, + remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, ) continue @@ -772,6 +883,12 @@ async def _sync_entity_vectors_internal( chunks_total=prepared.chunks_total, chunks_skipped=prepared.chunks_skipped, entity_skipped=prepared.entity_skipped, + entity_complete=prepared.entity_complete, + oversized_entity=prepared.oversized_entity, + pending_jobs_total=prepared.pending_jobs_total, + shard_index=prepared.shard_index, + shard_count=prepared.shard_count, + remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, prepare_seconds=prepared.prepare_seconds, ) pending_jobs.extend( @@ -795,12 +912,15 @@ async def _sync_entity_vectors_internal( (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, ) except Exception as exc: if not continue_on_error: raise affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) for failed_entity_id in affected_entity_ids: entity_runtime.pop(failed_entity_id, None) logger.warning( @@ -826,12 +946,15 @@ async def _sync_entity_vectors_internal( (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, ) except Exception as exc: if not continue_on_error: raise affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) for failed_entity_id in affected_entity_ids: entity_runtime.pop(failed_entity_id, None) logger.warning( @@ -849,6 +972,8 @@ async def _sync_entity_vectors_internal( if entity_runtime: orphan_runtime_entities = sorted(entity_runtime.keys()) failed_entity_ids.update(orphan_runtime_entities) + synced_entity_ids.difference_update(orphan_runtime_entities) + deferred_entity_ids.difference_update(orphan_runtime_entities) logger.warning( "Vector batch sync left unfinished entities after flushes: " "project_id={project_id} unfinished_entities={unfinished_entities}", @@ -858,13 +983,17 @@ async def _sync_entity_vectors_internal( # Keep result counters aligned with successful/failed terminal states. synced_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(synced_entity_ids) result.failed_entity_ids = sorted(failed_entity_ids) result.entities_failed = len(result.failed_entity_ids) + result.entities_deferred = len(deferred_entity_ids) result.entities_synced = len(synced_entity_ids) logger.info( "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " "entities_synced={entities_synced} entities_failed={entities_failed} " + "entities_deferred={entities_deferred} " "entities_skipped={entities_skipped} chunks_total={chunks_total} " "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " "prepare_seconds_total={prepare_seconds_total:.3f} " @@ -874,6 +1003,7 @@ async def _sync_entity_vectors_internal( entities_total=result.entities_total, entities_synced=result.entities_synced, entities_failed=result.entities_failed, + entities_deferred=result.entities_deferred, entities_skipped=result.entities_skipped, chunks_total=result.chunks_total, chunks_skipped=result.chunks_skipped, @@ -1046,9 +1176,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe prepare_seconds=prepare_seconds, ) - # --- Upsert changed / new chunks, collect embedding jobs --- timestamp_expr = self._timestamp_now_expr() - embedding_jobs: list[tuple[int, str]] = [] + pending_records: list[dict[str, str]] = [] skipped_chunks_count = 0 for record in chunk_records: current = existing_by_key.get(record["chunk_key"]) @@ -1081,9 +1210,29 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "embedding_model": current_embedding_model, }, ) - skipped_chunks_count += 1 - continue + skipped_chunks_count += 1 + continue + + pending_records.append(record) + shard_plan = self._plan_entity_vector_shard(pending_records) + self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan) + + # Trigger: oversized entities can accumulate thousands of pending chunks. + # Why: scheduling only one deterministic shard bounds memory and wall clock. + # Outcome: future runs resume from the remaining chunk rows without redoing completed work. + scheduled_records = [ + record + for record in sorted(pending_records, key=lambda record: record["chunk_key"]) + if record["chunk_key"] in shard_plan.scheduled_chunk_keys + ] + + # --- Upsert scheduled changed / new chunks, collect embedding jobs --- + embedding_jobs: list[tuple[int, str]] = [] + for record in scheduled_records: + current = existing_by_key.get(record["chunk_key"]) + if current: + row_id = int(current.id) if ( current.source_hash != record["source_hash"] or current.entity_fingerprint != current_entity_fingerprint @@ -1139,7 +1288,10 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "stale_chunks_count={stale_chunks_count} " "orphan_chunks_count={orphan_chunks_count} " "chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count}", + "embedding_jobs_count={embedding_jobs_count} " + "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " + "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} " + "oversized_entity={oversized_entity} entity_complete={entity_complete}", project_id=self.project_id, entity_id=entity_id, existing_chunks_count=existing_chunks_count, @@ -1147,6 +1299,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe orphan_chunks_count=orphan_chunks_count, chunks_skipped=skipped_chunks_count, embedding_jobs_count=len(embedding_jobs), + pending_jobs_total=shard_plan.pending_jobs_total, + shard_index=shard_plan.shard_index, + shard_count=shard_plan.shard_count, + remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, + oversized_entity=shard_plan.oversized_entity, + entity_complete=shard_plan.entity_complete, ) await session.commit() @@ -1158,6 +1316,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe embedding_jobs=embedding_jobs, chunks_total=built_chunk_records_count, chunks_skipped=skipped_chunks_count, + entity_complete=shard_plan.entity_complete, + oversized_entity=shard_plan.oversized_entity, + pending_jobs_total=shard_plan.pending_jobs_total, + shard_index=shard_plan.shard_index, + shard_count=shard_plan.shard_count, + remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, prepare_seconds=prepare_seconds, ) @@ -1221,7 +1385,7 @@ async def _flush_embedding_jobs( runtime.embed_seconds += embed_seconds * flush_share runtime.write_seconds += write_seconds * flush_share - if runtime.remaining_jobs <= 0: + if runtime.remaining_jobs <= 0 and runtime.entity_complete: synced_entity_ids.add(entity_id) return embed_seconds, write_seconds @@ -1231,6 +1395,7 @@ def _finalize_completed_entity_syncs( *, entity_runtime: dict[int, _EntitySyncRuntime], synced_entity_ids: set[int], + deferred_entity_ids: set[int], ) -> float: """Finalize completed entities and return cumulative queue wait seconds.""" queue_wait_seconds_total = 0.0 @@ -1238,7 +1403,10 @@ def _finalize_completed_entity_syncs( if runtime.remaining_jobs > 0: continue - synced_entity_ids.add(entity_id) + if runtime.entity_complete: + synced_entity_ids.add(entity_id) + else: + deferred_entity_ids.add(entity_id) total_seconds = time.perf_counter() - runtime.sync_start queue_wait_seconds = max( 0.0, @@ -1260,6 +1428,12 @@ def _finalize_completed_entity_syncs( chunks_skipped=runtime.chunks_skipped, embedding_jobs_count=runtime.embedding_jobs_count, entity_skipped=runtime.entity_skipped, + entity_complete=runtime.entity_complete, + oversized_entity=runtime.oversized_entity, + pending_jobs_total=runtime.pending_jobs_total, + shard_index=runtime.shard_index, + shard_count=runtime.shard_count, + remaining_jobs_after_shard=runtime.remaining_jobs_after_shard, ) entity_runtime.pop(entity_id, None) @@ -1279,6 +1453,12 @@ def _log_vector_sync_complete( chunks_skipped: int, embedding_jobs_count: int, entity_skipped: bool, + entity_complete: bool, + oversized_entity: bool, + pending_jobs_total: int, + shard_index: int, + shard_count: int, + remaining_jobs_after_shard: int, ) -> None: """Log completion and slow-entity warnings with a consistent format.""" logger.info( @@ -1287,7 +1467,10 @@ def _log_vector_sync_complete( "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " "chunks_total={chunks_total} chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped}", + "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped} " + "entity_complete={entity_complete} oversized_entity={oversized_entity} " + "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " + "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, @@ -1300,6 +1483,12 @@ def _log_vector_sync_complete( chunks_skipped=chunks_skipped, embedding_jobs_count=embedding_jobs_count, entity_skipped=entity_skipped, + entity_complete=entity_complete, + oversized_entity=oversized_entity, + pending_jobs_total=pending_jobs_total, + shard_index=shard_index, + shard_count=shard_count, + remaining_jobs_after_shard=remaining_jobs_after_shard, ) if total_seconds > 10: logger.warning( @@ -1308,7 +1497,10 @@ def _log_vector_sync_complete( "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " "chunks_total={chunks_total} chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped}", + "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped} " + "entity_complete={entity_complete} oversized_entity={oversized_entity} " + "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " + "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard}", project_id=self.project_id, entity_id=entity_id, total_seconds=total_seconds, @@ -1321,6 +1513,12 @@ def _log_vector_sync_complete( chunks_skipped=chunks_skipped, embedding_jobs_count=embedding_jobs_count, entity_skipped=entity_skipped, + entity_complete=entity_complete, + oversized_entity=oversized_entity, + pending_jobs_total=pending_jobs_total, + shard_index=shard_index, + shard_count=shard_count, + remaining_jobs_after_shard=remaining_jobs_after_shard, ) async def _prepare_vector_session(self, session: AsyncSession) -> None: diff --git a/tests/repository/test_postgres_search_repository.py b/tests/repository/test_postgres_search_repository.py index a435bcdf5..9e9f3e1a8 100644 --- a/tests/repository/test_postgres_search_repository.py +++ b/tests/repository/test_postgres_search_repository.py @@ -11,6 +11,7 @@ from basic_memory import db from basic_memory.config import BasicMemoryConfig, DatabaseBackend +import basic_memory.repository.search_repository_base as search_repository_base_module from basic_memory.repository.postgres_search_repository import ( PostgresSearchRepository, _strip_nul_from_row, @@ -53,6 +54,13 @@ class StubEmbeddingProviderV2(StubEmbeddingProvider): model_name = "stub-v2" +def _oversized_entity_content(bullet_count: int) -> str: + """Build deterministic content that produces many vector chunks.""" + lines = ["# Oversized Entity"] + lines.extend(f"- embedding job {index}" for index in range(1, bullet_count + 1)) + return "\n".join(lines) + + async def _skip_if_pgvector_unavailable(session_maker) -> None: """Skip semantic pgvector tests when extension is not available in test Postgres image.""" async with db.scoped_session(session_maker) as session: @@ -552,6 +560,99 @@ async def test_postgres_vector_sync_skips_unchanged_and_reembeds_changed_content assert model_changed_result.embedding_jobs_total == model_changed_result.chunks_total +@pytest.mark.asyncio +async def test_postgres_vector_sync_shards_oversized_entity_and_resumes( + session_maker, test_project, monkeypatch +): + """Oversized entities should sync one deterministic shard per run and resume cleanly.""" + await _skip_if_pgvector_unavailable(session_maker) + monkeypatch.setattr(search_repository_base_module, "OVERSIZED_ENTITY_VECTOR_SHARD_SIZE", 2) + + app_config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + database_backend=DatabaseBackend.POSTGRES, + semantic_search_enabled=True, + ) + repo = PostgresSearchRepository( + session_maker, + project_id=test_project.id, + app_config=app_config, + embedding_provider=StubEmbeddingProvider(), + ) + await repo.init_search_index() + + now = datetime.now(timezone.utc) + content = _oversized_entity_content(5) + await repo.index_item( + SearchIndexRow( + project_id=test_project.id, + id=430, + title="Oversized Vector Entity", + content_stems=content, + content_snippet=content, + permalink="specs/oversized-vector-entity", + file_path="specs/oversized-vector-entity.md", + type=SearchItemType.ENTITY.value, + entity_id=430, + metadata={"note_type": "spec"}, + created_at=now, + updated_at=now, + ) + ) + + first_result = await repo.sync_entity_vectors_batch([430]) + assert first_result.entities_synced == 0 + assert first_result.entities_deferred == 1 + assert first_result.entities_failed == 0 + assert first_result.embedding_jobs_total == 2 + assert first_result.chunks_total == 6 + assert first_result.chunks_skipped == 0 + + second_result = await repo.sync_entity_vectors_batch([430]) + assert second_result.entities_synced == 0 + assert second_result.entities_deferred == 1 + assert second_result.entities_failed == 0 + assert second_result.embedding_jobs_total == 2 + assert second_result.chunks_total == 6 + assert second_result.chunks_skipped == 2 + + third_result = await repo.sync_entity_vectors_batch([430]) + assert third_result.entities_synced == 1 + assert third_result.entities_deferred == 0 + assert third_result.entities_failed == 0 + assert third_result.embedding_jobs_total == 2 + assert third_result.chunks_total == 6 + assert third_result.chunks_skipped == 4 + + unchanged_result = await repo.sync_entity_vectors_batch([430]) + assert unchanged_result.entities_synced == 1 + assert unchanged_result.entities_deferred == 0 + assert unchanged_result.entities_skipped == 1 + assert unchanged_result.embedding_jobs_total == 0 + assert unchanged_result.chunks_skipped == unchanged_result.chunks_total == 6 + + async with db.scoped_session(session_maker) as session: + chunk_count = await session.execute( + text( + "SELECT COUNT(*) FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + {"project_id": test_project.id, "entity_id": 430}, + ) + embedding_count = await session.execute( + text( + "SELECT COUNT(*) FROM search_vector_embeddings e " + "JOIN search_vector_chunks c ON c.id = e.chunk_id " + "WHERE c.project_id = :project_id AND c.entity_id = :entity_id" + ), + {"project_id": test_project.id, "entity_id": 430}, + ) + assert int(chunk_count.scalar_one()) == 6 + assert int(embedding_count.scalar_one()) == 6 + + @pytest.mark.asyncio async def test_postgres_vector_mode_rejects_non_text_query(session_maker, test_project): """Vector mode should fail fast for title-only queries.""" diff --git a/tests/repository/test_postgres_search_repository_unit.py b/tests/repository/test_postgres_search_repository_unit.py index 4a331e0de..c669bcd63 100644 --- a/tests/repository/test_postgres_search_repository_unit.py +++ b/tests/repository/test_postgres_search_repository_unit.py @@ -348,3 +348,79 @@ def _capture_log(**kwargs): for record in completion_records: assert record["prepare_seconds"] == pytest.approx(1.0) assert record["queue_wait_seconds"] == pytest.approx(1.5) + + +@pytest.mark.asyncio +async def test_postgres_batch_sync_tracks_deferred_oversized_entities(monkeypatch): + """Oversized shard runs should be deferred until the last shard completes.""" + repo = _make_repo( + semantic_enabled=True, + embedding_provider=StubEmbeddingProvider(), + ) + repo._semantic_embedding_sync_batch_size = 8 + repo._vector_tables_initialized = True + + async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: + if entity_id == 1: + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(201, "chunk-1a"), (202, "chunk-1b")], + chunks_total=5, + pending_jobs_total=5, + entity_complete=False, + oversized_entity=True, + shard_index=1, + shard_count=3, + remaining_jobs_after_shard=3, + ) + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(301, "chunk-2a")], + chunks_total=1, + pending_jobs_total=1, + entity_complete=True, + shard_index=1, + shard_count=1, + remaining_jobs_after_shard=0, + ) + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + for job in flush_jobs: + runtime = entity_runtime[job.entity_id] + runtime.remaining_jobs -= 1 + runtime.embed_seconds += 0.5 + runtime.write_seconds += 0.25 + return (1.5, 0.75) + + completion_records: list[dict] = [] + + def _capture_log(**kwargs): + completion_records.append(kwargs) + + monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_total == 2 + assert result.entities_synced == 1 + assert result.entities_deferred == 1 + assert result.entities_failed == 0 + assert result.embedding_jobs_total == 3 + + deferred_record = next(record for record in completion_records if record["entity_id"] == 1) + assert deferred_record["entity_complete"] is False + assert deferred_record["oversized_entity"] is True + assert deferred_record["pending_jobs_total"] == 5 + assert deferred_record["shard_count"] == 3 + assert deferred_record["remaining_jobs_after_shard"] == 3 + + complete_record = next(record for record in completion_records if record["entity_id"] == 2) + assert complete_record["entity_complete"] is True + assert complete_record["oversized_entity"] is False From 2f34747c9fa967c3716794b1558a70025a59c3c1 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 12:26:31 -0500 Subject: [PATCH 07/10] perf(core): make postgres vector sync tuning configurable Signed-off-by: phernandez --- src/basic_memory/config.py | 6 ++++++ .../repository/postgres_search_repository.py | 16 ++++++++-------- .../test_postgres_search_repository_unit.py | 5 ++++- tests/test_config.py | 16 ++++++++++++++++ 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index e3f01eced..c6d895f4e 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -198,6 +198,12 @@ class BasicMemoryConfig(BaseSettings): description="Batch size for vector sync orchestration flushes.", gt=0, ) + semantic_postgres_prepare_concurrency: int = Field( + default=4, + description="Number of Postgres entity prepare tasks to run concurrently during vector sync. Postgres only; keep this low to avoid overdriving the database connection pool.", + gt=0, + le=16, + ) semantic_embedding_cache_dir: str | None = Field( default=None, description="Optional cache directory for FastEmbed model artifacts.", diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index bddd0dd69..a7bff5d00 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -28,9 +28,6 @@ from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode -POSTGRES_VECTOR_PREPARE_CONCURRENCY = 4 - - def _strip_nul_from_row(row_data: dict) -> dict: """Strip NUL bytes from all string values in a row dict. @@ -71,6 +68,9 @@ def __init__( self._semantic_embedding_sync_batch_size = ( self._app_config.semantic_embedding_sync_batch_size ) + self._semantic_postgres_prepare_concurrency = ( + self._app_config.semantic_postgres_prepare_concurrency + ) self._embedding_provider = embedding_provider self._vector_dimensions = 384 self._vector_tables_initialized = False @@ -503,8 +503,8 @@ async def sync_entity_vectors_batch( Trigger: cloud indexing uses Neon Postgres where network latency dominates thousands of per-entity prepare queries. - Why: preparing a small window of entities concurrently hides round-trip latency - without exhausting the tenant connection pool. + Why: preparing a small config-driven window of entities concurrently hides + round-trip latency without exhausting the tenant connection pool. Outcome: Postgres vector sync keeps the existing flush semantics while reducing wall-clock time on large cloud projects. """ @@ -527,7 +527,7 @@ async def sync_entity_vectors_batch( project_id=self.project_id, entities_total=total_entities, sync_batch_size=self._semantic_embedding_sync_batch_size, - prepare_concurrency=POSTGRES_VECTOR_PREPARE_CONCURRENCY, + prepare_concurrency=self._semantic_postgres_prepare_concurrency, ) pending_jobs: list[_PendingEmbeddingJob] = [] @@ -536,9 +536,9 @@ async def sync_entity_vectors_batch( deferred_entity_ids: set[int] = set() synced_entity_ids: set[int] = set() - for window_start in range(0, total_entities, POSTGRES_VECTOR_PREPARE_CONCURRENCY): + for window_start in range(0, total_entities, self._semantic_postgres_prepare_concurrency): window_entity_ids = entity_ids[ - window_start : window_start + POSTGRES_VECTOR_PREPARE_CONCURRENCY + window_start : window_start + self._semantic_postgres_prepare_concurrency ] if progress_callback is not None: diff --git a/tests/repository/test_postgres_search_repository_unit.py b/tests/repository/test_postgres_search_repository_unit.py index c669bcd63..8fff44200 100644 --- a/tests/repository/test_postgres_search_repository_unit.py +++ b/tests/repository/test_postgres_search_repository_unit.py @@ -40,6 +40,7 @@ def _make_repo( *, semantic_enabled: bool = False, embedding_provider=None, + semantic_postgres_prepare_concurrency: int = 4, ) -> PostgresSearchRepository: """Build a PostgresSearchRepository with a no-op session maker.""" session_maker = MagicMock() @@ -49,6 +50,7 @@ def _make_repo( default_project="test-project", database_backend=DatabaseBackend.POSTGRES, semantic_search_enabled=semantic_enabled, + semantic_postgres_prepare_concurrency=semantic_postgres_prepare_concurrency, ) return PostgresSearchRepository( session_maker, @@ -255,6 +257,7 @@ async def test_sync_entity_vectors_batch_prepares_entities_concurrently(self, mo repo = _make_repo( semantic_enabled=True, embedding_provider=StubEmbeddingProvider(), + semantic_postgres_prepare_concurrency=2, ) repo._semantic_embedding_sync_batch_size = 8 repo._vector_tables_initialized = True @@ -283,7 +286,7 @@ async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: assert result.entities_total == 4 assert result.entities_synced == 4 assert result.entities_failed == 0 - assert max_active_prepares > 1 + assert max_active_prepares == 2 @pytest.mark.asyncio diff --git a/tests/test_config.py b/tests/test_config.py index 4ef94ec4c..5c4fe5d79 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -882,6 +882,22 @@ def test_semantic_embedding_dimensions_can_be_set(self): config = BasicMemoryConfig(semantic_embedding_dimensions=1536) assert config.semantic_embedding_dimensions == 1536 + def test_semantic_postgres_prepare_concurrency_defaults_to_4(self): + """Postgres prepare concurrency should default to a conservative window of 4.""" + config = BasicMemoryConfig() + assert config.semantic_postgres_prepare_concurrency == 4 + + def test_semantic_postgres_prepare_concurrency_validation(self): + """Postgres prepare concurrency must stay within the bounded safe range.""" + config = BasicMemoryConfig(semantic_postgres_prepare_concurrency=8) + assert config.semantic_postgres_prepare_concurrency == 8 + + with pytest.raises(Exception): + BasicMemoryConfig(semantic_postgres_prepare_concurrency=0) + + with pytest.raises(Exception): + BasicMemoryConfig(semantic_postgres_prepare_concurrency=17) + def test_semantic_search_enabled_description_mentions_both_backends(self): """Description should not say 'SQLite only' anymore.""" field_info = BasicMemoryConfig.model_fields["semantic_search_enabled"] From 4529bfa128a7d4d7235c6dbb08dd6f923def0b27 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 12:37:13 -0500 Subject: [PATCH 08/10] refactor(core): share vector sync prepare orchestration Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 269 ++---------------- .../repository/search_repository_base.py | 220 +++++++------- 2 files changed, 146 insertions(+), 343 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index a7bff5d00..47785c473 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -18,9 +18,6 @@ from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.search_repository_base import ( SearchRepositoryBase, - VectorSyncBatchResult, - _EntitySyncRuntime, - _PendingEmbeddingJob, _PreparedEntityVectorSync, ) from basic_memory.repository.metadata_filters import parse_metadata_filters @@ -494,248 +491,22 @@ async def _run_vector_query( ) return [dict(row) for row in vector_result.mappings().all()] - async def sync_entity_vectors_batch( - self, - entity_ids: list[int], - progress_callback=None, - ) -> VectorSyncBatchResult: - """Sync semantic vectors with concurrent Postgres preparation windows. - - Trigger: cloud indexing uses Neon Postgres where network latency dominates - thousands of per-entity prepare queries. - Why: preparing a small config-driven window of entities concurrently hides - round-trip latency without exhausting the tenant connection pool. - Outcome: Postgres vector sync keeps the existing flush semantics while reducing - wall-clock time on large cloud projects. - """ - self._assert_semantic_available() - await self._ensure_vector_tables() - assert self._embedding_provider is not None - - total_entities = len(entity_ids) - result = VectorSyncBatchResult( - entities_total=total_entities, - entities_synced=0, - entities_failed=0, - ) - if total_entities == 0: - return result - - logger.info( - "Vector batch sync start: project_id={project_id} entities_total={entities_total} " - "sync_batch_size={sync_batch_size} prepare_concurrency={prepare_concurrency}", - project_id=self.project_id, - entities_total=total_entities, - sync_batch_size=self._semantic_embedding_sync_batch_size, - prepare_concurrency=self._semantic_postgres_prepare_concurrency, + def _vector_prepare_window_size(self) -> int: + """Use a bounded config-driven prepare window for Postgres vector sync.""" + return self._semantic_postgres_prepare_concurrency + + async def _prepare_entity_vector_jobs_window( + self, entity_ids: list[int] + ) -> list[_PreparedEntityVectorSync | BaseException]: + """Prepare one Postgres window concurrently to hide DB round-trip latency.""" + prepared_window = await asyncio.gather( + *(self._prepare_entity_vector_jobs(entity_id) for entity_id in entity_ids), + return_exceptions=True, ) - - pending_jobs: list[_PendingEmbeddingJob] = [] - entity_runtime: dict[int, _EntitySyncRuntime] = {} - failed_entity_ids: set[int] = set() - deferred_entity_ids: set[int] = set() - synced_entity_ids: set[int] = set() - - for window_start in range(0, total_entities, self._semantic_postgres_prepare_concurrency): - window_entity_ids = entity_ids[ - window_start : window_start + self._semantic_postgres_prepare_concurrency - ] - - if progress_callback is not None: - for offset, entity_id in enumerate(window_entity_ids, start=window_start): - progress_callback(entity_id, offset, total_entities) - - prepared_window = await asyncio.gather( - *(self._prepare_entity_vector_jobs(entity_id) for entity_id in window_entity_ids), - return_exceptions=True, - ) - - for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): - if isinstance(prepared, BaseException): - failed_entity_ids.add(entity_id) - logger.warning( - "Vector batch sync entity prepare failed: project_id={project_id} " - "entity_id={entity_id} error={error}", - project_id=self.project_id, - entity_id=entity_id, - error=str(prepared), - ) - continue - - prepared_sync = cast(_PreparedEntityVectorSync, prepared) - - embedding_jobs_count = len(prepared_sync.embedding_jobs) - result.chunks_total += prepared_sync.chunks_total - result.chunks_skipped += prepared_sync.chunks_skipped - if prepared_sync.entity_skipped: - result.entities_skipped += 1 - result.embedding_jobs_total += embedding_jobs_count - result.prepare_seconds_total += prepared_sync.prepare_seconds - - if embedding_jobs_count == 0: - if prepared_sync.entity_complete: - synced_entity_ids.add(entity_id) - else: - deferred_entity_ids.add(entity_id) - total_seconds = time.perf_counter() - prepared_sync.sync_start - queue_wait_seconds = max(0.0, total_seconds - prepared_sync.prepare_seconds) - result.queue_wait_seconds_total += queue_wait_seconds - self._log_vector_sync_complete( - entity_id=entity_id, - total_seconds=total_seconds, - prepare_seconds=prepared_sync.prepare_seconds, - queue_wait_seconds=queue_wait_seconds, - embed_seconds=0.0, - write_seconds=0.0, - source_rows_count=prepared_sync.source_rows_count, - chunks_total=prepared_sync.chunks_total, - chunks_skipped=prepared_sync.chunks_skipped, - embedding_jobs_count=0, - entity_skipped=prepared_sync.entity_skipped, - entity_complete=prepared_sync.entity_complete, - oversized_entity=prepared_sync.oversized_entity, - pending_jobs_total=prepared_sync.pending_jobs_total, - shard_index=prepared_sync.shard_index, - shard_count=prepared_sync.shard_count, - remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard, - ) - continue - - entity_runtime[entity_id] = _EntitySyncRuntime( - sync_start=prepared_sync.sync_start, - source_rows_count=prepared_sync.source_rows_count, - embedding_jobs_count=embedding_jobs_count, - remaining_jobs=embedding_jobs_count, - chunks_total=prepared_sync.chunks_total, - chunks_skipped=prepared_sync.chunks_skipped, - entity_skipped=prepared_sync.entity_skipped, - entity_complete=prepared_sync.entity_complete, - oversized_entity=prepared_sync.oversized_entity, - pending_jobs_total=prepared_sync.pending_jobs_total, - shard_index=prepared_sync.shard_index, - shard_count=prepared_sync.shard_count, - remaining_jobs_after_shard=prepared_sync.remaining_jobs_after_shard, - prepare_seconds=prepared_sync.prepare_seconds, - ) - pending_jobs.extend( - _PendingEmbeddingJob( - entity_id=entity_id, - chunk_row_id=row_id, - chunk_text=chunk_text, - ) - for row_id, chunk_text in prepared_sync.embedding_jobs - ) - - while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: - flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] - pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] - try: - embed_seconds, write_seconds = await self._flush_embedding_jobs( - flush_jobs=flush_jobs, - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - ) - result.embed_seconds_total += embed_seconds - result.write_seconds_total += write_seconds - (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - deferred_entity_ids=deferred_entity_ids, - ) - except Exception as exc: - affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) - failed_entity_ids.update(affected_entity_ids) - synced_entity_ids.difference_update(affected_entity_ids) - deferred_entity_ids.difference_update(affected_entity_ids) - for failed_entity_id in affected_entity_ids: - entity_runtime.pop(failed_entity_id, None) - logger.warning( - "Vector batch sync flush failed: project_id={project_id} " - "affected_entities={affected_entities} chunk_count={chunk_count} " - "error={error}", - project_id=self.project_id, - affected_entities=affected_entity_ids, - chunk_count=len(flush_jobs), - error=str(exc), - ) - - if pending_jobs: - flush_jobs = list(pending_jobs) - pending_jobs = [] - try: - embed_seconds, write_seconds = await self._flush_embedding_jobs( - flush_jobs=flush_jobs, - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - ) - result.embed_seconds_total += embed_seconds - result.write_seconds_total += write_seconds - (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - deferred_entity_ids=deferred_entity_ids, - ) - except Exception as exc: - affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) - failed_entity_ids.update(affected_entity_ids) - synced_entity_ids.difference_update(affected_entity_ids) - deferred_entity_ids.difference_update(affected_entity_ids) - for failed_entity_id in affected_entity_ids: - entity_runtime.pop(failed_entity_id, None) - logger.warning( - "Vector batch sync final flush failed: project_id={project_id} " - "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", - project_id=self.project_id, - affected_entities=affected_entity_ids, - chunk_count=len(flush_jobs), - error=str(exc), - ) - - if entity_runtime: - orphan_runtime_entities = sorted(entity_runtime.keys()) - failed_entity_ids.update(orphan_runtime_entities) - synced_entity_ids.difference_update(orphan_runtime_entities) - deferred_entity_ids.difference_update(orphan_runtime_entities) - logger.warning( - "Vector batch sync left unfinished entities after flushes: " - "project_id={project_id} unfinished_entities={unfinished_entities}", - project_id=self.project_id, - unfinished_entities=orphan_runtime_entities, - ) - - synced_entity_ids.difference_update(failed_entity_ids) - deferred_entity_ids.difference_update(failed_entity_ids) - deferred_entity_ids.difference_update(synced_entity_ids) - result.failed_entity_ids = sorted(failed_entity_ids) - result.entities_failed = len(result.failed_entity_ids) - result.entities_deferred = len(deferred_entity_ids) - result.entities_synced = len(synced_entity_ids) - - logger.info( - "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " - "entities_synced={entities_synced} entities_failed={entities_failed} " - "entities_deferred={entities_deferred} " - "entities_skipped={entities_skipped} chunks_total={chunks_total} " - "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " - "prepare_seconds_total={prepare_seconds_total:.3f} " - "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " - "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", - project_id=self.project_id, - entities_total=result.entities_total, - entities_synced=result.entities_synced, - entities_failed=result.entities_failed, - entities_deferred=result.entities_deferred, - entities_skipped=result.entities_skipped, - chunks_total=result.chunks_total, - chunks_skipped=result.chunks_skipped, - embedding_jobs_total=result.embedding_jobs_total, - prepare_seconds_total=result.prepare_seconds_total, - queue_wait_seconds_total=result.queue_wait_seconds_total, - embed_seconds_total=result.embed_seconds_total, - write_seconds_total=result.write_seconds_total, - ) - - return result + return [ + cast(_PreparedEntityVectorSync | BaseException, prepared) + for prepared in prepared_window + ] async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync: """Prepare chunk mutations with Postgres-specific bulk upserts.""" @@ -783,6 +554,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe source_rows_count=source_rows_count, ) await self._delete_entity_chunks(session, entity_id) + await session.commit() prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, @@ -807,6 +579,7 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) if not chunk_records: await self._delete_entity_chunks(session, entity_id) + await session.commit() prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, @@ -830,12 +603,12 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe existing_rows = existing_rows_result.mappings().all() existing_by_key = {str(row["chunk_key"]): row for row in existing_rows} existing_chunks_count = len(existing_by_key) - incoming_by_key = {record["chunk_key"]: record for record in chunk_records} + incoming_chunk_keys = {record["chunk_key"] for record in chunk_records} stale_ids = [ int(row["id"]) for chunk_key, row in existing_by_key.items() - if chunk_key not in incoming_by_key + if chunk_key not in incoming_chunk_keys ] stale_chunks_count = len(stale_ids) if stale_ids: @@ -934,6 +707,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "entity_id": entity_id, } upsert_values: list[str] = [] + # The SQL template is built from integer enumerate() indices only. + # No user-controlled text is interpolated into the statement. for index, record in enumerate(upsert_records): upsert_params[f"chunk_key_{index}"] = record["chunk_key"] upsert_params[f"chunk_text_{index}"] = record["chunk_text"] @@ -1028,6 +803,8 @@ async def _write_embeddings( params: dict[str, object] = {"project_id": self.project_id} value_rows: list[str] = [] + # The SQL template is built from integer enumerate() indices only. + # No user-controlled text is interpolated into the statement. for index, ((row_id, _), vector) in enumerate(zip(jobs, embeddings, strict=True)): params[f"chunk_id_{index}"] = row_id params[f"embedding_{index}"] = self._format_pgvector_literal(vector) diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 80c68b5f4..432e6c5e3 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -807,10 +807,11 @@ async def _sync_entity_vectors_internal( logger.info( "Vector batch sync start: project_id={project_id} entities_total={entities_total} " - "sync_batch_size={sync_batch_size}", + "sync_batch_size={sync_batch_size} prepare_window_size={prepare_window_size}", project_id=self.project_id, entities_total=total_entities, sync_batch_size=self._semantic_embedding_sync_batch_size, + prepare_window_size=self._vector_prepare_window_size(), ) pending_jobs: list[_PendingEmbeddingJob] = [] @@ -819,52 +820,74 @@ async def _sync_entity_vectors_internal( deferred_entity_ids: set[int] = set() synced_entity_ids: set[int] = set() - for index, entity_id in enumerate(entity_ids): + prepare_window_size = self._vector_prepare_window_size() + for window_start in range(0, total_entities, prepare_window_size): + window_entity_ids = entity_ids[window_start : window_start + prepare_window_size] + if progress_callback is not None: - progress_callback(entity_id, index, total_entities) + for offset, entity_id in enumerate(window_entity_ids, start=window_start): + progress_callback(entity_id, offset, total_entities) - try: - prepared = await self._prepare_entity_vector_jobs(entity_id) - except Exception as exc: - if not continue_on_error: - raise - failed_entity_ids.add(entity_id) - logger.warning( - "Vector batch sync entity prepare failed: project_id={project_id} " - "entity_id={entity_id} error={error}", - project_id=self.project_id, - entity_id=entity_id, - error=str(exc), - ) - continue + prepared_window = await self._prepare_entity_vector_jobs_window(window_entity_ids) - embedding_jobs_count = len(prepared.embedding_jobs) - result.chunks_total += prepared.chunks_total - result.chunks_skipped += prepared.chunks_skipped - if prepared.entity_skipped: - result.entities_skipped += 1 - result.embedding_jobs_total += embedding_jobs_count - result.prepare_seconds_total += prepared.prepare_seconds - - if embedding_jobs_count == 0: - if prepared.entity_complete: - synced_entity_ids.add(entity_id) - else: - deferred_entity_ids.add(entity_id) - total_seconds = time.perf_counter() - prepared.sync_start - queue_wait_seconds = max(0.0, total_seconds - prepared.prepare_seconds) - result.queue_wait_seconds_total += queue_wait_seconds - self._log_vector_sync_complete( - entity_id=entity_id, - total_seconds=total_seconds, - prepare_seconds=prepared.prepare_seconds, - queue_wait_seconds=queue_wait_seconds, - embed_seconds=0.0, - write_seconds=0.0, + for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): + if isinstance(prepared, BaseException): + if not continue_on_error: + raise prepared + failed_entity_ids.add(entity_id) + logger.warning( + "Vector batch sync entity prepare failed: project_id={project_id} " + "entity_id={entity_id} error={error}", + project_id=self.project_id, + entity_id=entity_id, + error=str(prepared), + ) + continue + + embedding_jobs_count = len(prepared.embedding_jobs) + result.chunks_total += prepared.chunks_total + result.chunks_skipped += prepared.chunks_skipped + if prepared.entity_skipped: + result.entities_skipped += 1 + result.embedding_jobs_total += embedding_jobs_count + result.prepare_seconds_total += prepared.prepare_seconds + + if embedding_jobs_count == 0: + if prepared.entity_complete: + synced_entity_ids.add(entity_id) + else: + deferred_entity_ids.add(entity_id) + total_seconds = time.perf_counter() - prepared.sync_start + queue_wait_seconds = max(0.0, total_seconds - prepared.prepare_seconds) + result.queue_wait_seconds_total += queue_wait_seconds + self._log_vector_sync_complete( + entity_id=entity_id, + total_seconds=total_seconds, + prepare_seconds=prepared.prepare_seconds, + queue_wait_seconds=queue_wait_seconds, + embed_seconds=0.0, + write_seconds=0.0, + source_rows_count=prepared.source_rows_count, + chunks_total=prepared.chunks_total, + chunks_skipped=prepared.chunks_skipped, + embedding_jobs_count=0, + entity_skipped=prepared.entity_skipped, + entity_complete=prepared.entity_complete, + oversized_entity=prepared.oversized_entity, + pending_jobs_total=prepared.pending_jobs_total, + shard_index=prepared.shard_index, + shard_count=prepared.shard_count, + remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, + ) + continue + + entity_runtime[entity_id] = _EntitySyncRuntime( + sync_start=prepared.sync_start, source_rows_count=prepared.source_rows_count, + embedding_jobs_count=embedding_jobs_count, + remaining_jobs=embedding_jobs_count, chunks_total=prepared.chunks_total, chunks_skipped=prepared.chunks_skipped, - embedding_jobs_count=0, entity_skipped=prepared.entity_skipped, entity_complete=prepared.entity_complete, oversized_entity=prepared.oversized_entity, @@ -872,65 +895,48 @@ async def _sync_entity_vectors_internal( shard_index=prepared.shard_index, shard_count=prepared.shard_count, remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, + prepare_seconds=prepared.prepare_seconds, ) - continue - - entity_runtime[entity_id] = _EntitySyncRuntime( - sync_start=prepared.sync_start, - source_rows_count=prepared.source_rows_count, - embedding_jobs_count=embedding_jobs_count, - remaining_jobs=embedding_jobs_count, - chunks_total=prepared.chunks_total, - chunks_skipped=prepared.chunks_skipped, - entity_skipped=prepared.entity_skipped, - entity_complete=prepared.entity_complete, - oversized_entity=prepared.oversized_entity, - pending_jobs_total=prepared.pending_jobs_total, - shard_index=prepared.shard_index, - shard_count=prepared.shard_count, - remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, - prepare_seconds=prepared.prepare_seconds, - ) - pending_jobs.extend( - _PendingEmbeddingJob( - entity_id=entity_id, chunk_row_id=row_id, chunk_text=chunk_text + pending_jobs.extend( + _PendingEmbeddingJob( + entity_id=entity_id, chunk_row_id=row_id, chunk_text=chunk_text + ) + for row_id, chunk_text in prepared.embedding_jobs ) - for row_id, chunk_text in prepared.embedding_jobs - ) - while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: - flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] - pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] - try: - embed_seconds, write_seconds = await self._flush_embedding_jobs( - flush_jobs=flush_jobs, - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - ) - result.embed_seconds_total += embed_seconds - result.write_seconds_total += write_seconds - (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - deferred_entity_ids=deferred_entity_ids, - ) - except Exception as exc: - if not continue_on_error: - raise - affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) - failed_entity_ids.update(affected_entity_ids) - synced_entity_ids.difference_update(affected_entity_ids) - deferred_entity_ids.difference_update(affected_entity_ids) - for failed_entity_id in affected_entity_ids: - entity_runtime.pop(failed_entity_id, None) - logger.warning( - "Vector batch sync flush failed: project_id={project_id} " - "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", - project_id=self.project_id, - affected_entities=affected_entity_ids, - chunk_count=len(flush_jobs), - error=str(exc), - ) + while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: + flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] + pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] + try: + embed_seconds, write_seconds = await self._flush_embedding_jobs( + flush_jobs=flush_jobs, + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) + result.embed_seconds_total += embed_seconds + result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, + ) + except Exception as exc: + if not continue_on_error: + raise + affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) + failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) + for failed_entity_id in affected_entity_ids: + entity_runtime.pop(failed_entity_id, None) + logger.warning( + "Vector batch sync flush failed: project_id={project_id} " + "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", + project_id=self.project_id, + affected_entities=affected_entity_ids, + chunk_count=len(flush_jobs), + error=str(exc), + ) if pending_jobs: flush_jobs = list(pending_jobs) @@ -1016,6 +1022,26 @@ async def _sync_entity_vectors_internal( return result + def _vector_prepare_window_size(self) -> int: + """Return the number of entities to prepare in one orchestration window.""" + return 1 + + async def _prepare_entity_vector_jobs_window( + self, entity_ids: list[int] + ) -> list[_PreparedEntityVectorSync | BaseException]: + """Prepare one window of entity vector jobs. + + Default implementation is sequential to preserve backend behavior. + Postgres overrides this to use a bounded concurrent gather. + """ + prepared_window: list[_PreparedEntityVectorSync | BaseException] = [] + for entity_id in entity_ids: + try: + prepared_window.append(await self._prepare_entity_vector_jobs(entity_id)) + except Exception as exc: + prepared_window.append(exc) + return prepared_window + async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync: """Prepare chunk mutations and embedding jobs for one entity.""" sync_start = time.perf_counter() From 16fc164b4ee0fe5773a4e24b0e0321d9673f3f94 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 12:58:13 -0500 Subject: [PATCH 09/10] fix(core): restore semantic vector skip and benchmark gating Signed-off-by: phernandez --- src/basic_memory/repository/search_repository_base.py | 4 ++-- test-int/semantic/conftest.py | 6 ++++++ test-int/semantic/test_semantic_quality.py | 5 +++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 432e6c5e3..52fcfc8c5 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -1236,8 +1236,8 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe "embedding_model": current_embedding_model, }, ) - skipped_chunks_count += 1 - continue + skipped_chunks_count += 1 + continue pending_records.append(record) diff --git a/test-int/semantic/conftest.py b/test-int/semantic/conftest.py index 4369906b0..9d29a1362 100644 --- a/test-int/semantic/conftest.py +++ b/test-int/semantic/conftest.py @@ -65,6 +65,11 @@ class SearchCombo: SearchCombo("postgres-openai", DatabaseBackend.POSTGRES, "openai", 1536), ] +# Benchmark queries compare ranking quality across providers rather than enforcing +# the stricter production retrieval cutoff. OpenAI paraphrase matches cluster near +# ~0.37 in this corpus, so the default 0.55 filter hides otherwise-correct results. +BENCHMARK_MIN_SIMILARITY = 0.3 + # --- Skip guards --- @@ -229,6 +234,7 @@ async def create_search_service( default_project="bench-project", database_backend=combo.backend, semantic_search_enabled=semantic_enabled, + semantic_min_similarity=BENCHMARK_MIN_SIMILARITY, ) # Create search repository (backend-specific) diff --git a/test-int/semantic/test_semantic_quality.py b/test-int/semantic/test_semantic_quality.py index a74292fe2..cd6463a41 100644 --- a/test-int/semantic/test_semantic_quality.py +++ b/test-int/semantic/test_semantic_quality.py @@ -51,8 +51,9 @@ ("sqlite-fastembed", "paraphrase", "hybrid"): 0.25, ("postgres-fastembed", "lexical", "hybrid"): 0.37, ("postgres-fastembed", "paraphrase", "hybrid"): 0.25, - # OpenAI metrics are still recorded, but we do not gate on them yet. - # The current benchmark corpus is too small to make that combo stable. + # OpenAI hybrid should handle paraphrases better than FastEmbed. + ("postgres-openai", "lexical", "hybrid"): 0.37, + ("postgres-openai", "paraphrase", "hybrid"): 0.25, } From e1df151cd128e7fa61c3f1a6b21ec5ea658f2d0b Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 13:28:50 -0500 Subject: [PATCH 10/10] fix(core): address vector sync follow-up regressions Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 3 --- .../repository/search_repository_base.py | 8 +++----- .../repository/sqlite_search_repository.py | 7 ++++--- tests/cli/test_cli_exit.py | 5 ++++- .../test_project_service_embedding_status.py | 20 +++++++++++++++---- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index 47785c473..f9b425694 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -870,9 +870,6 @@ async def _delete_stale_chunks( stale_params, ) - async def _update_timestamp_sql(self) -> str: - return "NOW()" # pragma: no cover - def _distance_to_similarity(self, distance: float) -> float: """Convert pgvector cosine distance to cosine similarity. diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 52fcfc8c5..8abcaf17f 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -288,11 +288,6 @@ async def _delete_stale_chunks( """Delete stale chunk rows (and their embeddings) by ID.""" pass - @abstractmethod - async def _update_timestamp_sql(self) -> str: - """Return the SQL expression for current timestamp in the backend.""" - pass # pragma: no cover - @abstractmethod def _distance_to_similarity(self, distance: float) -> float: """Convert a backend-specific vector distance to cosine similarity in [0, 1]. @@ -825,6 +820,9 @@ async def _sync_entity_vectors_internal( window_entity_ids = entity_ids[window_start : window_start + prepare_window_size] if progress_callback is not None: + # Trigger: Postgres prepares one bounded entity window concurrently. + # Why: callbacks still need per-entity progress positions before the gather starts. + # Outcome: progress advances in prepare_window_size bursts instead of strict one-by-one. for offset, entity_id in enumerate(window_entity_ids, start=window_start): progress_callback(entity_id, offset, total_entities) diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index a97c01887..cc42331b8 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -404,6 +404,10 @@ async def _ensure_vector_tables(self) -> None: } schema_mismatch = bool(chunks_columns) and set(chunks_columns) != expected_columns if schema_mismatch: + # Trigger: older SQLite installs are missing newly required chunk metadata columns. + # Why: vector tables store derived data only, so rebuilding them is safer than + # attempting piecemeal ALTER TABLE compatibility across sqlite-vec upgrades. + # Outcome: first startup after the schema change forces a clean re-embed. logger.warning("search_vector_chunks schema mismatch, recreating vector tables") await session.execute(text("DROP TABLE IF EXISTS search_vector_embeddings")) await session.execute(text("DROP TABLE IF EXISTS search_vector_chunks")) @@ -554,9 +558,6 @@ async def _delete_stale_chunks( stale_params, ) - async def _update_timestamp_sql(self) -> str: - return "CURRENT_TIMESTAMP" # pragma: no cover - def _distance_to_similarity(self, distance: float) -> float: """Convert L2 distance to cosine similarity for normalized embeddings. diff --git a/tests/cli/test_cli_exit.py b/tests/cli/test_cli_exit.py index f19754f7f..4a7defc78 100644 --- a/tests/cli/test_cli_exit.py +++ b/tests/cli/test_cli_exit.py @@ -28,7 +28,10 @@ def test_bm_help_exits_cleanly(): ["uv", "run", "bm", "--help"], capture_output=True, text=True, - timeout=10, + # Help builds the full command tree, so use a looser timeout than the + # version fast path. This test is guarding against hangs, not enforcing + # a tight performance budget under full-suite load. + timeout=20, cwd=Path(__file__).parent.parent.parent, ) assert result.returncode == 0 diff --git a/tests/services/test_project_service_embedding_status.py b/tests/services/test_project_service_embedding_status.py index ff17fa981..a542b8d53 100644 --- a/tests/services/test_project_service_embedding_status.py +++ b/tests/services/test_project_service_embedding_status.py @@ -105,8 +105,14 @@ async def test_embedding_status_orphaned_chunks( await project_service.repository.execute_query( text( "INSERT INTO search_vector_chunks " - "(entity_id, project_id, chunk_key, chunk_text, source_hash) " - "VALUES (:entity_id, :project_id, 'chunk-1', 'test text', 'abc123')" + "(" + "entity_id, project_id, chunk_key, chunk_text, source_hash, " + "entity_fingerprint, embedding_model" + ") " + "VALUES (" + ":entity_id, :project_id, 'chunk-1', 'test text', 'abc123', " + "'fp-abc123', 'bge-small-en-v1.5'" + ")" ), {"entity_id": entity_id, "project_id": test_project.id}, ) @@ -213,8 +219,14 @@ async def test_embedding_status_healthy(project_service: ProjectService, test_gr await project_service.repository.execute_query( text( "INSERT INTO search_vector_chunks " - "(id, entity_id, project_id, chunk_key, chunk_text, source_hash) " - "VALUES (:id, :entity_id, :project_id, :key, 'text', 'hash')" + "(" + "id, entity_id, project_id, chunk_key, chunk_text, source_hash, " + "entity_fingerprint, embedding_model" + ") " + "VALUES (" + ":id, :entity_id, :project_id, :key, 'text', 'hash', " + "'fp-hash', 'bge-small-en-v1.5'" + ")" ), { "id": chunk_id,