From 8f8d017529556e286d077ddc4010104ec9f8f494 Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 23:21:41 -0500 Subject: [PATCH 1/9] perf(sync): batch file indexing in core Signed-off-by: phernandez --- src/basic_memory/cli/commands/db.py | 74 ++- src/basic_memory/config.py | 30 + src/basic_memory/indexing/__init__.py | 27 + src/basic_memory/indexing/batch_indexer.py | 550 ++++++++++++++++++ src/basic_memory/indexing/batching.py | 63 ++ src/basic_memory/indexing/models.py | 83 +++ .../repository/embedding_provider_factory.py | 13 +- .../repository/openai_provider.py | 50 +- src/basic_memory/sync/sync_service.py | 341 +++++++++-- tests/indexing/test_batch_indexer.py | 400 +++++++++++++ tests/indexing/test_batching.py | 85 +++ tests/repository/test_openai_provider.py | 96 +++ tests/sync/test_sync_service_batching.py | 92 +++ tests/sync/test_sync_service_telemetry.py | 12 +- 14 files changed, 1850 insertions(+), 66 deletions(-) create mode 100644 src/basic_memory/indexing/__init__.py create mode 100644 src/basic_memory/indexing/batch_indexer.py create mode 100644 src/basic_memory/indexing/batching.py create mode 100644 src/basic_memory/indexing/models.py create mode 100644 tests/indexing/test_batch_indexer.py create mode 100644 tests/indexing/test_batching.py create mode 100644 tests/sync/test_sync_service_batching.py diff --git a/src/basic_memory/cli/commands/db.py b/src/basic_memory/cli/commands/db.py index 08595511..dadc94fe 100644 --- a/src/basic_memory/cli/commands/db.py +++ b/src/basic_memory/cli/commands/db.py @@ -1,5 +1,6 @@ """Database management commands.""" +from dataclasses import dataclass from pathlib import Path import typer @@ -12,6 +13,7 @@ from basic_memory.cli.app import app from basic_memory.cli.commands.command_utils import run_with_cleanup from basic_memory.config import ConfigManager, ProjectMode +from basic_memory.indexing import IndexProgress from basic_memory.repository import ProjectRepository from basic_memory.services.initialization import reconcile_projects_with_config from basic_memory.sync.sync_service import get_sync_service @@ -19,6 +21,39 @@ console = Console() +@dataclass(slots=True) +class EmbeddingProgress: + """Typed CLI progress payload for embedding backfills.""" + + entity_id: int + index: int + total: int + + +def _format_eta(seconds: float | None) -> str: + """Render a compact ETA string for CLI progress descriptions.""" + if seconds is None: + return "--:--" + + whole_seconds = max(int(seconds), 0) + minutes, remaining_seconds = divmod(whole_seconds, 60) + hours, remaining_minutes = divmod(minutes, 60) + if hours: + return f"{hours:d}:{remaining_minutes:02d}:{remaining_seconds:02d}" + return f"{remaining_minutes:02d}:{remaining_seconds:02d}" + + +def _format_index_progress(progress: IndexProgress) -> str: + """Render typed index progress as a compact Rich task description.""" + files_per_minute = int(progress.files_per_minute) if progress.files_per_minute else 0 + return ( + " Indexing files... " + f"{progress.files_processed}/{progress.files_total} files | " + f"{progress.batches_completed}/{progress.batches_total} batches | " + f"{files_per_minute}/min | ETA {_format_eta(progress.eta_seconds)}" + ) + + async def _reindex_projects(app_config): """Reindex all projects in a single async context. @@ -185,10 +220,34 @@ async def _reindex(app_config, search: bool, embeddings: bool, project: str | No console.print(f"\n[bold]Project: [cyan]{proj.name}[/cyan][/bold]") if search: - console.print(" Rebuilding full-text search index...") sync_service = await get_sync_service(proj) sync_dir = Path(proj.path) - await sync_service.sync(sync_dir, project_name=proj.name) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + ) as progress: + task = progress.add_task(" Indexing files... scanning changes", total=1) + + async def on_index_progress(update: IndexProgress) -> None: + total = update.files_total or 1 + completed = update.files_processed if update.files_total else 1 + progress.update( + task, + description=_format_index_progress(update), + total=total, + completed=min(completed, total), + ) + + await sync_service.sync( + sync_dir, + project_name=proj.name, + progress_callback=on_index_progress, + ) + progress.update(task, completed=progress.tasks[task].total or 1) + console.print(" [green]✓[/green] Full-text search index rebuilt") if embeddings: @@ -213,7 +272,16 @@ async def _reindex(app_config, search: bool, embeddings: bool, project: str | No task = progress.add_task(" Embedding entities...", total=None) def on_progress(entity_id, index, total): - progress.update(task, total=total, completed=index) + embedding_progress = EmbeddingProgress( + entity_id=entity_id, + index=index, + total=total, + ) + progress.update( + task, + total=embedding_progress.total, + completed=embedding_progress.index, + ) stats = await search_service.reindex_vectors(progress_callback=on_progress) progress.update(task, completed=stats["total_entities"]) diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index c6d895f4..62579bee 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -193,6 +193,11 @@ class BasicMemoryConfig(BaseSettings): description="Batch size for embedding generation.", gt=0, ) + semantic_embedding_request_concurrency: int = Field( + default=4, + description="Maximum number of concurrent provider requests for batched embedding generation when the active provider supports request-level concurrency.", + gt=0, + ) semantic_embedding_sync_batch_size: int = Field( default=64, description="Batch size for vector sync orchestration flushes.", @@ -286,6 +291,31 @@ class BasicMemoryConfig(BaseSettings): description="Maximum number of files to process concurrently during sync. Limits memory usage on large projects (2000+ files). Lower values reduce memory consumption.", gt=0, ) + index_batch_size: int = Field( + default=32, + description="Maximum number of changed files to load into one indexing batch.", + gt=0, + ) + index_batch_max_bytes: int = Field( + default=8 * 1024 * 1024, + description="Maximum total bytes to load into one indexing batch. Large files still run as single-file batches.", + gt=0, + ) + index_parse_max_concurrent: int = Field( + default=8, + description="Maximum number of markdown parse tasks to run concurrently inside one indexing batch.", + gt=0, + ) + index_entity_max_concurrent: int = Field( + default=4, + description="Maximum number of entity create/update tasks to run concurrently inside one indexing batch.", + gt=0, + ) + index_metadata_update_max_concurrent: int = Field( + default=4, + description="Maximum number of metadata/search refresh tasks to run concurrently inside one indexing batch.", + gt=0, + ) kebab_filenames: bool = Field( default=False, diff --git a/src/basic_memory/indexing/__init__.py b/src/basic_memory/indexing/__init__.py new file mode 100644 index 00000000..369da770 --- /dev/null +++ b/src/basic_memory/indexing/__init__.py @@ -0,0 +1,27 @@ +"""Reusable indexing primitives shared by local sync and future remote callers.""" + +from basic_memory.indexing.batch_indexer import BatchIndexer +from basic_memory.indexing.batching import build_index_batches +from basic_memory.indexing.models import ( + IndexedEntity, + IndexBatch, + IndexFileMetadata, + IndexFileWriter, + IndexFrontmatterUpdate, + IndexingBatchResult, + IndexInputFile, + IndexProgress, +) + +__all__ = [ + "BatchIndexer", + "IndexedEntity", + "IndexBatch", + "IndexFileMetadata", + "IndexFileWriter", + "IndexFrontmatterUpdate", + "IndexingBatchResult", + "IndexInputFile", + "IndexProgress", + "build_index_batches", +] diff --git a/src/basic_memory/indexing/batch_indexer.py b/src/basic_memory/indexing/batch_indexer.py new file mode 100644 index 00000000..09f74114 --- /dev/null +++ b/src/basic_memory/indexing/batch_indexer.py @@ -0,0 +1,550 @@ +"""Reusable batch executor for bounded-parallel file indexing.""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Awaitable, Callable, Mapping, TypeVar + +from loguru import logger +from sqlalchemy.exc import IntegrityError + +from basic_memory.config import BasicMemoryConfig +from basic_memory.file_utils import compute_checksum, has_frontmatter +from basic_memory.markdown.schemas import EntityMarkdown +from basic_memory.indexing.models import ( + IndexedEntity, + IndexFileWriter, + IndexFrontmatterUpdate, + IndexingBatchResult, + IndexInputFile, + IndexProgress, +) +from basic_memory.models import Entity, Relation +from basic_memory.services import EntityService +from basic_memory.services.search_service import SearchService +from basic_memory.repository import EntityRepository, RelationRepository + +T = TypeVar("T") + + +@dataclass(slots=True) +class _PreparedMarkdownFile: + file: IndexInputFile + content: str + final_checksum: str + markdown: EntityMarkdown + file_contains_frontmatter: bool + + +@dataclass(slots=True) +class _PreparedEntity: + path: str + entity_id: int + checksum: str + content_type: str | None + search_content: str | None + + +class BatchIndexer: + """Index already-loaded files without assuming where they came from.""" + + def __init__( + self, + *, + app_config: BasicMemoryConfig, + entity_service: EntityService, + entity_repository: EntityRepository, + relation_repository: RelationRepository, + search_service: SearchService, + file_writer: IndexFileWriter, + ) -> None: + self.app_config = app_config + self.entity_service = entity_service + self.entity_repository = entity_repository + self.relation_repository = relation_repository + self.search_service = search_service + self.file_writer = file_writer + + async def index_files( + self, + files: Mapping[str, IndexInputFile], + *, + max_concurrent: int, + parse_max_concurrent: int | None = None, + progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None, + ) -> IndexingBatchResult: + """Index one batch of loaded files with bounded concurrency.""" + if max_concurrent <= 0: + raise ValueError("max_concurrent must be greater than zero") + + ordered_paths = sorted(files) + if not ordered_paths: + result = IndexingBatchResult() + if progress_callback is not None: + await progress_callback( + IndexProgress( + files_total=0, + files_processed=0, + batches_total=0, + batches_completed=0, + ) + ) + return result + + parse_limit = parse_max_concurrent or max_concurrent + batch_start = time.monotonic() + error_by_path: dict[str, str] = {} + + markdown_paths = [path for path in ordered_paths if self._is_markdown(files[path])] + regular_paths = [path for path in ordered_paths if path not in markdown_paths] + + prepared_markdown, parse_errors = await self._run_bounded( + markdown_paths, + limit=parse_limit, + worker=lambda path: self._prepare_markdown_file(files[path]), + ) + error_by_path.update(parse_errors) + + prepared_markdown, normalization_errors = await self._normalize_markdown_batch( + prepared_markdown + ) + error_by_path.update(normalization_errors) + + indexed_entities: list[IndexedEntity] = [] + resolved_count = 0 + unresolved_count = 0 + search_indexed = 0 + + prepared_entities: dict[str, _PreparedEntity] = {} + + markdown_upserts, markdown_errors = await self._run_bounded( + [path for path in markdown_paths if path not in error_by_path], + limit=max_concurrent, + worker=lambda path: self._upsert_markdown_file(prepared_markdown[path]), + ) + error_by_path.update(markdown_errors) + prepared_entities.update(markdown_upserts) + + regular_upserts, regular_errors = await self._run_bounded( + regular_paths, + limit=max_concurrent, + worker=lambda path: self._upsert_regular_file(files[path]), + ) + error_by_path.update(regular_errors) + prepared_entities.update(regular_upserts) + + markdown_entity_ids = [ + prepared_entities[path].entity_id + for path in markdown_paths + if path in prepared_entities + ] + if markdown_entity_ids: + resolved_count, unresolved_count = await self._resolve_batch_relations( + markdown_entity_ids, + max_concurrent=max_concurrent, + ) + + refreshed_entities = await self.entity_repository.find_by_ids( + [prepared.entity_id for prepared in prepared_entities.values()] + ) + entities_by_id = {entity.id: entity for entity in refreshed_entities} + + refreshed, refresh_errors = await self._run_bounded( + [path for path in ordered_paths if path in prepared_entities], + limit=self.app_config.index_metadata_update_max_concurrent, + worker=lambda path: self._refresh_search_index( + prepared_entities[path], + entities_by_id[prepared_entities[path].entity_id], + ), + ) + error_by_path.update(refresh_errors) + + for path in ordered_paths: + indexed = refreshed.get(path) + if indexed is not None: + indexed_entities.append(indexed) + + search_indexed = len(indexed_entities) + + if progress_callback is not None: + elapsed_seconds = max(time.monotonic() - batch_start, 0.001) + files_per_minute = len(ordered_paths) / elapsed_seconds * 60 + await progress_callback( + IndexProgress( + files_total=len(ordered_paths), + files_processed=len(ordered_paths), + batches_total=1, + batches_completed=1, + current_batch_bytes=sum(max(files[path].size, 0) for path in ordered_paths), + files_per_minute=files_per_minute, + eta_seconds=0.0, + ) + ) + + return IndexingBatchResult( + indexed=indexed_entities, + errors=[(path, error_by_path[path]) for path in ordered_paths if path in error_by_path], + relations_resolved=resolved_count, + relations_unresolved=unresolved_count, + search_indexed=search_indexed, + ) + + # --- Preparation --- + + async def _prepare_markdown_file(self, file: IndexInputFile) -> _PreparedMarkdownFile: + if file.content is None: + raise ValueError(f"Missing content for markdown file: {file.path}") + + content = file.content.decode("utf-8") + file_contains_frontmatter = has_frontmatter(content) + final_checksum = await self._resolve_checksum(file) + entity_markdown = await self.entity_service.entity_parser.parse_markdown_content( + file_path=Path(file.path), + content=content, + mtime=file.last_modified.timestamp() if file.last_modified else None, + ctime=file.created_at.timestamp() if file.created_at else None, + ) + + return _PreparedMarkdownFile( + file=file, + content=content, + final_checksum=final_checksum, + markdown=entity_markdown, + file_contains_frontmatter=file_contains_frontmatter, + ) + + async def _normalize_markdown_batch( + self, + prepared_markdown: dict[str, _PreparedMarkdownFile], + ) -> tuple[dict[str, _PreparedMarkdownFile], dict[str, str]]: + if not prepared_markdown: + return {}, {} + + batch_paths = set(prepared_markdown) + existing_permalink_by_path = await self.entity_repository.get_file_path_to_permalink_map() + reserved_permalinks = { + permalink + for path, permalink in existing_permalink_by_path.items() + if path not in batch_paths and permalink + } + + normalized: dict[str, _PreparedMarkdownFile] = {} + errors: dict[str, str] = {} + + for path in sorted(prepared_markdown): + try: + normalized[path] = await self._normalize_markdown_file( + prepared_markdown[path], + reserved_permalinks, + ) + except Exception as exc: + errors[path] = str(exc) + logger.warning("Batch markdown normalization failed", path=path, error=str(exc)) + + return normalized, errors + + async def _normalize_markdown_file( + self, + prepared: _PreparedMarkdownFile, + reserved_permalinks: set[str], + ) -> _PreparedMarkdownFile: + final_checksum = prepared.final_checksum + final_permalink = await self._resolve_batch_permalink(prepared, reserved_permalinks) + + # Trigger: markdown file has no frontmatter and sync enforcement is enabled. + # Why: downstream indexing relies on normalized metadata and stable permalinks. + # Outcome: write derived metadata back through the storage-agnostic writer. + if not prepared.file_contains_frontmatter and self.app_config.ensure_frontmatter_on_sync: + frontmatter_updates = { + "title": prepared.markdown.frontmatter.title, + "type": prepared.markdown.frontmatter.type, + "permalink": final_permalink, + } + final_checksum = await self.file_writer.write_frontmatter( + IndexFrontmatterUpdate(path=prepared.file.path, metadata=frontmatter_updates) + ) + prepared.markdown.frontmatter.metadata.update(frontmatter_updates) + + # Trigger: existing markdown frontmatter may lack the canonical permalink. + # Why: batch sync keeps permalinks stable without forcing a full rewrite when unchanged. + # Outcome: only the permalink field is updated when it actually differs. + elif ( + prepared.file_contains_frontmatter + and not self.app_config.disable_permalinks + and final_permalink != prepared.markdown.frontmatter.permalink + ): + prepared.markdown.frontmatter.metadata["permalink"] = final_permalink + final_checksum = await self.file_writer.write_frontmatter( + IndexFrontmatterUpdate( + path=prepared.file.path, + metadata={"permalink": final_permalink}, + ) + ) + + return _PreparedMarkdownFile( + file=prepared.file, + content=prepared.content, + final_checksum=final_checksum, + markdown=prepared.markdown, + file_contains_frontmatter=prepared.file_contains_frontmatter, + ) + + async def _resolve_batch_permalink( + self, + prepared: _PreparedMarkdownFile, + reserved_permalinks: set[str], + ) -> str | None: + should_resolve_permalink = ( + not prepared.file_contains_frontmatter and self.app_config.ensure_frontmatter_on_sync + ) or (prepared.file_contains_frontmatter and not self.app_config.disable_permalinks) + if not should_resolve_permalink: + permalink = prepared.markdown.frontmatter.permalink + if permalink: + reserved_permalinks.add(permalink) + return permalink + + desired_permalink = await self.entity_service.resolve_permalink( + prepared.file.path, + markdown=prepared.markdown, + skip_conflict_check=True, + ) + return self._reserve_batch_permalink(desired_permalink, reserved_permalinks) + + def _reserve_batch_permalink( + self, + desired_permalink: str, + reserved_permalinks: set[str], + ) -> str: + permalink = desired_permalink + suffix = 1 + while permalink in reserved_permalinks: + permalink = f"{desired_permalink}-{suffix}" + suffix += 1 + reserved_permalinks.add(permalink) + return permalink + + # --- Persistence --- + + async def _upsert_markdown_file(self, prepared: _PreparedMarkdownFile) -> _PreparedEntity: + existing = await self.entity_repository.get_by_file_path( + prepared.file.path, + load_relations=False, + ) + entity = await self.entity_service.upsert_entity_from_markdown( + Path(prepared.file.path), + prepared.markdown, + is_new=existing is None, + ) + updated = await self.entity_repository.update( + entity.id, + self._entity_metadata_updates(prepared.file, prepared.final_checksum), + ) + if updated is None: + raise ValueError(f"Failed to update markdown entity metadata for {prepared.file.path}") + + return _PreparedEntity( + path=prepared.file.path, + entity_id=updated.id, + checksum=prepared.final_checksum, + content_type=prepared.file.content_type, + search_content=prepared.markdown.content, + ) + + async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity: + checksum = await self._resolve_checksum(file) + existing = await self.entity_repository.get_by_file_path(file.path, load_relations=False) + + if existing is None: + await self.entity_service.resolve_permalink(file.path, skip_conflict_check=True) + entity = Entity( + note_type="file", + file_path=file.path, + checksum=checksum, + title=Path(file.path).name, + created_at=file.created_at or datetime.now().astimezone(), + updated_at=file.last_modified or datetime.now().astimezone(), + content_type=file.content_type or "text/plain", + mtime=file.last_modified.timestamp() if file.last_modified else None, + size=file.size, + ) + + try: + created = await self.entity_repository.add(entity) + entity_id = created.id + except IntegrityError as exc: + message = str(exc) + if ( + "UNIQUE constraint failed: entity.file_path" in message + or "uix_entity_file_path_project" in message + or ( + "duplicate key value violates unique constraint" in message + and "file_path" in message + ) + ): + existing = await self.entity_repository.get_by_file_path( + file.path, + load_relations=False, + ) + if existing is None: + raise ValueError( + f"Entity not found after file_path conflict: {file.path}" + ) from exc + entity_id = existing.id + else: + raise + else: + entity_id = existing.id + + updated = await self.entity_repository.update( + entity_id, + self._entity_metadata_updates(file, checksum, include_created_at=existing is None), + ) + if updated is None: + raise ValueError(f"Failed to update file entity metadata for {file.path}") + + return _PreparedEntity( + path=file.path, + entity_id=updated.id, + checksum=checksum, + content_type=file.content_type, + search_content=None, + ) + + # --- Relations --- + + async def _resolve_batch_relations( + self, + entity_ids: list[int], + *, + max_concurrent: int, + ) -> tuple[int, int]: + unresolved_relations: list[Relation] = [] + for entity_id in entity_ids: + unresolved_relations.extend( + await self.relation_repository.find_unresolved_relations_for_entity(entity_id) + ) + + if not unresolved_relations: + return 0, 0 + + semaphore = asyncio.Semaphore(max_concurrent) + + async def resolve_relation(relation: Relation) -> int: + async with semaphore: + try: + resolved_entity = await self.entity_service.link_resolver.resolve_link( + relation.to_name + ) + if resolved_entity is None or resolved_entity.id == relation.from_id: + return 0 + + try: + await self.relation_repository.update( + relation.id, + { + "to_id": resolved_entity.id, + "to_name": resolved_entity.title, + }, + ) + except IntegrityError: + await self.relation_repository.delete(relation.id) + return 1 + except Exception as exc: # pragma: no cover - defensive logging + logger.warning( + "Batch relation resolution failed", + relation_id=relation.id, + from_id=relation.from_id, + to_name=relation.to_name, + error=str(exc), + ) + return 0 + + resolved_counts = await asyncio.gather( + *(resolve_relation(relation) for relation in unresolved_relations) + ) + + remaining_unresolved = 0 + for entity_id in entity_ids: + remaining_unresolved += len( + await self.relation_repository.find_unresolved_relations_for_entity(entity_id) + ) + + return sum(resolved_counts), remaining_unresolved + + # --- Search refresh --- + + async def _refresh_search_index( + self, prepared: _PreparedEntity, entity: Entity + ) -> IndexedEntity: + await self.search_service.index_entity_data(entity, content=prepared.search_content) + return IndexedEntity( + path=prepared.path, + entity_id=entity.id, + permalink=entity.permalink, + checksum=prepared.checksum, + content_type=prepared.content_type, + ) + + # --- Helpers --- + + async def _resolve_checksum(self, file: IndexInputFile) -> str: + if file.checksum is not None: + return file.checksum + if file.content is None: + raise ValueError(f"Missing checksum and content for file: {file.path}") + return await compute_checksum(file.content) + + def _entity_metadata_updates( + self, + file: IndexInputFile, + checksum: str, + *, + include_created_at: bool = True, + ) -> dict[str, object]: + updates: dict[str, object] = { + "file_path": file.path, + "checksum": checksum, + "size": file.size, + } + if include_created_at and file.created_at is not None: + updates["created_at"] = file.created_at + if file.last_modified is not None: + updates["updated_at"] = file.last_modified + updates["mtime"] = file.last_modified.timestamp() + if file.content_type is not None: + updates["content_type"] = file.content_type + return updates + + def _is_markdown(self, file: IndexInputFile) -> bool: + if file.content_type is not None: + return file.content_type == "text/markdown" + return Path(file.path).suffix.lower() in {".md", ".markdown"} + + async def _run_bounded( + self, + paths: list[str], + *, + limit: int, + worker: Callable[[str], Awaitable[T]], + ) -> tuple[dict[str, T], dict[str, str]]: + if not paths: + return {}, {} + + semaphore = asyncio.Semaphore(limit) + results: dict[str, T] = {} + errors: dict[str, str] = {} + + async def run(path: str) -> None: + async with semaphore: + try: + results[path] = await worker(path) + except Exception as exc: + errors[path] = str(exc) + logger.warning("Batch indexing failed", path=path, error=str(exc)) + + await asyncio.gather(*(run(path) for path in paths)) + return results, errors diff --git a/src/basic_memory/indexing/batching.py b/src/basic_memory/indexing/batching.py new file mode 100644 index 00000000..397f7af0 --- /dev/null +++ b/src/basic_memory/indexing/batching.py @@ -0,0 +1,63 @@ +"""Deterministic helpers for planning bounded indexing batches.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence + +from basic_memory.indexing.models import IndexBatch, IndexFileMetadata + + +def build_index_batches( + paths: Sequence[str], + metadata_by_path: Mapping[str, IndexFileMetadata], + *, + max_files: int, + max_bytes: int, +) -> list[IndexBatch]: + """Build deterministic batches bounded by file count and total bytes.""" + if max_files <= 0: + raise ValueError("max_files must be greater than zero") + if max_bytes <= 0: + raise ValueError("max_bytes must be greater than zero") + + ordered_paths = sorted(paths) + batches: list[IndexBatch] = [] + current_paths: list[str] = [] + current_bytes = 0 + + for path in ordered_paths: + metadata = metadata_by_path.get(path) + if metadata is None: + raise KeyError(f"Missing metadata for path: {path}") + + file_bytes = max(metadata.size, 0) + + # Trigger: the next file would overflow the active batch. + # Why: keep batches memory-bounded and predictable for both local and remote callers. + # Outcome: flush the current batch before placing the next file. + if current_paths and ( + len(current_paths) >= max_files or current_bytes + file_bytes > max_bytes + ): + batches.append(IndexBatch(paths=current_paths, total_bytes=current_bytes)) + current_paths = [] + current_bytes = 0 + + # Trigger: one file is larger than the configured byte budget. + # Why: we still need to index it, but splitting a single file is out of scope. + # Outcome: emit a dedicated single-file batch that may exceed max_bytes. + if file_bytes > max_bytes: + batches.append(IndexBatch(paths=[path], total_bytes=file_bytes)) + continue + + current_paths.append(path) + current_bytes += file_bytes + + if len(current_paths) >= max_files or current_bytes >= max_bytes: + batches.append(IndexBatch(paths=current_paths, total_bytes=current_bytes)) + current_paths = [] + current_bytes = 0 + + if current_paths: + batches.append(IndexBatch(paths=current_paths, total_bytes=current_bytes)) + + return batches diff --git a/src/basic_memory/indexing/models.py b/src/basic_memory/indexing/models.py new file mode 100644 index 00000000..583cca65 --- /dev/null +++ b/src/basic_memory/indexing/models.py @@ -0,0 +1,83 @@ +"""Typed models for the reusable indexing execution path.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Protocol + + +@dataclass(slots=True) +class IndexFileMetadata: + """Storage-agnostic metadata for a file queued for indexing.""" + + path: str + size: int + checksum: str | None = None + content_type: str | None = None + last_modified: datetime | None = None + created_at: datetime | None = None + + +@dataclass(slots=True) +class IndexInputFile(IndexFileMetadata): + """Fully loaded file payload consumed by the batch executor.""" + + content: bytes | None = None + + +@dataclass(slots=True) +class IndexBatch: + """A deterministic batch of files bounded by count and total bytes.""" + + paths: list[str] + total_bytes: int + + +@dataclass(slots=True) +class IndexProgress: + """Batch indexing progress emitted to callers such as the CLI.""" + + files_total: int + files_processed: int + batches_total: int + batches_completed: int + current_batch_bytes: int = 0 + files_per_minute: float = 0.0 + eta_seconds: float | None = None + + +@dataclass(slots=True) +class IndexFrontmatterUpdate: + """A typed frontmatter write request for a single file.""" + + path: str + metadata: dict[str, Any] + + +@dataclass(slots=True) +class IndexedEntity: + """Stable output describing one file that finished indexing successfully.""" + + path: str + entity_id: int + permalink: str | None + checksum: str + content_type: str | None = None + + +@dataclass(slots=True) +class IndexingBatchResult: + """Outcome for one batch execution.""" + + indexed: list[IndexedEntity] = field(default_factory=list) + errors: list[tuple[str, str]] = field(default_factory=list) + relations_resolved: int = 0 + relations_unresolved: int = 0 + search_indexed: int = 0 + + +class IndexFileWriter(Protocol): + """Narrow protocol for frontmatter writes during indexing.""" + + async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: ... diff --git a/src/basic_memory/repository/embedding_provider_factory.py b/src/basic_memory/repository/embedding_provider_factory.py index 856ed6b2..e259c62b 100644 --- a/src/basic_memory/repository/embedding_provider_factory.py +++ b/src/basic_memory/repository/embedding_provider_factory.py @@ -5,7 +5,16 @@ from basic_memory.config import BasicMemoryConfig from basic_memory.repository.embedding_provider import EmbeddingProvider -type ProviderCacheKey = tuple[str, str, int | None, int, str | None, int | None, int | None] +type ProviderCacheKey = tuple[ + str, + str, + int | None, + int, + int, + str | None, + int | None, + int | None, +] _EMBEDDING_PROVIDER_CACHE: dict[ProviderCacheKey, EmbeddingProvider] = {} _EMBEDDING_PROVIDER_CACHE_LOCK = Lock() @@ -18,6 +27,7 @@ def _provider_cache_key(app_config: BasicMemoryConfig) -> ProviderCacheKey: app_config.semantic_embedding_model, app_config.semantic_embedding_dimensions, app_config.semantic_embedding_batch_size, + app_config.semantic_embedding_request_concurrency, app_config.semantic_embedding_cache_dir, app_config.semantic_embedding_threads, app_config.semantic_embedding_parallel, @@ -73,6 +83,7 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide provider = OpenAIEmbeddingProvider( model_name=model_name, batch_size=app_config.semantic_embedding_batch_size, + request_concurrency=app_config.semantic_embedding_request_concurrency, **extra_kwargs, ) else: diff --git a/src/basic_memory/repository/openai_provider.py b/src/basic_memory/repository/openai_provider.py index 4d3198b9..479bce12 100644 --- a/src/basic_memory/repository/openai_provider.py +++ b/src/basic_memory/repository/openai_provider.py @@ -18,6 +18,7 @@ def __init__( model_name: str = "text-embedding-3-small", *, batch_size: int = 64, + request_concurrency: int = 4, dimensions: int = 1536, api_key: str | None = None, base_url: str | None = None, @@ -26,6 +27,7 @@ def __init__( self.model_name = model_name self.dimensions = dimensions self.batch_size = batch_size + self.request_concurrency = request_concurrency self._api_key = api_key self._base_url = base_url self._timeout = timeout @@ -67,25 +69,49 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]: return [] client = await self._get_client() - all_vectors: list[list[float]] = [] + batches = [ + texts[start : start + self.batch_size] + for start in range(0, len(texts), self.batch_size) + ] + batch_vectors: list[list[list[float]] | None] = [None] * len(batches) + semaphore = asyncio.Semaphore(self.request_concurrency) + + async def embed_batch(batch_index: int, batch: list[str]) -> None: + async with semaphore: + response = await client.embeddings.create( + model=self.model_name, + input=batch, + ) - for start in range(0, len(texts), self.batch_size): - batch = texts[start : start + self.batch_size] - response = await client.embeddings.create( - model=self.model_name, - input=batch, - ) - vectors_by_index: dict[int, list[float]] = { - int(item.index): [float(value) for value in item.embedding] - for item in response.data - } + vectors_by_index: dict[int, list[float]] = {} + for item in response.data: + response_index = int(item.index) + if response_index in vectors_by_index: + raise RuntimeError( + "OpenAI embedding response returned duplicate vector indexes." + ) + vectors_by_index[response_index] = [float(value) for value in item.embedding] + + ordered_vectors: list[list[float]] = [] for index in range(len(batch)): vector = vectors_by_index.get(index) if vector is None: raise RuntimeError( "OpenAI embedding response is missing expected vector index." ) - all_vectors.append(vector) + ordered_vectors.append(vector) + + batch_vectors[batch_index] = ordered_vectors + + await asyncio.gather( + *(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches)) + ) + + all_vectors: list[list[float]] = [] + for vectors in batch_vectors: + if vectors is None: + raise RuntimeError("OpenAI embedding batch did not produce vectors.") + all_vectors.extend(vectors) if all_vectors and len(all_vectors[0]) != self.dimensions: raise RuntimeError( diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index 63e59882..a781f474 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import AsyncIterator, Dict, List, Optional, Set, Tuple +from typing import AsyncIterator, Awaitable, Callable, Dict, List, Optional, Set, Tuple import aiofiles.os @@ -19,6 +19,9 @@ from basic_memory import db from basic_memory.config import BasicMemoryConfig, ConfigManager from basic_memory.file_utils import has_frontmatter +from basic_memory.indexing import BatchIndexer, IndexFileMetadata, IndexInputFile, IndexProgress +from basic_memory.indexing.batching import build_index_batches +from basic_memory.indexing.models import IndexedEntity, IndexFileWriter, IndexFrontmatterUpdate from basic_memory.ignore_utils import load_bmignore_patterns, should_ignore_path from basic_memory.markdown import EntityParser, MarkdownProcessor from basic_memory.models import Entity, Project @@ -118,6 +121,16 @@ class ScanResult: errors: Dict[str, str] = field(default_factory=dict) +class _FileServiceIndexWriter(IndexFileWriter): + """Adapt FileService frontmatter updates to the indexing writer protocol.""" + + def __init__(self, file_service: FileService) -> None: + self.file_service = file_service + + async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: + return await self.file_service.update_frontmatter(update.path, update.metadata) + + class SyncService: """Syncs documents and knowledge files with database.""" @@ -146,6 +159,14 @@ def __init__( # Use OrderedDict for LRU behavior with bounded size to prevent unbounded memory growth self._file_failures: OrderedDict[str, FileFailureInfo] = OrderedDict() self._max_tracked_failures = 100 # Limit failure cache size + self.batch_indexer = BatchIndexer( + app_config=app_config, + entity_service=entity_service, + entity_repository=entity_repository, + relation_repository=relation_repository, + search_service=search_service, + file_writer=_FileServiceIndexWriter(file_service), + ) async def _should_skip_file(self, path: str) -> bool: """Check if file should be skipped due to repeated failures. @@ -255,7 +276,11 @@ def _clear_failure(self, path: str) -> None: del self._file_failures[path] async def sync( - self, directory: Path, project_name: Optional[str] = None, force_full: bool = False + self, + directory: Path, + project_name: Optional[str] = None, + force_full: bool = False, + progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None, ) -> SyncReport: """Sync all files with database and update scan watermark. @@ -263,6 +288,7 @@ async def sync( directory: Directory to sync project_name: Optional project name force_full: If True, force a full scan bypassing watermark optimization + progress_callback: Optional callback for typed indexing progress updates """ start_time = time.time() @@ -310,42 +336,14 @@ async def sync( for path in report.deleted: await self.handle_delete(path) - # then new and modified — collect entity IDs for batch vector embedding - synced_entity_ids: list[int] = [] - - for path in report.new: - entity, _ = await self.sync_file(path, new=True) - - if entity is not None: - synced_entity_ids.append(entity.id) - # Track if file was skipped - elif await self._should_skip_file(path): - failure_info = self._file_failures[path] - report.skipped_files.append( - SkippedFile( - path=path, - reason=failure_info.last_error, - failure_count=failure_info.count, - first_failed=failure_info.first_failure, - ) - ) - - for path in report.modified: - entity, _ = await self.sync_file(path, new=False) - - if entity is not None: - synced_entity_ids.append(entity.id) - # Track if file was skipped - elif await self._should_skip_file(path): - failure_info = self._file_failures[path] - report.skipped_files.append( - SkippedFile( - path=path, - reason=failure_info.last_error, - failure_count=failure_info.count, - first_failed=failure_info.first_failure, - ) - ) + changed_paths = sorted(report.new | report.modified) + indexed_entities, skipped_files = await self._index_changed_files( + changed_paths, + report.checksums, + progress_callback=progress_callback, + ) + report.skipped_files.extend(skipped_files) + synced_entity_ids = [indexed.entity_id for indexed in indexed_entities] # Only resolve relations if there were actual changes # If no files changed, no new unresolved relations could have been created @@ -353,11 +351,12 @@ async def sync( with telemetry.scope( "sync.project.resolve_relations", relation_scope="all_pending" ): - await self.resolve_relations() + synced_entity_ids.extend(await self.resolve_relations()) else: logger.info("Skipping relation resolution - no file changes detected") # Batch-generate vector embeddings for all synced entities + synced_entity_ids = list(dict.fromkeys(synced_entity_ids)) if synced_entity_ids and self.app_config.semantic_search_enabled: try: with telemetry.scope( @@ -425,6 +424,253 @@ async def sync( return report + async def _index_changed_files( + self, + changed_paths: list[str], + checksums_by_path: dict[str, str], + *, + progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None, + ) -> tuple[list[IndexedEntity], list[SkippedFile]]: + """Load, batch, and index changed files without processing them serially.""" + if not changed_paths: + if progress_callback is not None: + await progress_callback( + IndexProgress( + files_total=0, + files_processed=0, + batches_total=0, + batches_completed=0, + ) + ) + return [], [] + + started_at = time.monotonic() + files_total = len(changed_paths) + files_processed = 0 + batches_completed = 0 + skipped_files: list[SkippedFile] = [] + skipped_paths: set[str] = set() + candidate_paths: list[str] = [] + + # Trigger: a file exceeded the retry threshold in a previous sync. + # Why: repeated retries on unchanged broken files waste the entire batch budget. + # Outcome: skip it up front and still count it in progress. + for path in changed_paths: + if await self._should_skip_file(path): + self._append_skipped_file(path, skipped_files, skipped_paths) + files_processed += 1 + else: + candidate_paths.append(path) + + ( + metadata_by_path, + metadata_errors, + missing_metadata_paths, + ) = await self._load_index_file_metadata( + candidate_paths, + checksums_by_path, + ) + files_processed += len(missing_metadata_paths) + files_processed += len(metadata_errors) + for path, error in metadata_errors: + await self._record_index_failure(path, error, skipped_files, skipped_paths) + + batch_paths = sorted(metadata_by_path) + batches = build_index_batches( + batch_paths, + metadata_by_path, + max_files=self.app_config.index_batch_size, + max_bytes=self.app_config.index_batch_max_bytes, + ) + + await self._emit_index_progress( + progress_callback, + files_total=files_total, + files_processed=files_processed, + batches_total=len(batches), + batches_completed=0, + current_batch_bytes=0, + started_at=started_at, + ) + + indexed_entities: list[IndexedEntity] = [] + for batch in batches: + loaded_files, load_errors = await self._load_index_batch_files( + batch.paths, metadata_by_path + ) + for path, error in load_errors: + await self._record_index_failure(path, error, skipped_files, skipped_paths) + + if loaded_files: + batch_result = await self.batch_indexer.index_files( + loaded_files, + max_concurrent=self.app_config.index_entity_max_concurrent, + parse_max_concurrent=self.app_config.index_parse_max_concurrent, + ) + indexed_entities.extend(batch_result.indexed) + + indexed_paths = {indexed.path for indexed in batch_result.indexed} + for path in indexed_paths: + self._clear_failure(path) + + for path, error in batch_result.errors: + await self._record_index_failure(path, error, skipped_files, skipped_paths) + + files_processed += len(batch.paths) + batches_completed += 1 + await self._emit_index_progress( + progress_callback, + files_total=files_total, + files_processed=files_processed, + batches_total=len(batches), + batches_completed=batches_completed, + current_batch_bytes=batch.total_bytes, + started_at=started_at, + ) + + return indexed_entities, skipped_files + + async def _load_index_file_metadata( + self, + paths: list[str], + checksums_by_path: dict[str, str], + ) -> tuple[dict[str, IndexFileMetadata], list[tuple[str, str]], list[str]]: + """Load typed metadata for batch planning before any file content is read.""" + if not paths: + return {}, [], [] + + semaphore = asyncio.Semaphore(self.app_config.sync_max_concurrent_files) + metadata_by_path: dict[str, IndexFileMetadata] = {} + errors: dict[str, str] = {} + missing_paths: list[str] = [] + + async def load(path: str) -> None: + async with semaphore: + try: + file_metadata = await self.file_service.get_file_metadata(path) + metadata_by_path[path] = IndexFileMetadata( + path=path, + size=file_metadata.size, + checksum=checksums_by_path.get(path), + content_type=self.file_service.content_type(path), + last_modified=file_metadata.modified_at, + created_at=file_metadata.created_at, + ) + except FileNotFoundError: + await self.handle_delete(path) + missing_paths.append(path) + except Exception as exc: + errors[path] = str(exc) + + await asyncio.gather(*(load(path) for path in paths)) + return ( + metadata_by_path, + [(path, errors[path]) for path in sorted(errors)], + sorted(missing_paths), + ) + + async def _load_index_batch_files( + self, + paths: list[str], + metadata_by_path: dict[str, IndexFileMetadata], + ) -> tuple[dict[str, IndexInputFile], list[tuple[str, str]]]: + """Read one batch of file contents into typed input objects.""" + if not paths: + return {}, [] + + semaphore = asyncio.Semaphore(self.app_config.sync_max_concurrent_files) + files: dict[str, IndexInputFile] = {} + errors: dict[str, str] = {} + + async def load(path: str) -> None: + async with semaphore: + metadata = metadata_by_path[path] + try: + content = await self.file_service.read_file_bytes(path) + files[path] = IndexInputFile( + path=metadata.path, + size=metadata.size, + checksum=metadata.checksum, + content_type=metadata.content_type, + last_modified=metadata.last_modified, + created_at=metadata.created_at, + content=content, + ) + except FileNotFoundError: + await self.handle_delete(path) + except Exception as exc: + errors[path] = str(exc) + + await asyncio.gather(*(load(path) for path in paths)) + return files, [(path, errors[path]) for path in sorted(errors)] + + async def _record_index_failure( + self, + path: str, + error: str, + skipped_files: list[SkippedFile], + skipped_paths: set[str], + ) -> None: + """Record a per-file batch failure and promote it to skipped when threshold is reached.""" + await self._record_failure(path, error) + if await self._should_skip_file(path): + self._append_skipped_file(path, skipped_files, skipped_paths) + + def _append_skipped_file( + self, + path: str, + skipped_files: list[SkippedFile], + skipped_paths: set[str], + ) -> None: + """Append one skipped file record once per sync run.""" + if path in skipped_paths or path not in self._file_failures: + return + + failure_info = self._file_failures[path] + skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) + ) + skipped_paths.add(path) + + async def _emit_index_progress( + self, + progress_callback: Callable[[IndexProgress], Awaitable[None]] | None, + *, + files_total: int, + files_processed: int, + batches_total: int, + batches_completed: int, + current_batch_bytes: int, + started_at: float, + ) -> None: + """Emit a typed indexing progress update when the caller requested one.""" + if progress_callback is None: + return + + elapsed_seconds = max(time.monotonic() - started_at, 0.001) + files_per_minute = files_processed / elapsed_seconds * 60 if files_processed else 0.0 + eta_seconds = None + if files_processed and files_total > files_processed: + files_per_second = files_processed / elapsed_seconds + eta_seconds = (files_total - files_processed) / files_per_second + + await progress_callback( + IndexProgress( + files_total=files_total, + files_processed=files_processed, + batches_total=batches_total, + batches_completed=batches_completed, + current_batch_bytes=current_batch_bytes, + files_per_minute=files_per_minute, + eta_seconds=eta_seconds, + ) + ) + async def scan(self, directory, force_full: bool = False): """Smart scan using watermark and file count for large project optimization. @@ -1081,13 +1327,17 @@ async def handle_move(self, old_path, new_path): # update search index await self.search_service.index_entity(updated) - async def resolve_relations(self, entity_id: int | None = None): + async def resolve_relations(self, entity_id: int | None = None) -> set[int]: """Try to resolve unresolved relations. Args: entity_id: If provided, only resolve relations for this specific entity. Otherwise, resolve all unresolved relations in the database. + + Returns: + Set of source entity IDs whose outgoing relations changed. """ + affected_entity_ids: set[int] = set() if entity_id: # Only get unresolved relations for the specific entity @@ -1131,8 +1381,7 @@ async def resolve_relations(self, entity_id: int | None = None): "to_name": resolved_entity.title, }, ) - # update search index only on successful resolution - await self.search_service.index_entity(resolved_entity) + affected_entity_ids.add(relation.from_id) except IntegrityError: with telemetry.scope( "sync.relation.resolve_conflict", @@ -1164,6 +1413,14 @@ async def resolve_relations(self, entity_id: int | None = None): logger.debug( f"Could not delete duplicate relation {relation.id}: {e}" ) + affected_entity_ids.add(relation.from_id) + + for affected_entity_id in sorted(affected_entity_ids): + source_entity = await self.entity_repository.find_by_id(affected_entity_id) + if source_entity is not None: + await self.search_service.index_entity(source_entity) + + return affected_entity_ids async def _quick_count_files(self, directory: Path) -> int: """Fast file count using find command. diff --git a/tests/indexing/test_batch_indexer.py b/tests/indexing/test_batch_indexer.py new file mode 100644 index 00000000..f054d90e --- /dev/null +++ b/tests/indexing/test_batch_indexer.py @@ -0,0 +1,400 @@ +"""Tests for the reusable batch indexing executor.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from textwrap import dedent + +import pytest +from sqlalchemy import text + +from basic_memory.indexing import BatchIndexer, IndexFrontmatterUpdate, IndexInputFile + + +class _TestFileWriter: + """Adapt the real FileService for batch indexer tests.""" + + def __init__(self, file_service) -> None: + self.file_service = file_service + + async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: + return await self.file_service.update_frontmatter(update.path, update.metadata) + + +async def _create_file(path: Path, content: str | bytes) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if isinstance(content, bytes): + path.write_bytes(content) + else: + path.write_text(content) + + +async def _load_input(file_service, path: str) -> IndexInputFile: + metadata = await file_service.get_file_metadata(path) + return IndexInputFile( + path=path, + size=metadata.size, + checksum=await file_service.compute_checksum(path), + content_type=file_service.content_type(path), + last_modified=metadata.modified_at, + created_at=metadata.created_at, + content=await file_service.read_file_bytes(path), + ) + + +def _make_batch_indexer( + app_config, entity_service, entity_repository, relation_repository, search_service, file_service +) -> BatchIndexer: + return BatchIndexer( + app_config=app_config, + entity_service=entity_service, + entity_repository=entity_repository, + relation_repository=relation_repository, + search_service=search_service, + file_writer=_TestFileWriter(file_service), + ) + + +@pytest.mark.asyncio +async def test_batch_indexer_parses_markdown_with_parallel_path( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + path_one = "notes/one.md" + path_two = "notes/two.md" + await _create_file( + project_config.home / path_one, + dedent( + """ + --- + title: One + type: note + --- + # One + """ + ).strip(), + ) + await _create_file( + project_config.home / path_two, + dedent( + """ + --- + title: Two + type: note + --- + # Two + """ + ).strip(), + ) + + files = { + path_one: await _load_input(file_service, path_one), + path_two: await _load_input(file_service, path_two), + } + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + original_parse = entity_service.entity_parser.parse_markdown_content + in_flight = 0 + max_in_flight = 0 + + async def spy_parse(*args, **kwargs): + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await asyncio.sleep(0.05) + try: + return await original_parse(*args, **kwargs) + finally: + in_flight -= 1 + + entity_service.entity_parser.parse_markdown_content = spy_parse + try: + result = await batch_indexer.index_files( + files, + max_concurrent=2, + parse_max_concurrent=2, + ) + finally: + entity_service.entity_parser.parse_markdown_content = original_parse + + assert max_in_flight >= 2 + assert len(result.indexed) == 2 + assert result.errors == [] + + +@pytest.mark.asyncio +async def test_batch_indexer_creates_entities_with_parallel_path( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + path_one = "notes/alpha.md" + path_two = "notes/beta.md" + await _create_file( + project_config.home / path_one, + dedent( + """ + --- + title: Alpha + type: note + --- + # Alpha + """ + ).strip(), + ) + await _create_file( + project_config.home / path_two, + dedent( + """ + --- + title: Beta + type: note + --- + # Beta + """ + ).strip(), + ) + + files = { + path_one: await _load_input(file_service, path_one), + path_two: await _load_input(file_service, path_two), + } + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + original_upsert = entity_service.upsert_entity_from_markdown + in_flight = 0 + max_in_flight = 0 + + async def spy_upsert(*args, **kwargs): + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await asyncio.sleep(0.05) + try: + return await original_upsert(*args, **kwargs) + finally: + in_flight -= 1 + + entity_service.upsert_entity_from_markdown = spy_upsert + try: + result = await batch_indexer.index_files( + files, + max_concurrent=2, + parse_max_concurrent=2, + ) + finally: + entity_service.upsert_entity_from_markdown = original_upsert + + assert max_in_flight >= 2 + assert len(result.indexed) == 2 + assert result.errors == [] + + +@pytest.mark.asyncio +async def test_batch_indexer_indexes_non_markdown_files( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + pdf_path = "assets/doc.pdf" + image_path = "assets/image.png" + await _create_file(project_config.home / pdf_path, b"%PDF-1.4 test") + await _create_file(project_config.home / image_path, b"\x89PNG\r\n\x1a\nrest") + + files = { + pdf_path: await _load_input(file_service, pdf_path), + image_path: await _load_input(file_service, image_path), + } + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + result = await batch_indexer.index_files( + files, + max_concurrent=2, + parse_max_concurrent=2, + ) + + assert {indexed.path for indexed in result.indexed} == {pdf_path, image_path} + + pdf_entity = await entity_repository.get_by_file_path(pdf_path) + image_entity = await entity_repository.get_by_file_path(image_path) + assert pdf_entity is not None + assert pdf_entity.content_type == "application/pdf" + assert image_entity is not None + assert image_entity.content_type == "image/png" + + +@pytest.mark.asyncio +async def test_batch_indexer_resolves_relations_and_refreshes_search( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + search_repository, + file_service, + project_config, +): + source_path = "notes/source.md" + target_path = "notes/target.md" + await _create_file( + project_config.home / source_path, + dedent( + """ + --- + title: Source + type: note + --- + # Source + + - depends_on [[Target]] + """ + ).strip(), + ) + await _create_file( + project_config.home / target_path, + dedent( + """ + --- + title: Target + type: note + --- + # Target + """ + ).strip(), + ) + + files = { + source_path: await _load_input(file_service, source_path), + target_path: await _load_input(file_service, target_path), + } + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + result = await batch_indexer.index_files( + files, + max_concurrent=2, + parse_max_concurrent=2, + ) + + source = await entity_repository.get_by_file_path(source_path) + target = await entity_repository.get_by_file_path(target_path) + assert source is not None + assert target is not None + assert len(source.outgoing_relations) == 1 + assert source.outgoing_relations[0].to_id == target.id + assert result.relations_unresolved == 0 + assert result.search_indexed == 2 + + relation_rows = await search_repository.execute_query( + text( + "SELECT COUNT(*) FROM search_index " + "WHERE entity_id = :entity_id AND type = 'relation' AND to_id IS NOT NULL" + ), + {"entity_id": source.id}, + ) + assert relation_rows.scalar_one() == 1 + + +@pytest.mark.asyncio +async def test_batch_indexer_assigns_unique_permalinks_for_batch_local_conflicts( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + path_one = "notes/basic memory bug.md" + path_two = "notes/basic-memory-bug.md" + await _create_file( + project_config.home / path_one, + dedent( + """ + --- + title: Basic Memory Bug + type: note + --- + # Basic Memory Bug + """ + ).strip(), + ) + await _create_file( + project_config.home / path_two, + dedent( + """ + --- + title: Basic Memory Bug Report + type: note + --- + # Basic Memory Bug Report + """ + ).strip(), + ) + + files = { + path_one: await _load_input(file_service, path_one), + path_two: await _load_input(file_service, path_two), + } + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + result = await batch_indexer.index_files( + files, + max_concurrent=2, + parse_max_concurrent=2, + ) + + assert result.errors == [] + + entities = await entity_repository.find_all() + assert len(entities) == 2 + permalinks = [entity.permalink for entity in entities if entity.permalink] + assert len(set(permalinks)) == 2 diff --git a/tests/indexing/test_batching.py b/tests/indexing/test_batching.py new file mode 100644 index 00000000..634fade1 --- /dev/null +++ b/tests/indexing/test_batching.py @@ -0,0 +1,85 @@ +"""Tests for deterministic indexing batch planning.""" + +from basic_memory.indexing import IndexFileMetadata +from basic_memory.indexing.batching import build_index_batches + + +def test_build_index_batches_respects_max_files() -> None: + metadata = { + f"note-{index}.md": IndexFileMetadata(path=f"note-{index}.md", size=10) + for index in range(5) + } + + batches = build_index_batches( + list(metadata), + metadata, + max_files=2, + max_bytes=10_000, + ) + + assert [batch.paths for batch in batches] == [ + ["note-0.md", "note-1.md"], + ["note-2.md", "note-3.md"], + ["note-4.md"], + ] + + +def test_build_index_batches_respects_max_bytes() -> None: + metadata = { + "a.md": IndexFileMetadata(path="a.md", size=30), + "b.md": IndexFileMetadata(path="b.md", size=40), + "c.md": IndexFileMetadata(path="c.md", size=50), + } + + batches = build_index_batches( + ["c.md", "a.md", "b.md"], + metadata, + max_files=10, + max_bytes=70, + ) + + assert [(batch.paths, batch.total_bytes) for batch in batches] == [ + (["a.md", "b.md"], 70), + (["c.md"], 50), + ] + + +def test_build_index_batches_puts_giant_file_in_single_file_batch() -> None: + metadata = { + "alpha.md": IndexFileMetadata(path="alpha.md", size=10), + "giant.md": IndexFileMetadata(path="giant.md", size=500), + "omega.md": IndexFileMetadata(path="omega.md", size=10), + } + + batches = build_index_batches( + list(metadata), + metadata, + max_files=10, + max_bytes=100, + ) + + assert [(batch.paths, batch.total_bytes) for batch in batches] == [ + (["alpha.md"], 10), + (["giant.md"], 500), + (["omega.md"], 10), + ] + + +def test_build_index_batches_is_deterministic() -> None: + metadata = { + "notes/b.md": IndexFileMetadata(path="notes/b.md", size=10), + "notes/a.md": IndexFileMetadata(path="notes/a.md", size=10), + "notes/c.md": IndexFileMetadata(path="notes/c.md", size=10), + } + + batches = build_index_batches( + ["notes/c.md", "notes/a.md", "notes/b.md"], + metadata, + max_files=2, + max_bytes=1_000, + ) + + assert [batch.paths for batch in batches] == [ + ["notes/a.md", "notes/b.md"], + ["notes/c.md"], + ] diff --git a/tests/repository/test_openai_provider.py b/tests/repository/test_openai_provider.py index fc0b062a..e9db6bbe 100644 --- a/tests/repository/test_openai_provider.py +++ b/tests/repository/test_openai_provider.py @@ -1,5 +1,6 @@ """Tests for OpenAIEmbeddingProvider and embedding provider factory.""" +import asyncio import builtins import sys from types import SimpleNamespace @@ -40,6 +41,34 @@ def __init__(self, *, api_key: str, base_url=None, timeout=30.0): _StubAsyncOpenAI.init_count += 1 +class _ConcurrentEmbeddingsApi: + def __init__(self): + self.calls: list[tuple[str, list[str]]] = [] + self.in_flight = 0 + self.max_in_flight = 0 + + async def create(self, *, model: str, input: list[str]): + self.calls.append((model, input)) + self.in_flight += 1 + self.max_in_flight = max(self.max_in_flight, self.in_flight) + try: + await asyncio.sleep(0.05) + vectors = [] + for index, value in enumerate(input): + base = float(len(value)) + vectors.append( + SimpleNamespace(index=index, embedding=[base, base + 1.0, base + 2.0]) + ) + return SimpleNamespace(data=vectors) + finally: + self.in_flight -= 1 + + +class _MalformedEmbeddingsApi: + async def create(self, *, model: str, input: list[str]): + return SimpleNamespace(data=[SimpleNamespace(index=0, embedding=[1.0, 2.0, 3.0])]) + + @pytest.fixture(autouse=True) def _reset_embedding_provider_cache_fixture(): reset_embedding_provider_cache() @@ -260,6 +289,57 @@ def test_embedding_provider_factory_reuses_provider_for_same_cache_key(): assert provider_a is provider_b +@pytest.mark.asyncio +async def test_openai_provider_runs_batches_concurrently_and_preserves_output_order(monkeypatch): + """Concurrent request fan-out should keep batch order stable.""" + + shared_api = _ConcurrentEmbeddingsApi() + + class _ConcurrentAsyncOpenAI: + def __init__(self, *, api_key: str, base_url=None, timeout=30.0): + self.embeddings = shared_api + + module = type(sys)("openai") + module.AsyncOpenAI = _ConcurrentAsyncOpenAI + monkeypatch.setitem(sys.modules, "openai", module) + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + provider = OpenAIEmbeddingProvider( + model_name="text-embedding-3-small", + batch_size=2, + request_concurrency=2, + dimensions=3, + ) + + vectors = await provider.embed_documents(["a", "bbbb", "ccc", "dd"]) + + assert shared_api.max_in_flight >= 2 + assert vectors == [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [3.0, 4.0, 5.0], + [2.0, 3.0, 4.0], + ] + + +@pytest.mark.asyncio +async def test_openai_provider_fails_fast_on_malformed_concurrent_batch(monkeypatch): + """Missing batch indexes should still raise even when requests run concurrently.""" + + class _MalformedAsyncOpenAI: + def __init__(self, *, api_key: str, base_url=None, timeout=30.0): + self.embeddings = _MalformedEmbeddingsApi() + + module = type(sys)("openai") + module.AsyncOpenAI = _MalformedAsyncOpenAI + monkeypatch.setitem(sys.modules, "openai", module) + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + + provider = OpenAIEmbeddingProvider(batch_size=2, request_concurrency=2, dimensions=3) + with pytest.raises(RuntimeError, match="missing expected vector index"): + await provider.embed_documents(["one", "two", "three", "four"]) + + def test_embedding_provider_factory_creates_new_provider_for_different_cache_key(): """Factory should create distinct providers when cache key fields differ.""" config_a = BasicMemoryConfig( @@ -285,6 +365,22 @@ def test_embedding_provider_factory_creates_new_provider_for_different_cache_key assert provider_a is not provider_b +def test_embedding_provider_factory_forwards_openai_request_concurrency(): + """Factory should forward provider request concurrency for API-backed batching.""" + config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="openai", + semantic_embedding_request_concurrency=6, + ) + + provider = create_embedding_provider(config) + assert isinstance(provider, OpenAIEmbeddingProvider) + assert provider.request_concurrency == 6 + + def test_embedding_provider_factory_reset_clears_cache(): """Cache reset helper should force provider recreation for the same config.""" config = BasicMemoryConfig( diff --git a/tests/sync/test_sync_service_batching.py b/tests/sync/test_sync_service_batching.py new file mode 100644 index 00000000..41186aca --- /dev/null +++ b/tests/sync/test_sync_service_batching.py @@ -0,0 +1,92 @@ +"""Targeted tests for batched sync indexing behavior.""" + +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent + +import pytest +from sqlalchemy import text + +from basic_memory.indexing import IndexProgress + + +async def _create_file(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +@pytest.mark.asyncio +async def test_sync_batches_changed_files_emits_typed_progress_and_resolves_forward_refs( + app_config, + sync_service, + search_repository, + entity_repository, + project_config, +): + app_config.index_batch_size = 1 + app_config.index_batch_max_bytes = 1_024 + + source_path = project_config.home / "notes/source.md" + target_path = project_config.home / "notes/target.md" + + await _create_file( + source_path, + dedent( + """ + --- + title: Source + type: note + --- + # Source + + - depends_on [[Target]] + """ + ).strip(), + ) + await _create_file( + target_path, + dedent( + """ + --- + title: Target + type: note + --- + # Target + """ + ).strip(), + ) + + progress_updates: list[IndexProgress] = [] + + async def on_progress(update: IndexProgress) -> None: + progress_updates.append(update) + + await sync_service.sync( + project_config.home, + project_name=project_config.name, + progress_callback=on_progress, + ) + + assert progress_updates + assert all(isinstance(update, IndexProgress) for update in progress_updates) + assert progress_updates[-1].files_total == 2 + assert progress_updates[-1].files_processed == 2 + assert progress_updates[-1].batches_total == 2 + assert progress_updates[-1].batches_completed == 2 + + source = await entity_repository.get_by_file_path("notes/source.md") + target = await entity_repository.get_by_file_path("notes/target.md") + assert source is not None + assert target is not None + assert len(source.outgoing_relations) == 1 + assert source.outgoing_relations[0].to_id == target.id + + relation_rows = await search_repository.execute_query( + text( + "SELECT COUNT(*) FROM search_index " + "WHERE entity_id = :entity_id AND type = 'relation' AND to_id IS NOT NULL" + ), + {"entity_id": source.id}, + ) + assert relation_rows.scalar_one() == 1 diff --git a/tests/sync/test_sync_service_telemetry.py b/tests/sync/test_sync_service_telemetry.py index f243b39d..9a657508 100644 --- a/tests/sync/test_sync_service_telemetry.py +++ b/tests/sync/test_sync_service_telemetry.py @@ -49,14 +49,11 @@ async def fake_handle_move(old_path, new_path): async def fake_handle_delete(path): return None - async def fake_sync_file(path, new=True): - return None, None - - async def fake_should_skip_file(path): - return False + async def fake_index_changed_files(changed_paths, checksums_by_path, progress_callback=None): + return [], [] async def fake_resolve_relations(entity_id=None): - return None + return set() async def fake_quick_count_files(directory): return 3 @@ -72,8 +69,7 @@ async def fake_update(project_id, values): monkeypatch.setattr(sync_service, "scan", fake_scan) monkeypatch.setattr(sync_service, "handle_move", fake_handle_move) monkeypatch.setattr(sync_service, "handle_delete", fake_handle_delete) - monkeypatch.setattr(sync_service, "sync_file", fake_sync_file) - monkeypatch.setattr(sync_service, "_should_skip_file", fake_should_skip_file) + monkeypatch.setattr(sync_service, "_index_changed_files", fake_index_changed_files) monkeypatch.setattr(sync_service, "resolve_relations", fake_resolve_relations) monkeypatch.setattr(sync_service, "_quick_count_files", fake_quick_count_files) monkeypatch.setattr(sync_service.project_repository, "find_by_id", fake_find_by_id) From 2f3cf950c0ae1076f2c215cbd5dedc0d6494df4e Mon Sep 17 00:00:00 2001 From: phernandez Date: Tue, 7 Apr 2026 23:34:34 -0500 Subject: [PATCH 2/9] perf(sync): return final markdown from batch indexer Signed-off-by: phernandez --- src/basic_memory/indexing/__init__.py | 2 + src/basic_memory/indexing/batch_indexer.py | 19 ++++-- src/basic_memory/indexing/models.py | 13 +++- src/basic_memory/services/file_service.py | 29 ++++++-- src/basic_memory/sync/sync_service.py | 16 ++++- tests/indexing/test_batch_indexer.py | 78 +++++++++++++++++++++- 6 files changed, 140 insertions(+), 17 deletions(-) diff --git a/src/basic_memory/indexing/__init__.py b/src/basic_memory/indexing/__init__.py index 369da770..79742d6f 100644 --- a/src/basic_memory/indexing/__init__.py +++ b/src/basic_memory/indexing/__init__.py @@ -8,6 +8,7 @@ IndexFileMetadata, IndexFileWriter, IndexFrontmatterUpdate, + IndexFrontmatterWriteResult, IndexingBatchResult, IndexInputFile, IndexProgress, @@ -20,6 +21,7 @@ "IndexFileMetadata", "IndexFileWriter", "IndexFrontmatterUpdate", + "IndexFrontmatterWriteResult", "IndexingBatchResult", "IndexInputFile", "IndexProgress", diff --git a/src/basic_memory/indexing/batch_indexer.py b/src/basic_memory/indexing/batch_indexer.py index 09f74114..0d1f0b17 100644 --- a/src/basic_memory/indexing/batch_indexer.py +++ b/src/basic_memory/indexing/batch_indexer.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError from basic_memory.config import BasicMemoryConfig -from basic_memory.file_utils import compute_checksum, has_frontmatter +from basic_memory.file_utils import compute_checksum, has_frontmatter, remove_frontmatter from basic_memory.markdown.schemas import EntityMarkdown from basic_memory.indexing.models import ( IndexedEntity, @@ -47,6 +47,7 @@ class _PreparedEntity: checksum: str content_type: str | None search_content: str | None + markdown_content: str | None = None class BatchIndexer: @@ -253,6 +254,7 @@ async def _normalize_markdown_file( reserved_permalinks: set[str], ) -> _PreparedMarkdownFile: final_checksum = prepared.final_checksum + final_content = prepared.content final_permalink = await self._resolve_batch_permalink(prepared, reserved_permalinks) # Trigger: markdown file has no frontmatter and sync enforcement is enabled. @@ -264,9 +266,11 @@ async def _normalize_markdown_file( "type": prepared.markdown.frontmatter.type, "permalink": final_permalink, } - final_checksum = await self.file_writer.write_frontmatter( + write_result = await self.file_writer.write_frontmatter( IndexFrontmatterUpdate(path=prepared.file.path, metadata=frontmatter_updates) ) + final_checksum = write_result.checksum + final_content = write_result.content prepared.markdown.frontmatter.metadata.update(frontmatter_updates) # Trigger: existing markdown frontmatter may lack the canonical permalink. @@ -278,16 +282,18 @@ async def _normalize_markdown_file( and final_permalink != prepared.markdown.frontmatter.permalink ): prepared.markdown.frontmatter.metadata["permalink"] = final_permalink - final_checksum = await self.file_writer.write_frontmatter( + write_result = await self.file_writer.write_frontmatter( IndexFrontmatterUpdate( path=prepared.file.path, metadata={"permalink": final_permalink}, ) ) + final_checksum = write_result.checksum + final_content = write_result.content return _PreparedMarkdownFile( file=prepared.file, - content=prepared.content, + content=final_content, final_checksum=final_checksum, markdown=prepared.markdown, file_contains_frontmatter=prepared.file_contains_frontmatter, @@ -351,7 +357,8 @@ async def _upsert_markdown_file(self, prepared: _PreparedMarkdownFile) -> _Prepa entity_id=updated.id, checksum=prepared.final_checksum, content_type=prepared.file.content_type, - search_content=prepared.markdown.content, + search_content=remove_frontmatter(prepared.content), + markdown_content=prepared.content, ) async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity: @@ -412,6 +419,7 @@ async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity: checksum=checksum, content_type=file.content_type, search_content=None, + markdown_content=None, ) # --- Relations --- @@ -487,6 +495,7 @@ async def _refresh_search_index( permalink=entity.permalink, checksum=prepared.checksum, content_type=prepared.content_type, + markdown_content=prepared.markdown_content, ) # --- Helpers --- diff --git a/src/basic_memory/indexing/models.py b/src/basic_memory/indexing/models.py index 583cca65..aaa5b869 100644 --- a/src/basic_memory/indexing/models.py +++ b/src/basic_memory/indexing/models.py @@ -55,6 +55,14 @@ class IndexFrontmatterUpdate: metadata: dict[str, Any] +@dataclass(slots=True) +class IndexFrontmatterWriteResult: + """Typed result for a frontmatter write performed during indexing.""" + + checksum: str + content: str + + @dataclass(slots=True) class IndexedEntity: """Stable output describing one file that finished indexing successfully.""" @@ -64,6 +72,7 @@ class IndexedEntity: permalink: str | None checksum: str content_type: str | None = None + markdown_content: str | None = None @dataclass(slots=True) @@ -80,4 +89,6 @@ class IndexingBatchResult: class IndexFileWriter(Protocol): """Narrow protocol for frontmatter writes during indexing.""" - async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: ... + async def write_frontmatter( + self, update: IndexFrontmatterUpdate + ) -> IndexFrontmatterWriteResult: ... diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index 86cb3328..60f855da 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -3,6 +3,7 @@ import asyncio import hashlib import mimetypes +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union @@ -25,6 +26,14 @@ from loguru import logger +@dataclass(slots=True) +class FrontmatterUpdateResult: + """Final content emitted by a frontmatter rewrite without a follow-up reread.""" + + checksum: str + content: str + + class FileService: """Service for handling file operations with concurrency control. @@ -401,12 +410,14 @@ async def move_file(self, source: FilePath, destination: FilePath) -> None: ) raise FileOperationError(f"Failed to move file {source} -> {destination}: {e}") - async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> str: - """Update frontmatter fields in a file while preserving all content. + async def update_frontmatter_with_result( + self, path: FilePath, updates: Dict[str, Any] + ) -> FrontmatterUpdateResult: + """Update frontmatter and return the exact final written markdown content. Only modifies the frontmatter section, leaving all content untouched. Creates frontmatter section if none exists. - Returns checksum of updated file. + Returns both checksum and final content so callers do not need a reread. Uses aiofiles for true async I/O (non-blocking). @@ -415,7 +426,7 @@ async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> s updates: Dict of frontmatter fields to update Returns: - Checksum of updated file + Typed result containing checksum and final content Raises: FileOperationError: If file operations fail @@ -467,7 +478,10 @@ async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> s if formatted_content is not None: content_for_checksum = formatted_content # pragma: no cover - return await file_utils.compute_checksum(content_for_checksum) + return FrontmatterUpdateResult( + checksum=await file_utils.compute_checksum(content_for_checksum), + content=content_for_checksum, + ) except Exception as e: # pragma: no cover # Only log real errors (not YAML parsing, which is handled above) @@ -479,6 +493,11 @@ async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> s ) raise FileOperationError(f"Failed to update frontmatter: {e}") + async def update_frontmatter(self, path: FilePath, updates: Dict[str, Any]) -> str: + """Update frontmatter fields in a file while preserving all content.""" + result = await self.update_frontmatter_with_result(path, updates) + return result.checksum + async def compute_checksum(self, path: FilePath) -> str: """Compute checksum for a file using true async I/O. diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index a781f474..29507ec1 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -21,7 +21,12 @@ from basic_memory.file_utils import has_frontmatter from basic_memory.indexing import BatchIndexer, IndexFileMetadata, IndexInputFile, IndexProgress from basic_memory.indexing.batching import build_index_batches -from basic_memory.indexing.models import IndexedEntity, IndexFileWriter, IndexFrontmatterUpdate +from basic_memory.indexing.models import ( + IndexedEntity, + IndexFileWriter, + IndexFrontmatterUpdate, + IndexFrontmatterWriteResult, +) from basic_memory.ignore_utils import load_bmignore_patterns, should_ignore_path from basic_memory.markdown import EntityParser, MarkdownProcessor from basic_memory.models import Entity, Project @@ -127,8 +132,13 @@ class _FileServiceIndexWriter(IndexFileWriter): def __init__(self, file_service: FileService) -> None: self.file_service = file_service - async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: - return await self.file_service.update_frontmatter(update.path, update.metadata) + async def write_frontmatter( + self, update: IndexFrontmatterUpdate + ) -> IndexFrontmatterWriteResult: + result = await self.file_service.update_frontmatter_with_result( + update.path, update.metadata + ) + return IndexFrontmatterWriteResult(checksum=result.checksum, content=result.content) class SyncService: diff --git a/tests/indexing/test_batch_indexer.py b/tests/indexing/test_batch_indexer.py index f054d90e..fc683a4a 100644 --- a/tests/indexing/test_batch_indexer.py +++ b/tests/indexing/test_batch_indexer.py @@ -9,7 +9,12 @@ import pytest from sqlalchemy import text -from basic_memory.indexing import BatchIndexer, IndexFrontmatterUpdate, IndexInputFile +from basic_memory.indexing import ( + BatchIndexer, + IndexFrontmatterUpdate, + IndexFrontmatterWriteResult, + IndexInputFile, +) class _TestFileWriter: @@ -18,8 +23,13 @@ class _TestFileWriter: def __init__(self, file_service) -> None: self.file_service = file_service - async def write_frontmatter(self, update: IndexFrontmatterUpdate) -> str: - return await self.file_service.update_frontmatter(update.path, update.metadata) + async def write_frontmatter( + self, update: IndexFrontmatterUpdate + ) -> IndexFrontmatterWriteResult: + result = await self.file_service.update_frontmatter_with_result( + update.path, update.metadata + ) + return IndexFrontmatterWriteResult(checksum=result.checksum, content=result.content) async def _create_file(path: Path, content: str | bytes) -> None: @@ -214,6 +224,51 @@ async def spy_upsert(*args, **kwargs): assert result.errors == [] +@pytest.mark.asyncio +async def test_batch_indexer_returns_original_markdown_content_when_no_frontmatter_rewrite( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + app_config.disable_permalinks = True + + path = "notes/original.md" + original_content = dedent( + """ + --- + title: Original + type: note + --- + # Original + """ + ).strip() + await _create_file(project_config.home / path, original_content) + + files = {path: await _load_input(file_service, path)} + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + result = await batch_indexer.index_files( + files, + max_concurrent=1, + parse_max_concurrent=1, + ) + + assert result.errors == [] + assert len(result.indexed) == 1 + assert result.indexed[0].markdown_content == original_content + + @pytest.mark.asyncio async def test_batch_indexer_indexes_non_markdown_files( app_config, @@ -249,6 +304,7 @@ async def test_batch_indexer_indexes_non_markdown_files( ) assert {indexed.path for indexed in result.indexed} == {pdf_path, image_path} + assert all(indexed.markdown_content is None for indexed in result.indexed) pdf_entity = await entity_repository.get_by_file_path(pdf_path) image_entity = await entity_repository.get_by_file_path(image_path) @@ -377,6 +433,11 @@ async def test_batch_indexer_assigns_unique_permalinks_for_batch_local_conflicts path_one: await _load_input(file_service, path_one), path_two: await _load_input(file_service, path_two), } + original_contents = { + path: file.content.decode("utf-8") + for path, file in files.items() + if file.content is not None + } batch_indexer = _make_batch_indexer( app_config, entity_service, @@ -393,6 +454,17 @@ async def test_batch_indexer_assigns_unique_permalinks_for_batch_local_conflicts ) assert result.errors == [] + indexed_by_path = {indexed.path: indexed for indexed in result.indexed} + assert indexed_by_path[path_one].markdown_content is not None + assert indexed_by_path[path_two].markdown_content is not None + assert indexed_by_path[path_one].markdown_content != original_contents[path_one] + assert indexed_by_path[path_two].markdown_content != original_contents[path_two] + assert indexed_by_path[path_one].markdown_content == await file_service.read_file_content( + path_one + ) + assert indexed_by_path[path_two].markdown_content == await file_service.read_file_content( + path_two + ) entities = await entity_repository.find_all() assert len(entities) == 2 From 49a40be4544ac7c99992cf0651680a31ee66506a Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 00:21:13 -0500 Subject: [PATCH 3/9] fix(core): address batch indexing review issues Signed-off-by: phernandez --- src/basic_memory/indexing/batch_indexer.py | 83 ++++---- src/basic_memory/indexing/batching.py | 2 +- .../repository/postgres_search_repository.py | 57 ----- .../repository/search_repository_base.py | 156 -------------- src/basic_memory/services/file_service.py | 4 + src/basic_memory/sync/sync_service.py | 24 ++- test-int/test_sync_batching_integration.py | 201 ++++++++++++++++++ tests/indexing/test_batch_indexer.py | 79 +++++++ tests/sync/test_sync_service_batching.py | 176 ++++++++++++++- 9 files changed, 517 insertions(+), 265 deletions(-) create mode 100644 test-int/test_sync_batching_integration.py diff --git a/src/basic_memory/indexing/batch_indexer.py b/src/basic_memory/indexing/batch_indexer.py index 0d1f0b17..63792512 100644 --- a/src/basic_memory/indexing/batch_indexer.py +++ b/src/basic_memory/indexing/batch_indexer.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import time from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -13,7 +12,7 @@ from sqlalchemy.exc import IntegrityError from basic_memory.config import BasicMemoryConfig -from basic_memory.file_utils import compute_checksum, has_frontmatter, remove_frontmatter +from basic_memory.file_utils import compute_checksum, has_frontmatter from basic_memory.markdown.schemas import EntityMarkdown from basic_memory.indexing.models import ( IndexedEntity, @@ -21,10 +20,10 @@ IndexFrontmatterUpdate, IndexingBatchResult, IndexInputFile, - IndexProgress, ) from basic_memory.models import Entity, Relation from basic_memory.services import EntityService +from basic_memory.services.exceptions import SyncFatalError from basic_memory.services.search_service import SearchService from basic_memory.repository import EntityRepository, RelationRepository @@ -76,7 +75,7 @@ async def index_files( *, max_concurrent: int, parse_max_concurrent: int | None = None, - progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None, + existing_permalink_by_path: dict[str, str | None] | None = None, ) -> IndexingBatchResult: """Index one batch of loaded files with bounded concurrency.""" if max_concurrent <= 0: @@ -84,20 +83,9 @@ async def index_files( ordered_paths = sorted(files) if not ordered_paths: - result = IndexingBatchResult() - if progress_callback is not None: - await progress_callback( - IndexProgress( - files_total=0, - files_processed=0, - batches_total=0, - batches_completed=0, - ) - ) - return result + return IndexingBatchResult() parse_limit = parse_max_concurrent or max_concurrent - batch_start = time.monotonic() error_by_path: dict[str, str] = {} markdown_paths = [path for path in ordered_paths if self._is_markdown(files[path])] @@ -111,7 +99,8 @@ async def index_files( error_by_path.update(parse_errors) prepared_markdown, normalization_errors = await self._normalize_markdown_batch( - prepared_markdown + prepared_markdown, + existing_permalink_by_path=existing_permalink_by_path, ) error_by_path.update(normalization_errors) @@ -171,21 +160,6 @@ async def index_files( search_indexed = len(indexed_entities) - if progress_callback is not None: - elapsed_seconds = max(time.monotonic() - batch_start, 0.001) - files_per_minute = len(ordered_paths) / elapsed_seconds * 60 - await progress_callback( - IndexProgress( - files_total=len(ordered_paths), - files_processed=len(ordered_paths), - batches_total=1, - batches_completed=1, - current_batch_bytes=sum(max(files[path].size, 0) for path in ordered_paths), - files_per_minute=files_per_minute, - eta_seconds=0.0, - ) - ) - return IndexingBatchResult( indexed=indexed_entities, errors=[(path, error_by_path[path]) for path in ordered_paths if path in error_by_path], @@ -221,12 +195,21 @@ async def _prepare_markdown_file(self, file: IndexInputFile) -> _PreparedMarkdow async def _normalize_markdown_batch( self, prepared_markdown: dict[str, _PreparedMarkdownFile], + *, + existing_permalink_by_path: dict[str, str | None] | None = None, ) -> tuple[dict[str, _PreparedMarkdownFile], dict[str, str]]: if not prepared_markdown: return {}, {} + if existing_permalink_by_path is None: + existing_permalink_by_path = { + path: permalink + for path, permalink in ( + await self.entity_repository.get_file_path_to_permalink_map() + ).items() + } + batch_paths = set(prepared_markdown) - existing_permalink_by_path = await self.entity_repository.get_file_path_to_permalink_map() reserved_permalinks = { permalink for path, permalink in existing_permalink_by_path.items() @@ -242,6 +225,7 @@ async def _normalize_markdown_batch( prepared_markdown[path], reserved_permalinks, ) + existing_permalink_by_path[path] = normalized[path].markdown.frontmatter.permalink except Exception as exc: errors[path] = str(exc) logger.warning("Batch markdown normalization failed", path=path, error=str(exc)) @@ -357,13 +341,18 @@ async def _upsert_markdown_file(self, prepared: _PreparedMarkdownFile) -> _Prepa entity_id=updated.id, checksum=prepared.final_checksum, content_type=prepared.file.content_type, - search_content=remove_frontmatter(prepared.content), + search_content=( + prepared.markdown.content + if prepared.markdown.content is not None + else prepared.content + ), markdown_content=prepared.content, ) async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity: checksum = await self._resolve_checksum(file) existing = await self.entity_repository.get_by_file_path(file.path, load_relations=False) + is_new_entity = existing is None if existing is None: await self.entity_service.resolve_permalink(file.path, skip_conflict_check=True) @@ -408,7 +397,7 @@ async def _upsert_regular_file(self, file: IndexInputFile) -> _PreparedEntity: updated = await self.entity_repository.update( entity_id, - self._entity_metadata_updates(file, checksum, include_created_at=existing is None), + self._entity_metadata_updates(file, checksum, include_created_at=is_new_entity), ) if updated is None: raise ValueError(f"Failed to update file entity metadata for {file.path}") @@ -430,11 +419,15 @@ async def _resolve_batch_relations( *, max_concurrent: int, ) -> tuple[int, int]: - unresolved_relations: list[Relation] = [] - for entity_id in entity_ids: - unresolved_relations.extend( - await self.relation_repository.find_unresolved_relations_for_entity(entity_id) + unresolved_relation_lists = await asyncio.gather( + *( + self.relation_repository.find_unresolved_relations_for_entity(entity_id) + for entity_id in entity_ids ) + ) + unresolved_relations = [ + relation for relation_list in unresolved_relation_lists for relation in relation_list + ] if not unresolved_relations: return 0, 0 @@ -475,11 +468,13 @@ async def resolve_relation(relation: Relation) -> int: *(resolve_relation(relation) for relation in unresolved_relations) ) - remaining_unresolved = 0 - for entity_id in entity_ids: - remaining_unresolved += len( - await self.relation_repository.find_unresolved_relations_for_entity(entity_id) + remaining_relation_lists = await asyncio.gather( + *( + self.relation_repository.find_unresolved_relations_for_entity(entity_id) + for entity_id in entity_ids ) + ) + remaining_unresolved = sum(len(relations) for relations in remaining_relation_lists) return sum(resolved_counts), remaining_unresolved @@ -552,6 +547,8 @@ async def run(path: str) -> None: try: results[path] = await worker(path) except Exception as exc: + if isinstance(exc, SyncFatalError) or isinstance(exc.__cause__, SyncFatalError): + raise errors[path] = str(exc) logger.warning("Batch indexing failed", path=path, error=str(exc)) diff --git a/src/basic_memory/indexing/batching.py b/src/basic_memory/indexing/batching.py index 397f7af0..4ff9c2b8 100644 --- a/src/basic_memory/indexing/batching.py +++ b/src/basic_memory/indexing/batching.py @@ -52,7 +52,7 @@ def build_index_batches( current_paths.append(path) current_bytes += file_bytes - if len(current_paths) >= max_files or current_bytes >= max_bytes: + if len(current_paths) >= max_files or current_bytes == max_bytes: batches.append(IndexBatch(paths=current_paths, total_bytes=current_bytes)) current_paths = [] current_bytes = 0 diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index f9b42569..fd321bf7 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -512,12 +512,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe """Prepare chunk mutations with Postgres-specific bulk upserts.""" sync_start = time.perf_counter() - logger.info( - "Vector sync start: project_id={project_id} entity_id={entity_id}", - project_id=self.project_id, - entity_id=entity_id, - ) - async with db.scoped_session(self.session_maker) as session: await self._prepare_vector_session(session) @@ -546,13 +540,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe source_rows_count = len(rows) if not rows: - logger.info( - "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " - "source_rows_count={source_rows_count} built_chunk_records_count=0", - project_id=self.project_id, - entity_id=entity_id, - source_rows_count=source_rows_count, - ) await self._delete_entity_chunks(session, entity_id) await session.commit() prepare_seconds = time.perf_counter() - sync_start @@ -568,15 +555,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe built_chunk_records_count = len(chunk_records) current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) current_embedding_model = self._embedding_model_key() - logger.info( - "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " - "source_rows_count={source_rows_count} " - "built_chunk_records_count={built_chunk_records_count}", - project_id=self.project_id, - entity_id=entity_id, - source_rows_count=source_rows_count, - built_chunk_records_count=built_chunk_records_count, - ) if not chunk_records: await self._delete_entity_chunks(session, entity_id) await session.commit() @@ -629,16 +607,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) ) if skip_unchanged_entity: - logger.info( - "Vector sync skipped unchanged entity: project_id={project_id} " - "entity_id={entity_id} chunks_skipped={chunks_skipped} " - "entity_fingerprint={entity_fingerprint} embedding_model={embedding_model}", - project_id=self.project_id, - entity_id=entity_id, - chunks_skipped=built_chunk_records_count, - entity_fingerprint=current_entity_fingerprint, - embedding_model=current_embedding_model, - ) prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, @@ -752,31 +720,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe row_id = upserted_ids_by_key[record["chunk_key"]] embedding_jobs.append((row_id, record["chunk_text"])) - logger.info( - "Vector sync diff complete: project_id={project_id} entity_id={entity_id} " - "existing_chunks_count={existing_chunks_count} " - "stale_chunks_count={stale_chunks_count} " - "orphan_chunks_count={orphan_chunks_count} " - "chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count} " - "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " - "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} " - "oversized_entity={oversized_entity} entity_complete={entity_complete}", - project_id=self.project_id, - entity_id=entity_id, - existing_chunks_count=existing_chunks_count, - stale_chunks_count=stale_chunks_count, - orphan_chunks_count=orphan_chunks_count, - chunks_skipped=skipped_chunks_count, - embedding_jobs_count=len(embedding_jobs), - pending_jobs_total=shard_plan.pending_jobs_total, - shard_index=shard_plan.shard_index, - shard_count=shard_plan.shard_count, - remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, - oversized_entity=shard_plan.oversized_entity, - entity_complete=shard_plan.entity_complete, - ) - prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 8abcaf17..f252df05 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -590,7 +590,6 @@ def _log_vector_shard_plan( if shard_plan.pending_jobs_total == 0: return - scheduled_jobs_count = shard_plan.pending_jobs_total - shard_plan.remaining_jobs_after_shard if shard_plan.oversized_entity: logger.warning( "Vector sync oversized entity detected: project_id={project_id} " @@ -603,23 +602,6 @@ def _log_vector_shard_plan( shard_count=shard_plan.shard_count, ) - logger.info( - "Vector sync shard planned: project_id={project_id} entity_id={entity_id} " - "pending_jobs_total={pending_jobs_total} scheduled_jobs_count={scheduled_jobs_count} " - "shard_index={shard_index} shard_count={shard_count} " - "remaining_jobs_after_shard={remaining_jobs_after_shard} " - "oversized_entity={oversized_entity} entity_complete={entity_complete}", - project_id=self.project_id, - entity_id=entity_id, - pending_jobs_total=shard_plan.pending_jobs_total, - scheduled_jobs_count=scheduled_jobs_count, - shard_index=shard_plan.shard_index, - shard_count=shard_plan.shard_count, - remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, - oversized_entity=shard_plan.oversized_entity, - entity_complete=shard_plan.entity_complete, - ) - # --- Text splitting --- def _split_text_into_chunks(self, text_value: str) -> list[str]: @@ -1044,12 +1026,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe """Prepare chunk mutations and embedding jobs for one entity.""" sync_start = time.perf_counter() - logger.info( - "Vector sync start: project_id={project_id} entity_id={entity_id}", - project_id=self.project_id, - entity_id=entity_id, - ) - async with db.scoped_session(self.session_maker) as session: await self._prepare_vector_session(session) @@ -1080,15 +1056,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe # No search_index rows → delete all chunk/embedding data for this entity. if not rows: - logger.info( - "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " - "source_rows_count={source_rows_count} " - "built_chunk_records_count={built_chunk_records_count}", - project_id=self.project_id, - entity_id=entity_id, - source_rows_count=source_rows_count, - built_chunk_records_count=built_chunk_records_count, - ) await self._delete_entity_chunks(session, entity_id) await session.commit() prepare_seconds = time.perf_counter() - sync_start @@ -1104,15 +1071,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe built_chunk_records_count = len(chunk_records) current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) current_embedding_model = self._embedding_model_key() - logger.info( - "Vector sync source prepared: project_id={project_id} entity_id={entity_id} " - "source_rows_count={source_rows_count} " - "built_chunk_records_count={built_chunk_records_count}", - project_id=self.project_id, - entity_id=entity_id, - source_rows_count=source_rows_count, - built_chunk_records_count=built_chunk_records_count, - ) if not chunk_records: await self._delete_entity_chunks(session, entity_id) await session.commit() @@ -1178,16 +1136,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) ) if skip_unchanged_entity: - logger.info( - "Vector sync skipped unchanged entity: project_id={project_id} " - "entity_id={entity_id} chunks_skipped={chunks_skipped} " - "entity_fingerprint={entity_fingerprint} embedding_model={embedding_model}", - project_id=self.project_id, - entity_id=entity_id, - chunks_skipped=built_chunk_records_count, - entity_fingerprint=current_entity_fingerprint, - embedding_model=current_embedding_model, - ) prepare_seconds = time.perf_counter() - sync_start return _PreparedEntityVectorSync( entity_id=entity_id, @@ -1305,31 +1253,6 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe ) row_id = int(inserted.scalar_one()) embedding_jobs.append((row_id, record["chunk_text"])) - - logger.info( - "Vector sync diff complete: project_id={project_id} entity_id={entity_id} " - "existing_chunks_count={existing_chunks_count} " - "stale_chunks_count={stale_chunks_count} " - "orphan_chunks_count={orphan_chunks_count} " - "chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count} " - "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " - "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard} " - "oversized_entity={oversized_entity} entity_complete={entity_complete}", - project_id=self.project_id, - entity_id=entity_id, - existing_chunks_count=existing_chunks_count, - stale_chunks_count=stale_chunks_count, - orphan_chunks_count=orphan_chunks_count, - chunks_skipped=skipped_chunks_count, - embedding_jobs_count=len(embedding_jobs), - pending_jobs_total=shard_plan.pending_jobs_total, - shard_index=shard_plan.shard_index, - shard_count=shard_plan.shard_count, - remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, - oversized_entity=shard_plan.oversized_entity, - entity_complete=shard_plan.entity_complete, - ) await session.commit() prepare_seconds = time.perf_counter() - sync_start @@ -1364,15 +1287,6 @@ async def _flush_embedding_jobs( texts = [job.chunk_text for job in flush_jobs] embeddings = await self._embedding_provider.embed_documents(texts) embed_seconds = time.perf_counter() - embed_start - embed_rate = (len(flush_jobs) / embed_seconds) if embed_seconds > 0 else 0.0 - logger.info( - "Vector batch embed flush: project_id={project_id} chunk_count={chunk_count} " - "embed_seconds={embed_seconds:.3f} embed_rate_chunks_per_second={embed_rate:.2f}", - project_id=self.project_id, - chunk_count=len(flush_jobs), - embed_seconds=embed_seconds, - embed_rate=embed_rate, - ) if len(embeddings) != len(flush_jobs): raise RuntimeError("Embedding provider returned an unexpected number of vectors.") @@ -1383,15 +1297,6 @@ async def _flush_embedding_jobs( await self._write_embeddings(session, write_jobs, embeddings) await session.commit() write_seconds = time.perf_counter() - write_start - write_rate = (len(flush_jobs) / write_seconds) if write_seconds > 0 else 0.0 - logger.info( - "Vector batch write flush: project_id={project_id} row_count={row_count} " - "write_seconds={write_seconds:.3f} write_rate_rows_per_second={write_rate:.2f}", - project_id=self.project_id, - row_count=len(flush_jobs), - write_seconds=write_seconds, - write_rate=write_rate, - ) flush_size = len(flush_jobs) entity_job_counts: dict[int, int] = {} @@ -1485,35 +1390,6 @@ def _log_vector_sync_complete( remaining_jobs_after_shard: int, ) -> None: """Log completion and slow-entity warnings with a consistent format.""" - logger.info( - "Vector sync complete: project_id={project_id} entity_id={entity_id} " - "total_seconds={total_seconds:.3f} prepare_seconds={prepare_seconds:.3f} " - "queue_wait_seconds={queue_wait_seconds:.3f} embed_seconds={embed_seconds:.3f} " - "write_seconds={write_seconds:.3f} source_rows_count={source_rows_count} " - "chunks_total={chunks_total} chunks_skipped={chunks_skipped} " - "embedding_jobs_count={embedding_jobs_count} entity_skipped={entity_skipped} " - "entity_complete={entity_complete} oversized_entity={oversized_entity} " - "pending_jobs_total={pending_jobs_total} shard_index={shard_index} " - "shard_count={shard_count} remaining_jobs_after_shard={remaining_jobs_after_shard}", - project_id=self.project_id, - entity_id=entity_id, - total_seconds=total_seconds, - prepare_seconds=prepare_seconds, - queue_wait_seconds=queue_wait_seconds, - embed_seconds=embed_seconds, - write_seconds=write_seconds, - source_rows_count=source_rows_count, - chunks_total=chunks_total, - chunks_skipped=chunks_skipped, - embedding_jobs_count=embedding_jobs_count, - entity_skipped=entity_skipped, - entity_complete=entity_complete, - oversized_entity=oversized_entity, - pending_jobs_total=pending_jobs_total, - shard_index=shard_index, - shard_count=shard_count, - remaining_jobs_after_shard=remaining_jobs_after_shard, - ) if total_seconds > 10: logger.warning( "Vector sync slow entity: project_id={project_id} entity_id={entity_id} " @@ -1719,22 +1595,6 @@ def _log_vector_summary() -> None: return total_ms = (time.perf_counter() - query_start) * 1000 - logger.info( - "Semantic query timing: project_id={project_id} retrieval_mode={retrieval_mode} " - "query_length={query_length} candidate_limit={candidate_limit} " - "vector_row_count={vector_row_count} embed_ms={embed_ms:.2f} " - "vector_query_ms={vector_query_ms:.2f} hydrate_ms={hydrate_ms:.2f} " - "total_ms={total_ms:.2f}", - project_id=self.project_id, - retrieval_mode="vector", - query_length=len(query_text), - candidate_limit=candidate_limit, - vector_row_count=vector_row_count, - embed_ms=embed_ms, - vector_query_ms=vector_query_ms, - hydrate_ms=hydrate_ms, - total_ms=total_ms, - ) if total_ms > 2000: logger.warning( "[SEMANTIC_SLOW_QUERY] Semantic query timing: project_id={project_id} " @@ -2073,22 +1933,6 @@ async def _search_hybrid( output.append(replace(row, score=fused_score)) fusion_ms = (time.perf_counter() - fusion_start) * 1000 total_ms = (time.perf_counter() - query_start) * 1000 - logger.info( - "Semantic query timing: project_id={project_id} retrieval_mode={retrieval_mode} " - "query_length={query_length} candidate_limit={candidate_limit} " - "fts_count={fts_count} vector_count={vector_count} fts_ms={fts_ms:.2f} " - "vector_ms={vector_ms:.2f} fusion_ms={fusion_ms:.2f} total_ms={total_ms:.2f}", - project_id=self.project_id, - retrieval_mode="hybrid", - query_length=len(query_text), - candidate_limit=candidate_limit, - fts_count=len(fts_results), - vector_count=len(vector_results), - fts_ms=fts_ms, - vector_ms=vector_ms, - fusion_ms=fusion_ms, - total_ms=total_ms, - ) if total_ms > 2500: logger.warning( "[SEMANTIC_SLOW_QUERY] Semantic query timing: project_id={project_id} " diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index 60f855da..7ed823b9 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -308,6 +308,10 @@ async def read_file_bytes(self, path: FilePath) -> bytes: ) return content + except FileNotFoundError: + # Preserve FileNotFoundError so callers (e.g. sync) can treat it as deletion. + logger.warning("File not found", operation="read_file_bytes", path=str(full_path)) + raise except Exception as e: logger.exception("File read error", path=str(full_path), error=str(e)) raise FileOperationError(f"Failed to read file: {e}") diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index 29507ec1..de507f12 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -18,7 +18,7 @@ from basic_memory import telemetry from basic_memory import db from basic_memory.config import BasicMemoryConfig, ConfigManager -from basic_memory.file_utils import has_frontmatter +from basic_memory.file_utils import compute_checksum, has_frontmatter from basic_memory.indexing import BatchIndexer, IndexFileMetadata, IndexInputFile, IndexProgress from basic_memory.indexing.batching import build_index_batches from basic_memory.indexing.models import ( @@ -135,6 +135,8 @@ def __init__(self, file_service: FileService) -> None: async def write_frontmatter( self, update: IndexFrontmatterUpdate ) -> IndexFrontmatterWriteResult: + # Why: IndexFrontmatterWriteResult lives in indexing/models.py so the indexing + # layer does not need to import FileService. This adapter keeps that boundary intact. result = await self.file_service.update_frontmatter_with_result( update.path, update.metadata ) @@ -504,6 +506,22 @@ async def _index_changed_files( ) indexed_entities: list[IndexedEntity] = [] + shared_permalink_by_path: dict[str, str | None] | None = None + if any( + metadata.content_type == "text/markdown" + or ( + metadata.content_type is None + and Path(metadata.path).suffix.lower() in {".md", ".markdown"} + ) + for metadata in metadata_by_path.values() + ): + shared_permalink_by_path = { + path: permalink + for path, permalink in ( + await self.entity_repository.get_file_path_to_permalink_map() + ).items() + } + for batch in batches: loaded_files, load_errors = await self._load_index_batch_files( batch.paths, metadata_by_path @@ -516,6 +534,7 @@ async def _index_changed_files( loaded_files, max_concurrent=self.app_config.index_entity_max_concurrent, parse_max_concurrent=self.app_config.index_parse_max_concurrent, + existing_permalink_by_path=shared_permalink_by_path, ) indexed_entities.extend(batch_result.indexed) @@ -597,10 +616,11 @@ async def load(path: str) -> None: metadata = metadata_by_path[path] try: content = await self.file_service.read_file_bytes(path) + loaded_checksum = await compute_checksum(content) files[path] = IndexInputFile( path=metadata.path, size=metadata.size, - checksum=metadata.checksum, + checksum=loaded_checksum, content_type=metadata.content_type, last_modified=metadata.last_modified, created_at=metadata.created_at, diff --git a/test-int/test_sync_batching_integration.py b/test-int/test_sync_batching_integration.py new file mode 100644 index 00000000..73a40ab5 --- /dev/null +++ b/test-int/test_sync_batching_integration.py @@ -0,0 +1,201 @@ +"""Integration coverage for batched sync indexing.""" + +from __future__ import annotations + +from pathlib import Path +from textwrap import dedent + +import pytest + +from basic_memory.markdown import EntityParser, MarkdownProcessor +from basic_memory.repository import ( + EntityRepository, + ObservationRepository, + ProjectRepository, + RelationRepository, +) +from basic_memory.repository.search_repository import create_search_repository +from basic_memory.services import FileService +from basic_memory.services.entity_service import EntityService +from basic_memory.services.link_resolver import LinkResolver +from basic_memory.services.search_service import SearchService +from basic_memory.sync.sync_service import MAX_CONSECUTIVE_FAILURES, SyncService + + +async def _create_file(path: Path, content: str) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + + +async def _create_binary_file(path: Path, content: bytes) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(content) + + +async def _build_sync_service( + project_root: Path, + engine_factory, + app_config, + test_project, +) -> SyncService: + _, session_maker = engine_factory + + entity_repository = EntityRepository(session_maker, project_id=test_project.id) + observation_repository = ObservationRepository(session_maker, project_id=test_project.id) + relation_repository = RelationRepository(session_maker, project_id=test_project.id) + project_repository = ProjectRepository(session_maker) + search_repository = create_search_repository(session_maker, project_id=test_project.id) + + entity_parser = EntityParser(project_root) + markdown_processor = MarkdownProcessor(entity_parser) + file_service = FileService(project_root, markdown_processor) + search_service = SearchService(search_repository, entity_repository, file_service) + await search_service.init_search_index() + link_resolver = LinkResolver(entity_repository, search_service) + + entity_service = EntityService( + entity_parser=entity_parser, + entity_repository=entity_repository, + observation_repository=observation_repository, + relation_repository=relation_repository, + file_service=file_service, + link_resolver=link_resolver, + app_config=app_config, + ) + + return SyncService( + app_config=app_config, + entity_service=entity_service, + entity_parser=entity_parser, + entity_repository=entity_repository, + relation_repository=relation_repository, + project_repository=project_repository, + search_service=search_service, + file_service=file_service, + ) + + +@pytest.mark.asyncio +async def test_sync_batching_handles_large_single_file_batches_and_resolves_forward_refs( + engine_factory, + app_config, + test_project, +): + app_config.index_batch_size = 2 + app_config.index_batch_max_bytes = 256 + + project_root = Path(test_project.path) + sync_service = await _build_sync_service(project_root, engine_factory, app_config, test_project) + + await _create_file( + project_root / "notes/alpha.md", + dedent( + """ + --- + title: Alpha + type: note + --- + # Alpha + + - depends_on [[Target]] + """ + ).strip(), + ) + await _create_file( + project_root / "notes/large.md", + dedent( + f""" + --- + title: Large + type: note + --- + # Large + + {"x" * 2048} + """ + ).strip(), + ) + await _create_file( + project_root / "notes/target.md", + dedent( + """ + --- + title: Target + type: note + --- + # Target + """ + ).strip(), + ) + + report = await sync_service.sync( + project_root, + project_name=test_project.name, + force_full=True, + ) + + alpha = await sync_service.entity_repository.get_by_file_path("notes/alpha.md") + large = await sync_service.entity_repository.get_by_file_path("notes/large.md") + target = await sync_service.entity_repository.get_by_file_path("notes/target.md") + + assert report.total == 3 + assert alpha is not None + assert large is not None + assert target is not None + assert large.size is not None + assert large.size > app_config.index_batch_max_bytes + assert len(alpha.outgoing_relations) == 1 + assert alpha.outgoing_relations[0].to_id == target.id + + +@pytest.mark.asyncio +async def test_sync_batching_circuit_breaker_skips_unchanged_broken_markdown_after_threshold( + engine_factory, + app_config, + test_project, +): + app_config.index_batch_size = 1 + app_config.index_batch_max_bytes = 256 + + project_root = Path(test_project.path) + sync_service = await _build_sync_service(project_root, engine_factory, app_config, test_project) + + await _create_binary_file(project_root / "notes/broken.md", b"\xff\xfe\xfd") + + last_report = None + for _ in range(MAX_CONSECUTIVE_FAILURES): + last_report = await sync_service.sync( + project_root, + project_name=test_project.name, + force_full=True, + ) + + assert last_report is not None + assert [skipped.path for skipped in last_report.skipped_files] == ["notes/broken.md"] + assert sync_service._file_failures["notes/broken.md"].count == MAX_CONSECUTIVE_FAILURES + + await _create_file( + project_root / "notes/good.md", + dedent( + """ + --- + title: Good + type: note + --- + # Good + """ + ).strip(), + ) + + report = await sync_service.sync( + project_root, + project_name=test_project.name, + force_full=True, + ) + + good = await sync_service.entity_repository.get_by_file_path("notes/good.md") + broken = await sync_service.entity_repository.get_by_file_path("notes/broken.md") + + assert [skipped.path for skipped in report.skipped_files] == ["notes/broken.md"] + assert good is not None + assert broken is None diff --git a/tests/indexing/test_batch_indexer.py b/tests/indexing/test_batch_indexer.py index fc683a4a..cacf55aa 100644 --- a/tests/indexing/test_batch_indexer.py +++ b/tests/indexing/test_batch_indexer.py @@ -15,6 +15,7 @@ IndexFrontmatterWriteResult, IndexInputFile, ) +from basic_memory.services.exceptions import SyncFatalError class _TestFileWriter: @@ -470,3 +471,81 @@ async def test_batch_indexer_assigns_unique_permalinks_for_batch_local_conflicts assert len(entities) == 2 permalinks = [entity.permalink for entity in entities if entity.permalink] assert len(set(permalinks)) == 2 + + +@pytest.mark.asyncio +async def test_batch_indexer_uses_parsed_markdown_body_for_malformed_frontmatter_delimiters( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + project_config, +): + app_config.disable_permalinks = True + app_config.ensure_frontmatter_on_sync = False + + path = "notes/malformed.md" + malformed_content = dedent( + """ + --- + this is not valid frontmatter + # Malformed Frontmatter + + The parser should still index this file. + """ + ).strip() + await _create_file(project_config.home / path, malformed_content) + + files = {path: await _load_input(file_service, path)} + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + result = await batch_indexer.index_files( + files, + max_concurrent=1, + parse_max_concurrent=1, + ) + + assert result.errors == [] + assert len(result.indexed) == 1 + assert result.indexed[0].markdown_content == malformed_content + + entity = await entity_repository.get_by_file_path(path) + assert entity is not None + + +@pytest.mark.asyncio +async def test_batch_indexer_re_raises_fatal_sync_errors( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, +): + batch_indexer = _make_batch_indexer( + app_config, + entity_service, + entity_repository, + relation_repository, + search_service, + file_service, + ) + + async def fatal_worker(path: str) -> str: + raise SyncFatalError(f"fatal batch failure for {path}") + + with pytest.raises(SyncFatalError, match="fatal batch failure"): + await batch_indexer._run_bounded( + ["notes/fatal.md"], + limit=1, + worker=fatal_worker, + ) diff --git a/tests/sync/test_sync_service_batching.py b/tests/sync/test_sync_service_batching.py index 41186aca..498633c6 100644 --- a/tests/sync/test_sync_service_batching.py +++ b/tests/sync/test_sync_service_batching.py @@ -8,7 +8,9 @@ import pytest from sqlalchemy import text -from basic_memory.indexing import IndexProgress +from basic_memory.file_utils import compute_checksum +from basic_memory.indexing import IndexFileMetadata, IndexProgress +from basic_memory.sync.sync_service import MAX_CONSECUTIVE_FAILURES async def _create_file(path: Path, content: str) -> None: @@ -58,15 +60,26 @@ async def test_sync_batches_changed_files_emits_typed_progress_and_resolves_forw ) progress_updates: list[IndexProgress] = [] + original_get_permalink_map = entity_repository.get_file_path_to_permalink_map + permalink_map_calls = 0 async def on_progress(update: IndexProgress) -> None: progress_updates.append(update) - await sync_service.sync( - project_config.home, - project_name=project_config.name, - progress_callback=on_progress, - ) + async def spy_get_permalink_map() -> dict[str, str]: + nonlocal permalink_map_calls + permalink_map_calls += 1 + return await original_get_permalink_map() + + entity_repository.get_file_path_to_permalink_map = spy_get_permalink_map + try: + await sync_service.sync( + project_config.home, + project_name=project_config.name, + progress_callback=on_progress, + ) + finally: + entity_repository.get_file_path_to_permalink_map = original_get_permalink_map assert progress_updates assert all(isinstance(update, IndexProgress) for update in progress_updates) @@ -74,6 +87,7 @@ async def on_progress(update: IndexProgress) -> None: assert progress_updates[-1].files_processed == 2 assert progress_updates[-1].batches_total == 2 assert progress_updates[-1].batches_completed == 2 + assert permalink_map_calls == 1 source = await entity_repository.get_by_file_path("notes/source.md") target = await entity_repository.get_by_file_path("notes/target.md") @@ -90,3 +104,153 @@ async def on_progress(update: IndexProgress) -> None: {"entity_id": source.id}, ) assert relation_rows.scalar_one() == 1 + + +@pytest.mark.asyncio +async def test_index_changed_files_returns_empty_result_and_zero_progress(sync_service): + progress_updates: list[IndexProgress] = [] + + async def on_progress(update: IndexProgress) -> None: + progress_updates.append(update) + + indexed_entities, skipped_files = await sync_service._index_changed_files( + [], + {}, + progress_callback=on_progress, + ) + + assert indexed_entities == [] + assert skipped_files == [] + assert len(progress_updates) == 1 + assert progress_updates[0] == IndexProgress( + files_total=0, + files_processed=0, + batches_total=0, + batches_completed=0, + ) + + +@pytest.mark.asyncio +async def test_index_changed_files_skips_paths_blocked_by_circuit_breaker( + sync_service, + project_config, +): + skipped_path = "notes/skipped.md" + indexed_path = "notes/indexed.md" + await _create_file(project_config.home / skipped_path, "# Skipped\n") + await _create_file(project_config.home / indexed_path, "# Indexed\n") + + for attempt in range(MAX_CONSECUTIVE_FAILURES): + await sync_service._record_failure(skipped_path, f"failure {attempt}") + + indexed_entities, skipped_files = await sync_service._index_changed_files( + [skipped_path, indexed_path], + { + skipped_path: await sync_service.file_service.compute_checksum(skipped_path), + indexed_path: await sync_service.file_service.compute_checksum(indexed_path), + }, + ) + + assert [indexed.path for indexed in indexed_entities] == [indexed_path] + assert [skipped.path for skipped in skipped_files] == [skipped_path] + + +@pytest.mark.asyncio +async def test_load_index_file_metadata_tracks_missing_and_error_paths( + sync_service, + project_config, + monkeypatch, +): + error_path = "notes/error.md" + missing_path = "notes/missing.md" + await _create_file(project_config.home / error_path, "# Error\n") + + deleted_paths: list[str] = [] + original_get_file_metadata = sync_service.file_service.get_file_metadata + + async def spy_handle_delete(path: str) -> None: + deleted_paths.append(path) + + async def fake_get_file_metadata(path: str): + if path == error_path: + raise ValueError("metadata boom") + return await original_get_file_metadata(path) + + monkeypatch.setattr(sync_service, "handle_delete", spy_handle_delete) + monkeypatch.setattr(sync_service.file_service, "get_file_metadata", fake_get_file_metadata) + + metadata_by_path, errors, missing_paths = await sync_service._load_index_file_metadata( + [missing_path, error_path], + {}, + ) + + assert metadata_by_path == {} + assert errors == [(error_path, "metadata boom")] + assert missing_paths == [missing_path] + assert deleted_paths == [missing_path] + + +@pytest.mark.asyncio +async def test_load_index_batch_files_recomputes_checksum_from_loaded_bytes_and_tracks_errors( + sync_service, + project_config, + monkeypatch, +): + good_path = "notes/good.md" + error_path = "notes/error.md" + missing_path = "notes/missing.md" + await _create_file(project_config.home / good_path, "# Good\n") + await _create_file(project_config.home / error_path, "# Error\n") + + good_metadata = await sync_service.file_service.get_file_metadata(good_path) + error_metadata = await sync_service.file_service.get_file_metadata(error_path) + metadata_by_path = { + good_path: IndexFileMetadata( + path=good_path, + size=good_metadata.size, + checksum="stale-checksum", + content_type=sync_service.file_service.content_type(good_path), + last_modified=good_metadata.modified_at, + created_at=good_metadata.created_at, + ), + error_path: IndexFileMetadata( + path=error_path, + size=error_metadata.size, + checksum="ignored", + content_type=sync_service.file_service.content_type(error_path), + last_modified=error_metadata.modified_at, + created_at=error_metadata.created_at, + ), + missing_path: IndexFileMetadata( + path=missing_path, + size=0, + checksum="missing", + content_type="text/markdown", + ), + } + + deleted_paths: list[str] = [] + original_read_file_bytes = sync_service.file_service.read_file_bytes + + async def spy_handle_delete(path: str) -> None: + deleted_paths.append(path) + + async def fake_read_file_bytes(path: str) -> bytes: + if path == good_path: + return b"# Loaded\n" + if path == error_path: + raise ValueError("load boom") + return await original_read_file_bytes(path) + + monkeypatch.setattr(sync_service, "handle_delete", spy_handle_delete) + monkeypatch.setattr(sync_service.file_service, "read_file_bytes", fake_read_file_bytes) + + files, errors = await sync_service._load_index_batch_files( + [good_path, error_path, missing_path], + metadata_by_path, + ) + + assert files[good_path].checksum == await compute_checksum(b"# Loaded\n") + assert files[good_path].checksum != "stale-checksum" + assert errors == [(error_path, "load boom")] + assert deleted_paths == [missing_path] From d986c4dfc4693d3b67b9eedf38117f374ffae100 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 00:22:45 -0500 Subject: [PATCH 4/9] fix(core): preserve file service read contract Signed-off-by: phernandez --- src/basic_memory/services/file_service.py | 6 +----- src/basic_memory/sync/sync_service.py | 10 +++++++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index 7ed823b9..c7081e15 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -308,13 +308,9 @@ async def read_file_bytes(self, path: FilePath) -> bytes: ) return content - except FileNotFoundError: - # Preserve FileNotFoundError so callers (e.g. sync) can treat it as deletion. - logger.warning("File not found", operation="read_file_bytes", path=str(full_path)) - raise except Exception as e: logger.exception("File read error", path=str(full_path), error=str(e)) - raise FileOperationError(f"Failed to read file: {e}") + raise FileOperationError(f"Failed to read file: {e}") from e async def read_file(self, path: FilePath) -> Tuple[str, str]: """Read file and compute checksum using true async I/O. diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index de507f12..4e75c040 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -39,7 +39,7 @@ from basic_memory.repository.search_repository import create_search_repository from basic_memory.services import EntityService, FileService from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError -from basic_memory.services.exceptions import SyncFatalError +from basic_memory.services.exceptions import FileOperationError, SyncFatalError from basic_memory.services.link_resolver import LinkResolver from basic_memory.services.search_service import SearchService @@ -628,6 +628,14 @@ async def load(path: str) -> None: ) except FileNotFoundError: await self.handle_delete(path) + except FileOperationError as exc: + # Trigger: FileService wraps binary read failures in FileOperationError. + # Why: the service contract should stay consistent for direct callers. + # Outcome: sync still treats wrapped missing-file reads as deletions. + if isinstance(exc.__cause__, FileNotFoundError): + await self.handle_delete(path) + else: + errors[path] = str(exc) except Exception as exc: errors[path] = str(exc) From cdffd879a9e884f02e24db95d0a3522b24647612 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 00:34:44 -0500 Subject: [PATCH 5/9] feat(core): support note-level embedding opt-out Signed-off-by: phernandez --- src/basic_memory/services/search_service.py | 106 +++++++++++++++- tests/services/test_semantic_search.py | 128 ++++++++++++++++++++ 2 files changed, 231 insertions(+), 3 deletions(-) diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index b1a77051..12e42cbb 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -427,6 +427,15 @@ async def index_entity_data( async def sync_entity_vectors(self, entity_id: int) -> None: """Refresh vector chunks for one entity in repositories that support semantic indexing.""" + entity = await self.entity_repository.find_by_id(entity_id) + if entity is None: + await self._clear_entity_vectors(entity_id) + return + + if not self._entity_embeddings_enabled(entity): + await self._clear_entity_vectors(entity_id) + return + await self.repository.sync_entity_vectors(entity_id) async def sync_entity_vectors_batch( @@ -435,10 +444,47 @@ async def sync_entity_vectors_batch( progress_callback=None, ) -> VectorSyncBatchResult: """Refresh vector chunks for a batch of entities.""" - return await self.repository.sync_entity_vectors_batch( - entity_ids, + if not entity_ids: + return VectorSyncBatchResult( + entities_total=0, + entities_synced=0, + entities_failed=0, + ) + + entities_by_id = { + entity.id: entity for entity in await self.entity_repository.find_by_ids(entity_ids) + } + opted_out_ids = [ + entity_id + for entity_id in entity_ids + if ( + (entity := entities_by_id.get(entity_id)) is not None + and not self._entity_embeddings_enabled(entity) + ) + ] + for entity_id in opted_out_ids: + await self._clear_entity_vectors(entity_id) + + eligible_entity_ids = [ + entity_id + for entity_id in entity_ids + if entity_id in entities_by_id and entity_id not in opted_out_ids + ] + if not eligible_entity_ids: + return VectorSyncBatchResult( + entities_total=len(entity_ids), + entities_synced=0, + entities_failed=0, + entities_skipped=len(opted_out_ids), + ) + + batch_result = await self.repository.sync_entity_vectors_batch( + eligible_entity_ids, progress_callback=progress_callback, ) + batch_result.entities_total = len(entity_ids) + batch_result.entities_skipped += len(opted_out_ids) + return batch_result async def reindex_vectors(self, progress_callback=None) -> dict: """Rebuild vector embeddings for all entities. @@ -463,7 +509,7 @@ async def reindex_vectors(self, progress_callback=None) -> dict: stats = { "total_entities": batch_result.entities_total, "embedded": batch_result.entities_synced, - "skipped": 0, + "skipped": batch_result.entities_skipped, "errors": batch_result.entities_failed, } @@ -518,6 +564,60 @@ async def _purge_stale_search_rows(self) -> None: logger.info("Purged stale search rows for deleted entities", project_id=project_id) + @staticmethod + def _entity_embeddings_enabled(entity: Entity) -> bool: + """Return whether semantic embeddings should be generated for this entity.""" + if not entity.entity_metadata: + return True + + embed_value = entity.entity_metadata.get("embed") + if embed_value is None: + return True + if isinstance(embed_value, bool): + return embed_value + if isinstance(embed_value, str): + normalized = embed_value.strip().lower() + if normalized in {"false", "0", "no", "off"}: + return False + if normalized in {"true", "1", "yes", "on"}: + return True + if isinstance(embed_value, (int, float)): + return bool(embed_value) + + # Default unknown values to enabled so malformed metadata does not silently + # remove notes from semantic search. + return True + + async def _clear_entity_vectors(self, entity_id: int) -> None: + """Delete derived vector rows for one entity.""" + from basic_memory.repository.search_repository_base import SearchRepositoryBase + from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository + + # Trigger: semantic indexing is disabled for this repository instance. + # Why: repositories only create vector tables when semantic search is enabled. + # Outcome: skip cleanup because there are no active derived vector rows to maintain. + if isinstance(self.repository, SearchRepositoryBase) and not self.repository._semantic_enabled: + return + + params = {"project_id": self.repository.project_id, "entity_id": entity_id} + if isinstance(self.repository, SQLiteSearchRepository): + await self.repository.execute_query( + text( + "DELETE FROM search_vector_embeddings WHERE rowid IN (" + "SELECT id FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id)" + ), + params, + ) + + await self.repository.execute_query( + text( + "DELETE FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + params, + ) + async def index_entity_file( self, entity: Entity, diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index 0ddb6962..a316137c 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -1,7 +1,13 @@ """Semantic search service regression tests for local SQLite search.""" +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest +from basic_memory.repository import EntityRepository +from basic_memory.repository.search_repository_base import VectorSyncBatchResult from basic_memory.repository.semantic_errors import ( SemanticDependenciesMissingError, SemanticSearchDisabledError, @@ -89,3 +95,125 @@ async def test_semantic_fts_mode_still_returns_observations(search_service, test assert results assert any(result.type == SearchItemType.OBSERVATION.value for result in results) + + +@pytest.mark.asyncio +async def test_semantic_vector_sync_skips_embed_opt_out_and_clears_vectors( + search_service, monkeypatch +): + """Embed opt-out should clear stale vectors instead of regenerating them.""" + repository = _sqlite_repo(search_service) + repository._semantic_enabled = True + + monkeypatch.setattr( + search_service.entity_repository, + "find_by_id", + AsyncMock(return_value=SimpleNamespace(id=42, entity_metadata={"embed": False})), + ) + sync_vectors = AsyncMock() + execute_query = AsyncMock() + monkeypatch.setattr(repository, "sync_entity_vectors", sync_vectors) + monkeypatch.setattr(repository, "execute_query", execute_query) + + await search_service.sync_entity_vectors(42) + + sync_vectors.assert_not_awaited() + assert execute_query.await_count == 2 + + +@pytest.mark.asyncio +async def test_semantic_vector_sync_resumes_when_embed_opt_out_removed( + search_service, monkeypatch +): + """Removing the opt-out should restore normal embedding sync.""" + repository = _sqlite_repo(search_service) + repository._semantic_enabled = True + + monkeypatch.setattr( + search_service.entity_repository, + "find_by_id", + AsyncMock(return_value=SimpleNamespace(id=42, entity_metadata={})), + ) + sync_vectors = AsyncMock() + execute_query = AsyncMock() + monkeypatch.setattr(repository, "sync_entity_vectors", sync_vectors) + monkeypatch.setattr(repository, "execute_query", execute_query) + + await search_service.sync_entity_vectors(42) + + sync_vectors.assert_awaited_once_with(42) + execute_query.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_semantic_vector_sync_batch_skips_embed_opt_out_and_reports_skips( + search_service, monkeypatch +): + """Batch vector sync should only embed eligible notes and report skipped opt-outs.""" + repository = _sqlite_repo(search_service) + repository._semantic_enabled = True + + monkeypatch.setattr( + search_service.entity_repository, + "find_by_ids", + AsyncMock( + return_value=[ + SimpleNamespace(id=41, entity_metadata={"embed": False}), + SimpleNamespace(id=42, entity_metadata={}), + ] + ), + ) + sync_batch = AsyncMock( + return_value=VectorSyncBatchResult( + entities_total=1, + entities_synced=1, + entities_failed=0, + ) + ) + execute_query = AsyncMock() + monkeypatch.setattr(repository, "sync_entity_vectors_batch", sync_batch) + monkeypatch.setattr(repository, "execute_query", execute_query) + + result = await search_service.sync_entity_vectors_batch([41, 42]) + + sync_batch.assert_awaited_once() + assert sync_batch.await_args.args[0] == [42] + assert result.entities_total == 2 + assert result.entities_synced == 1 + assert result.entities_skipped == 1 + assert execute_query.await_count == 2 + + +@pytest.mark.asyncio +async def test_embed_opt_out_note_still_participates_in_fts( + search_service, session_maker, test_project +): + """Per-note semantic opt-out should not remove the note from FTS search.""" + entity_repo = EntityRepository(session_maker, project_id=test_project.id) + entity = await entity_repo.create( + { + "title": "FTS Opt Out", + "note_type": "note", + "entity_metadata": {"embed": False}, + "content_type": "text/markdown", + "file_path": "test/fts-opt-out.md", + "permalink": "test/fts-opt-out", + "project_id": test_project.id, + "created_at": datetime.now(), + "updated_at": datetime.now(), + } + ) + + await search_service.index_entity( + entity, + content="This note should stay searchable through full text indexing.", + ) + + results = await search_service.search( + SearchQuery( + text="stay searchable", + retrieval_mode=SearchRetrievalMode.FTS, + ) + ) + + assert any(result.entity_id == entity.id for result in results) From 1207e8c664ddc3b7487f58deea7baccac0e24d96 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 00:43:41 -0500 Subject: [PATCH 6/9] fix(core): respect embedding opt-outs during reindex Signed-off-by: phernandez --- src/basic_memory/services/search_service.py | 9 +++-- src/basic_memory/sync/sync_service.py | 2 -- tests/services/test_semantic_search.py | 37 +++++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index 12e42cbb..ae13bc1d 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -1,5 +1,6 @@ """Service for search operations.""" +import asyncio import ast import re from datetime import datetime @@ -462,8 +463,10 @@ async def sync_entity_vectors_batch( and not self._entity_embeddings_enabled(entity) ) ] - for entity_id in opted_out_ids: - await self._clear_entity_vectors(entity_id) + if opted_out_ids: + await asyncio.gather( + *(self._clear_entity_vectors(entity_id) for entity_id in opted_out_ids) + ) eligible_entity_ids = [ entity_id @@ -502,7 +505,7 @@ async def reindex_vectors(self, progress_callback=None) -> dict: # that reference entity_ids no longer in the entity table await self._purge_stale_search_rows() - batch_result = await self.repository.sync_entity_vectors_batch( + batch_result = await self.sync_entity_vectors_batch( entity_ids, progress_callback=progress_callback, ) diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index 4e75c040..c912d549 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -626,8 +626,6 @@ async def load(path: str) -> None: created_at=metadata.created_at, content=content, ) - except FileNotFoundError: - await self.handle_delete(path) except FileOperationError as exc: # Trigger: FileService wraps binary read failures in FileOperationError. # Why: the service contract should stay consistent for direct callers. diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index a316137c..bd106e22 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -217,3 +217,40 @@ async def test_embed_opt_out_note_still_participates_in_fts( ) assert any(result.entity_id == entity.id for result in results) + + +@pytest.mark.asyncio +async def test_reindex_vectors_respects_embed_opt_out(search_service, monkeypatch): + """Full vector reindex should route through the service-level opt-out filter.""" + monkeypatch.setattr( + search_service.entity_repository, + "find_all", + AsyncMock( + return_value=[ + SimpleNamespace(id=41, entity_metadata={"embed": False}), + SimpleNamespace(id=42, entity_metadata={}), + ] + ), + ) + purge_stale_rows = AsyncMock() + sync_batch = AsyncMock( + return_value=VectorSyncBatchResult( + entities_total=2, + entities_synced=1, + entities_failed=0, + entities_skipped=1, + ) + ) + monkeypatch.setattr(search_service, "_purge_stale_search_rows", purge_stale_rows) + monkeypatch.setattr(search_service, "sync_entity_vectors_batch", sync_batch) + + stats = await search_service.reindex_vectors() + + purge_stale_rows.assert_awaited_once() + sync_batch.assert_awaited_once_with([41, 42], progress_callback=None) + assert stats == { + "total_entities": 2, + "embedded": 1, + "skipped": 1, + "errors": 0, + } From fe9fc7e2f3cc9c74866023bad050d7eb02e17dca Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 00:55:20 -0500 Subject: [PATCH 7/9] fix(core): remove automatic embedding backfill Signed-off-by: phernandez --- src/basic_memory/db.py | 103 ------------ src/basic_memory/mcp/server.py | 38 +---- tests/services/test_initialization.py | 225 -------------------------- 3 files changed, 2 insertions(+), 364 deletions(-) diff --git a/src/basic_memory/db.py b/src/basic_memory/db.py index 51792d01..6133e953 100644 --- a/src/basic_memory/db.py +++ b/src/basic_memory/db.py @@ -44,101 +44,6 @@ _session_maker: Optional[async_sessionmaker[AsyncSession]] = None -async def _needs_semantic_embedding_backfill( - app_config: BasicMemoryConfig, - session_maker: async_sessionmaker[AsyncSession], -) -> bool: - """Check if entities exist but vector embeddings are empty. - - This is the reliable way to detect that embeddings need to be generated, - regardless of how migrations were applied (fresh DB, upgrade, reset, etc.). - """ - if not app_config.semantic_search_enabled: - return False - - try: - async with scoped_session(session_maker) as session: - entity_count = ( - await session.execute(text("SELECT COUNT(*) FROM entity")) - ).scalar() or 0 - if entity_count == 0: - return False - - # Check if vector chunks table exists and is empty - embedding_count = ( - await session.execute(text("SELECT COUNT(*) FROM search_vector_chunks")) - ).scalar() or 0 - - return embedding_count == 0 - except Exception as exc: - # Table might not exist yet (pre-migration) - logger.debug(f"Could not check embedding status: {exc}") - return False - - -async def _run_semantic_embedding_backfill( - app_config: BasicMemoryConfig, - session_maker: async_sessionmaker[AsyncSession], -) -> None: - """Backfill semantic embeddings for all active projects/entities.""" - if not app_config.semantic_search_enabled: - logger.info("Skipping automatic semantic embedding backfill: semantic search is disabled.") - return - - async with scoped_session(session_maker) as session: - project_result = await session.execute( - text("SELECT id, name FROM project WHERE is_active = :is_active ORDER BY id"), - {"is_active": True}, - ) - projects = [(int(row[0]), str(row[1])) for row in project_result.fetchall()] - - if not projects: - logger.info("Skipping automatic semantic embedding backfill: no active projects found.") - return - - repository_class = ( - PostgresSearchRepository - if app_config.database_backend == DatabaseBackend.POSTGRES - else SQLiteSearchRepository - ) - - total_entities = 0 - for project_id, project_name in projects: - async with scoped_session(session_maker) as session: - entity_result = await session.execute( - text("SELECT id FROM entity WHERE project_id = :project_id ORDER BY id"), - {"project_id": project_id}, - ) - entity_ids = [int(row[0]) for row in entity_result.fetchall()] - - if not entity_ids: - continue - - total_entities += len(entity_ids) - logger.info( - "Automatic semantic embedding backfill: " - f"project={project_name}, entities={len(entity_ids)}" - ) - - search_repository = repository_class( - session_maker, - project_id=project_id, - app_config=app_config, - ) - batch_result = await search_repository.sync_entity_vectors_batch(entity_ids) - if batch_result.entities_failed > 0: - logger.warning( - "Automatic semantic embedding backfill encountered entity failures: " - f"project={project_name}, failed={batch_result.entities_failed}, " - f"failed_entity_ids={batch_result.failed_entity_ids}" - ) - - logger.info( - "Automatic semantic embedding backfill complete: " - f"projects={len(projects)}, entities={total_entities}" - ) - - class DatabaseType(Enum): """Types of supported databases.""" @@ -521,14 +426,6 @@ async def run_migrations( else: await SQLiteSearchRepository(session_maker, 1).init_search_index() - # Check if backfill is needed — actual backfill runs in background - # from the MCP server lifespan to avoid blocking startup. - if await _needs_semantic_embedding_backfill(app_config, session_maker): - logger.info( - "Semantic embeddings missing — backfill will run in background after startup" - ) - else: - logger.info("Semantic embeddings: up to date") except Exception as e: # pragma: no cover logger.error(f"Error running migrations: {e}") raise diff --git a/src/basic_memory/mcp/server.py b/src/basic_memory/mcp/server.py index 9b10b38d..e2a9eaa8 100644 --- a/src/basic_memory/mcp/server.py +++ b/src/basic_memory/mcp/server.py @@ -2,7 +2,6 @@ Basic Memory FastMCP server. """ -import asyncio import time from contextlib import asynccontextmanager @@ -13,12 +12,7 @@ from basic_memory import db from basic_memory.cli.auth import CLIAuth -from basic_memory.config import BasicMemoryConfig -from basic_memory.db import ( - scoped_session, - _needs_semantic_embedding_backfill, - _run_semantic_embedding_backfill, -) +from basic_memory.db import scoped_session from basic_memory.mcp.container import McpContainer, set_container from basic_memory.services.initialization import initialize_app from basic_memory import telemetry @@ -43,7 +37,7 @@ async def _log_embedding_status(session_maker: async_sessionmaker[AsyncSession]) elif embedding_count == 0: logger.warning( f"Semantic embeddings: EMPTY — {entity_count} entities have no embeddings. " - "Backfill running in background..." + "Run 'bm reindex --embeddings' to build them." ) else: logger.info( @@ -54,20 +48,6 @@ async def _log_embedding_status(session_maker: async_sessionmaker[AsyncSession]) logger.debug(f"Could not check embedding status at startup: {exc}") -async def _background_embedding_backfill( - config: BasicMemoryConfig, - session_maker: async_sessionmaker[AsyncSession], -) -> None: - """Run semantic embedding backfill in the background without blocking startup.""" - try: - if await _needs_semantic_embedding_backfill(config, session_maker): - logger.info("Background embedding backfill starting...") - await _run_semantic_embedding_backfill(config, session_maker) - await _log_embedding_status(session_maker) - except Exception as exc: - logger.error(f"Background embedding backfill failed: {exc}") - - @asynccontextmanager async def lifespan(app: FastMCP): """Lifecycle manager for the MCP server. @@ -133,14 +113,8 @@ async def lifespan(app: FastMCP): await initialize_app(container.config) # Log embedding status so it's easy to spot in the logs - backfill_task: asyncio.Task | None = None # type: ignore[type-arg] if config.semantic_search_enabled and db._session_maker is not None: await _log_embedding_status(db._session_maker) - # Launch backfill in background so MCP server is ready immediately - backfill_task = asyncio.create_task( - _background_embedding_backfill(config, db._session_maker), - name="embedding-backfill", - ) # Create and start sync coordinator (lifecycle centralized in coordinator) sync_coordinator = container.create_sync_coordinator() @@ -157,14 +131,6 @@ async def lifespan(app: FastMCP): ): logger.debug("Shutting down Basic Memory MCP server") - # Cancel embedding backfill if still running - if backfill_task is not None and not backfill_task.done(): - backfill_task.cancel() - try: - await backfill_task - except asyncio.CancelledError: - logger.info("Background embedding backfill cancelled during shutdown") - await sync_coordinator.stop() # Only shutdown DB if we created it (not if test fixture provided it) diff --git a/tests/services/test_initialization.py b/tests/services/test_initialization.py index 7ca22135..37c28285 100644 --- a/tests/services/test_initialization.py +++ b/tests/services/test_initialization.py @@ -6,7 +6,6 @@ from __future__ import annotations -from datetime import datetime from unittest.mock import AsyncMock import pytest @@ -198,227 +197,3 @@ def capture_warning(message: str) -> None: for message in warnings ) - -@pytest.mark.asyncio -async def test_run_migrations_triggers_embedding_backfill_when_entities_exist_but_no_embeddings( - monkeypatch, app_config: BasicMemoryConfig -): - """run_migrations checks for missing embeddings (actual backfill runs in background from MCP).""" - - class StubSearchRepository: - def __init__(self, *args, **kwargs): - pass - - async def init_search_index(self): - return None - - original_session_maker = db._session_maker # pyright: ignore [reportPrivateUsage] - try: - session_marker = object() - db._session_maker = session_marker # pyright: ignore [reportPrivateUsage] - - monkeypatch.setattr( - "basic_memory.db.command.upgrade", - lambda *args, **kwargs: None, - ) - monkeypatch.setattr("basic_memory.db.SQLiteSearchRepository", StubSearchRepository) - monkeypatch.setattr("basic_memory.db.PostgresSearchRepository", StubSearchRepository) - - needs_backfill_mock = AsyncMock(return_value=True) - monkeypatch.setattr( - "basic_memory.db._needs_semantic_embedding_backfill", needs_backfill_mock - ) - - await db.run_migrations(app_config) - - # Verifies the check runs — backfill itself is launched by MCP lifespan - needs_backfill_mock.assert_awaited_once_with(app_config, session_marker) - finally: - db._session_maker = original_session_maker # pyright: ignore [reportPrivateUsage] - - -@pytest.mark.asyncio -async def test_run_migrations_skips_embedding_backfill_when_embeddings_already_exist( - monkeypatch, app_config: BasicMemoryConfig -): - """When embeddings already exist, no backfill is needed.""" - - class StubSearchRepository: - def __init__(self, *args, **kwargs): - pass - - async def init_search_index(self): - return None - - original_session_maker = db._session_maker # pyright: ignore [reportPrivateUsage] - try: - session_marker = object() - db._session_maker = session_marker # pyright: ignore [reportPrivateUsage] - - monkeypatch.setattr( - "basic_memory.db.command.upgrade", - lambda *args, **kwargs: None, - ) - monkeypatch.setattr("basic_memory.db.SQLiteSearchRepository", StubSearchRepository) - monkeypatch.setattr("basic_memory.db.PostgresSearchRepository", StubSearchRepository) - - needs_backfill_mock = AsyncMock(return_value=False) - monkeypatch.setattr( - "basic_memory.db._needs_semantic_embedding_backfill", needs_backfill_mock - ) - - await db.run_migrations(app_config) - - needs_backfill_mock.assert_awaited_once_with(app_config, session_marker) - finally: - db._session_maker = original_session_maker # pyright: ignore [reportPrivateUsage] - - -@pytest.mark.asyncio -async def test_semantic_embedding_backfill_syncs_each_entity( - monkeypatch, - app_config: BasicMemoryConfig, - session_maker, - test_project, -): - """Automatic backfill should run sync_entity_vectors for every entity in active projects.""" - from basic_memory.repository.entity_repository import EntityRepository - - entity_repository = EntityRepository(session_maker, project_id=test_project.id) - created_entity_ids: list[int] = [] - for i in range(3): - entity = await entity_repository.create( - { - "title": f"Backfill Entity {i}", - "note_type": "note", - "entity_metadata": {}, - "content_type": "text/markdown", - "file_path": f"test/backfill-{i}.md", - "permalink": f"test/backfill-{i}", - "project_id": test_project.id, - "created_at": datetime.now(), - "updated_at": datetime.now(), - } - ) - created_entity_ids.append(entity.id) - - synced_pairs: list[tuple[int, int]] = [] - - class StubSearchRepository: - def __init__(self, _session_maker, project_id: int, app_config=None): - self.project_id = project_id - - async def sync_entity_vectors_batch(self, entity_ids: list[int], progress_callback=None): - for entity_id in entity_ids: - synced_pairs.append((self.project_id, entity_id)) - from basic_memory.repository.search_repository_base import VectorSyncBatchResult - - return VectorSyncBatchResult( - entities_total=len(entity_ids), - entities_synced=len(entity_ids), - entities_failed=0, - failed_entity_ids=[], - embedding_jobs_total=0, - embed_seconds_total=0.0, - write_seconds_total=0.0, - ) - - monkeypatch.setattr("basic_memory.db.SQLiteSearchRepository", StubSearchRepository) - monkeypatch.setattr("basic_memory.db.PostgresSearchRepository", StubSearchRepository) - - app_config.semantic_search_enabled = True - - await db._run_semantic_embedding_backfill(app_config, session_maker) # pyright: ignore [reportPrivateUsage] - - expected_pairs = {(test_project.id, entity_id) for entity_id in created_entity_ids} - assert expected_pairs.issubset(set(synced_pairs)) - - -@pytest.mark.asyncio -async def test_semantic_embedding_backfill_skips_when_semantic_disabled( - monkeypatch, - app_config: BasicMemoryConfig, - session_maker, -): - """Automatic backfill should no-op when semantic search is disabled.""" - called = False - - class StubSearchRepository: - def __init__(self, *args, **kwargs): - nonlocal called - called = True - - async def sync_entity_vectors_batch(self, entity_ids: list[int], progress_callback=None): - from basic_memory.repository.search_repository_base import VectorSyncBatchResult - - return VectorSyncBatchResult( - entities_total=len(entity_ids), - entities_synced=len(entity_ids), - entities_failed=0, - failed_entity_ids=[], - embedding_jobs_total=0, - embed_seconds_total=0.0, - write_seconds_total=0.0, - ) - - monkeypatch.setattr("basic_memory.db.SQLiteSearchRepository", StubSearchRepository) - monkeypatch.setattr("basic_memory.db.PostgresSearchRepository", StubSearchRepository) - - app_config.semantic_search_enabled = False - await db._run_semantic_embedding_backfill(app_config, session_maker) # pyright: ignore [reportPrivateUsage] - assert called is False - - -@pytest.mark.asyncio -async def test_needs_semantic_embedding_backfill_true_when_entities_exist_no_embeddings( - app_config: BasicMemoryConfig, - session_maker, - test_project, -): - """Should return True when entities exist but vector chunks table is empty.""" - from basic_memory.repository.entity_repository import EntityRepository - - entity_repository = EntityRepository(session_maker, project_id=test_project.id) - await entity_repository.create( - { - "title": "Test Entity", - "note_type": "note", - "entity_metadata": {}, - "content_type": "text/markdown", - "file_path": "test/backfill-check.md", - "permalink": "test/backfill-check", - "project_id": test_project.id, - "created_at": datetime.now(), - "updated_at": datetime.now(), - } - ) - - # Clear any embeddings left by other tests in the shared DB - async with db.scoped_session(session_maker) as session: - await session.execute(db.text("DELETE FROM search_vector_chunks")) - - app_config.semantic_search_enabled = True - result = await db._needs_semantic_embedding_backfill(app_config, session_maker) # pyright: ignore [reportPrivateUsage] - assert result is True - - -@pytest.mark.asyncio -async def test_needs_semantic_embedding_backfill_false_when_no_entities( - app_config: BasicMemoryConfig, - session_maker, -): - """Should return False when no entities exist (nothing to backfill).""" - app_config.semantic_search_enabled = True - result = await db._needs_semantic_embedding_backfill(app_config, session_maker) # pyright: ignore [reportPrivateUsage] - assert result is False - - -@pytest.mark.asyncio -async def test_needs_semantic_embedding_backfill_false_when_semantic_disabled( - app_config: BasicMemoryConfig, - session_maker, -): - """Should return False when semantic search is disabled.""" - app_config.semantic_search_enabled = False - result = await db._needs_semantic_embedding_backfill(app_config, session_maker) # pyright: ignore [reportPrivateUsage] - assert result is False From a83bee9322b35c108cf538d46625a76d4ee84eca Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 01:06:13 -0500 Subject: [PATCH 8/9] fix(core): clean up stale vector ids Signed-off-by: phernandez --- src/basic_memory/services/search_service.py | 45 ++++++++++++++++++--- tests/services/test_semantic_search.py | 38 +++++++++++++++++ 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index ae13bc1d..0a6d1430 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -455,6 +455,7 @@ async def sync_entity_vectors_batch( entities_by_id = { entity.id: entity for entity in await self.entity_repository.find_by_ids(entity_ids) } + unknown_ids = [entity_id for entity_id in entity_ids if entity_id not in entities_by_id] opted_out_ids = [ entity_id for entity_id in entity_ids @@ -468,12 +469,27 @@ async def sync_entity_vectors_batch( *(self._clear_entity_vectors(entity_id) for entity_id in opted_out_ids) ) + repository_results: list[VectorSyncBatchResult] = [] + if unknown_ids: + # Trigger: a caller passes entity IDs that were deleted after the batch was built. + # Why: repository sync still owns stale chunk cleanup for IDs with no source rows. + # Outcome: deleted entities do not silently keep orphaned vector rows forever. + repository_results.append(await self.repository.sync_entity_vectors_batch(unknown_ids)) + eligible_entity_ids = [ entity_id for entity_id in entity_ids if entity_id in entities_by_id and entity_id not in opted_out_ids ] - if not eligible_entity_ids: + if eligible_entity_ids: + repository_results.append( + await self.repository.sync_entity_vectors_batch( + eligible_entity_ids, + progress_callback=progress_callback, + ) + ) + + if not repository_results: return VectorSyncBatchResult( entities_total=len(entity_ids), entities_synced=0, @@ -481,12 +497,29 @@ async def sync_entity_vectors_batch( entities_skipped=len(opted_out_ids), ) - batch_result = await self.repository.sync_entity_vectors_batch( - eligible_entity_ids, - progress_callback=progress_callback, + batch_result = VectorSyncBatchResult( + entities_total=len(entity_ids), + entities_synced=sum(result.entities_synced for result in repository_results), + entities_failed=sum(result.entities_failed for result in repository_results), + entities_deferred=sum(result.entities_deferred for result in repository_results), + entities_skipped=( + len(opted_out_ids) + sum(result.entities_skipped for result in repository_results) + ), + failed_entity_ids=[ + failed_entity_id + for result in repository_results + for failed_entity_id in result.failed_entity_ids + ], + chunks_total=sum(result.chunks_total for result in repository_results), + chunks_skipped=sum(result.chunks_skipped for result in repository_results), + embedding_jobs_total=sum(result.embedding_jobs_total for result in repository_results), + prepare_seconds_total=sum(result.prepare_seconds_total for result in repository_results), + queue_wait_seconds_total=sum( + result.queue_wait_seconds_total for result in repository_results + ), + embed_seconds_total=sum(result.embed_seconds_total for result in repository_results), + write_seconds_total=sum(result.write_seconds_total for result in repository_results), ) - batch_result.entities_total = len(entity_ids) - batch_result.entities_skipped += len(opted_out_ids) return batch_result async def reindex_vectors(self, progress_callback=None) -> dict: diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index bd106e22..db5c01df 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -254,3 +254,41 @@ async def test_reindex_vectors_respects_embed_opt_out(search_service, monkeypatc "skipped": 1, "errors": 0, } + + +@pytest.mark.asyncio +async def test_semantic_vector_sync_batch_cleans_up_unknown_ids(search_service, monkeypatch): + """Deleted entity IDs should still flow through repository cleanup instead of being dropped.""" + repository = _sqlite_repo(search_service) + repository._semantic_enabled = True + + monkeypatch.setattr( + search_service.entity_repository, + "find_by_ids", + AsyncMock(return_value=[SimpleNamespace(id=42, entity_metadata={})]), + ) + sync_batch = AsyncMock( + side_effect=[ + VectorSyncBatchResult( + entities_total=1, + entities_synced=1, + entities_failed=0, + ), + VectorSyncBatchResult( + entities_total=1, + entities_synced=1, + entities_failed=0, + ), + ] + ) + monkeypatch.setattr(repository, "sync_entity_vectors_batch", sync_batch) + + result = await search_service.sync_entity_vectors_batch([41, 42]) + + assert sync_batch.await_count == 2 + assert sync_batch.await_args_list[0].args[0] == [41] + assert sync_batch.await_args_list[1].args[0] == [42] + assert result.entities_total == 2 + assert result.entities_synced == 2 + assert result.entities_failed == 0 + assert result.entities_skipped == 0 From 1fee913950dab3324d16f3944f24291e78fc6e59 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 01:16:53 -0500 Subject: [PATCH 9/9] fix(core): parallelize stale vector cleanup Signed-off-by: phernandez --- src/basic_memory/services/search_service.py | 36 +++++++++++++-------- tests/services/test_semantic_search.py | 14 ++++++-- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index 0a6d1430..102d69be 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -469,25 +469,31 @@ async def sync_entity_vectors_batch( *(self._clear_entity_vectors(entity_id) for entity_id in opted_out_ids) ) - repository_results: list[VectorSyncBatchResult] = [] - if unknown_ids: - # Trigger: a caller passes entity IDs that were deleted after the batch was built. - # Why: repository sync still owns stale chunk cleanup for IDs with no source rows. - # Outcome: deleted entities do not silently keep orphaned vector rows forever. - repository_results.append(await self.repository.sync_entity_vectors_batch(unknown_ids)) - eligible_entity_ids = [ entity_id for entity_id in entity_ids if entity_id in entities_by_id and entity_id not in opted_out_ids ] - if eligible_entity_ids: - repository_results.append( - await self.repository.sync_entity_vectors_batch( - eligible_entity_ids, - progress_callback=progress_callback, - ) + + cleanup_task = ( + self.repository.sync_entity_vectors_batch(unknown_ids) if unknown_ids else None + ) + eligible_task = ( + self.repository.sync_entity_vectors_batch( + eligible_entity_ids, + progress_callback=progress_callback, ) + if eligible_entity_ids + else None + ) + repository_results = [ + result + for result in await asyncio.gather( + cleanup_task if cleanup_task is not None else asyncio.sleep(0, result=None), + eligible_task if eligible_task is not None else asyncio.sleep(0, result=None), + ) + if result is not None + ] if not repository_results: return VectorSyncBatchResult( @@ -503,7 +509,9 @@ async def sync_entity_vectors_batch( entities_failed=sum(result.entities_failed for result in repository_results), entities_deferred=sum(result.entities_deferred for result in repository_results), entities_skipped=( - len(opted_out_ids) + sum(result.entities_skipped for result in repository_results) + len(opted_out_ids) + + sum(result.entities_skipped for result in repository_results) + - len(unknown_ids) ), failed_entity_ids=[ failed_entity_id diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index db5c01df..19a1f949 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -273,6 +273,7 @@ async def test_semantic_vector_sync_batch_cleans_up_unknown_ids(search_service, entities_total=1, entities_synced=1, entities_failed=0, + entities_skipped=1, ), VectorSyncBatchResult( entities_total=1, @@ -282,12 +283,19 @@ async def test_semantic_vector_sync_batch_cleans_up_unknown_ids(search_service, ] ) monkeypatch.setattr(repository, "sync_entity_vectors_batch", sync_batch) + progress_callback = AsyncMock() - result = await search_service.sync_entity_vectors_batch([41, 42]) + result = await search_service.sync_entity_vectors_batch([41, 42], progress_callback) assert sync_batch.await_count == 2 - assert sync_batch.await_args_list[0].args[0] == [41] - assert sync_batch.await_args_list[1].args[0] == [42] + called_entity_ids = {tuple(call.args[0]) for call in sync_batch.await_args_list} + assert called_entity_ids == {(41,), (42,)} + progress_callback_calls = [ + call for call in sync_batch.await_args_list if call.kwargs.get("progress_callback") is not None + ] + assert len(progress_callback_calls) == 1 + assert progress_callback_calls[0].args[0] == [42] + assert progress_callback_calls[0].kwargs["progress_callback"] is progress_callback assert result.entities_total == 2 assert result.entities_synced == 2 assert result.entities_failed == 0