From 651fdd8072d6817ca8ea9c5db9a2d8373500baf5 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 1 Apr 2026 18:37:01 -0400 Subject: [PATCH 1/7] feat(migrate): [3/7] add async migration with non-blocking planner, executor, validator, and readiness utilities Async versions of planner, executor, and validator for non-blocking migration workflows. Async executor mirrors the sync drop/recreate flow with async key enumeration, prefix/field renames, vector re-encoding with checkpoint resume, and readiness polling. Includes async utilities for index listing, readiness polling, and source snapshot validation. Adds async unit and integration tests. --- redisvl/migration/__init__.py | 16 + redisvl/migration/async_executor.py | 1093 ++++++++++++++++++ redisvl/migration/async_planner.py | 271 +++++ redisvl/migration/async_utils.py | 101 ++ redisvl/migration/async_validation.py | 186 +++ tests/integration/test_async_migration_v1.py | 150 +++ tests/unit/test_async_migration_executor.py | 1092 +++++++++++++++++ tests/unit/test_async_migration_planner.py | 319 +++++ 8 files changed, 3228 insertions(+) create mode 100644 redisvl/migration/async_executor.py create mode 100644 redisvl/migration/async_planner.py create mode 100644 redisvl/migration/async_utils.py create mode 100644 redisvl/migration/async_validation.py create mode 100644 tests/integration/test_async_migration_v1.py create mode 100644 tests/unit/test_async_migration_executor.py create mode 100644 tests/unit/test_async_migration_planner.py diff --git a/redisvl/migration/__init__.py b/redisvl/migration/__init__.py index 267ad4c31..255803473 100644 --- a/redisvl/migration/__init__.py +++ b/redisvl/migration/__init__.py @@ -1,3 +1,11 @@ +from redisvl.migration.async_executor import AsyncMigrationExecutor +from redisvl.migration.async_planner import AsyncMigrationPlanner +from redisvl.migration.async_utils import ( + async_current_source_matches_snapshot, + async_list_indexes, + async_wait_for_index_ready, +) +from redisvl.migration.async_validation import AsyncMigrationValidator from redisvl.migration.executor import MigrationExecutor from redisvl.migration.models import ( DiskSpaceEstimate, @@ -11,6 +19,7 @@ from redisvl.migration.validation import MigrationValidator __all__ = [ + # Sync "DiskSpaceEstimate", "FieldRename", "MigrationExecutor", @@ -20,4 +29,11 @@ "MigrationValidator", "RenameOperations", "SchemaPatch", + # Async + "AsyncMigrationExecutor", + "AsyncMigrationPlanner", + "AsyncMigrationValidator", + "async_current_source_matches_snapshot", + "async_list_indexes", + "async_wait_for_index_ready", ] diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py new file mode 100644 index 000000000..7d00b83a8 --- /dev/null +++ b/redisvl/migration/async_executor.py @@ -0,0 +1,1093 @@ +from __future__ import annotations + +import asyncio +import logging +import time +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional + +from redis.exceptions import ResponseError + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.async_planner import AsyncMigrationPlanner +from redisvl.migration.async_validation import AsyncMigrationValidator +from redisvl.migration.models import ( + MigrationBenchmarkSummary, + MigrationPlan, + MigrationReport, + MigrationTimings, + MigrationValidation, +) +from redisvl.migration.reliability import ( + BatchUndoBuffer, + QuantizationCheckpoint, + async_trigger_bgsave_and_wait, + is_already_quantized, + is_same_width_dtype_conversion, +) +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + get_schema_field_path, + normalize_keys, + timestamp_utc, +) +from redisvl.redis.utils import array_to_buffer, buffer_to_array +from redisvl.types import AsyncRedisClient + +logger = logging.getLogger(__name__) + + +class AsyncMigrationExecutor: + """Async migration executor for document-preserving drop/recreate flows. + + This is the async version of MigrationExecutor. It uses AsyncSearchIndex + and async Redis operations for better performance on large indexes, + especially during vector quantization. + """ + + def __init__(self, validator: Optional[AsyncMigrationValidator] = None): + self.validator = validator or AsyncMigrationValidator() + + async def _detect_aof_enabled(self, client: Any) -> bool: + """Best-effort detection of whether AOF is enabled on the live Redis.""" + try: + info = await client.info("persistence") + if isinstance(info, dict) and "aof_enabled" in info: + return bool(int(info["aof_enabled"])) + except Exception: + logger.debug("Could not read Redis INFO persistence for AOF detection.") + + try: + config = await client.config_get("appendonly") + if isinstance(config, dict): + value = config.get("appendonly") + if value is not None: + return str(value).lower() in {"yes", "1", "true", "on"} + except Exception: + logger.debug("Could not read Redis CONFIG GET appendonly.") + + return False + + async def _enumerate_indexed_keys( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate document keys using FT.AGGREGATE with SCAN fallback. + + Uses FT.AGGREGATE WITHCURSOR for efficient enumeration when the index + has no indexing failures. Falls back to SCAN if: + - Index has hash_indexing_failures > 0 (would miss failed docs) + - FT.AGGREGATE command fails for any reason + """ + # Check for indexing failures - if any, fall back to SCAN + try: + info = await client.ft(index_name).info() + failures = int(info.get("hash_indexing_failures", 0) or 0) + if failures > 0: + logger.warning( + f"Index '{index_name}' has {failures} indexing failures. " + "Using SCAN for complete enumeration." + ) + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + return + except Exception as e: + logger.warning(f"Failed to check index info: {e}. Using SCAN fallback.") + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + return + + # Try FT.AGGREGATE enumeration + try: + async for key in self._enumerate_with_aggregate( + client, index_name, batch_size + ): + yield key + except ResponseError as e: + logger.warning( + f"FT.AGGREGATE failed: {e}. Falling back to SCAN enumeration." + ) + async for key in self._enumerate_with_scan( + client, index_name, batch_size, key_separator + ): + yield key + + async def _enumerate_with_aggregate( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate keys using FT.AGGREGATE WITHCURSOR.""" + cursor_id: Optional[int] = None + + try: + # Initial aggregate call with LOAD 1 __key + result = await client.execute_command( + "FT.AGGREGATE", + index_name, + "*", + "LOAD", + "1", + "__key", + "WITHCURSOR", + "COUNT", + str(batch_size), + ) + + while True: + results_data, cursor_id = result + + # Extract keys from results + for item in results_data[1:]: + if isinstance(item, (list, tuple)) and len(item) >= 2: + key = item[1] + yield key.decode() if isinstance(key, bytes) else str(key) + + if cursor_id == 0: + break + + result = await client.execute_command( + "FT.CURSOR", + "READ", + index_name, + str(cursor_id), + "COUNT", + str(batch_size), + ) + finally: + if cursor_id and cursor_id != 0: + try: + await client.execute_command( + "FT.CURSOR", "DEL", index_name, str(cursor_id) + ) + except Exception: + pass + + async def _enumerate_with_scan( + self, + client: AsyncRedisClient, + index_name: str, + batch_size: int = 1000, + key_separator: str = ":", + ) -> AsyncGenerator[str, None]: + """Async version: Enumerate keys using SCAN with prefix matching.""" + # Get prefix from index info + try: + info = await client.ft(index_name).info() + if isinstance(info, dict): + prefixes = info.get("index_definition", {}).get("prefixes", []) + else: + prefixes = [] + for i, item in enumerate(info): + if item == b"index_definition" or item == "index_definition": + defn = info[i + 1] + if isinstance(defn, dict): + prefixes = defn.get("prefixes", []) + elif isinstance(defn, list): + for j, d in enumerate(defn): + if d in (b"prefixes", "prefixes") and j + 1 < len(defn): + prefixes = defn[j + 1] + break + normalized_prefixes = [ + p.decode() if isinstance(p, bytes) else str(p) for p in prefixes + ] + except Exception as e: + logger.warning(f"Failed to get prefix from index info: {e}") + normalized_prefixes = [] + + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + normalized_prefixes, key_separator + ): + cursor: int = 0 + while True: + cursor, keys = await client.scan( + cursor=cursor, + match=match_pattern, + count=batch_size, + ) + for key in keys: + key_str = key.decode() if isinstance(key, bytes) else str(key) + if key_str not in seen_keys: + seen_keys.add(key_str) + yield key_str + + if cursor == 0: + break + + async def _rename_keys( + self, + client: AsyncRedisClient, + keys: List[str], + old_prefix: str, + new_prefix: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename keys from old prefix to new prefix. + + Uses RENAMENX to avoid overwriting existing destination keys. + Raises on collision to prevent silent data loss. + """ + renamed = 0 + total = len(keys) + pipeline_size = 100 + collisions: List[str] = [] + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + pipe = client.pipeline(transaction=False) + batch_new_keys: List[str] = [] + + for key in batch: + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix) :] + else: + logger.warning( + f"Key '{key}' does not start with prefix '{old_prefix}'" + ) + continue + pipe.renamenx(key, new_key) + batch_new_keys.append(new_key) + + try: + results = await pipe.execute() + for j, r in enumerate(results): + if r is True or r == 1: + renamed += 1 + else: + collisions.append(batch_new_keys[j]) + except Exception as e: + logger.warning(f"Error in rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + if collisions: + raise RuntimeError( + f"Prefix rename aborted: {len(collisions)} destination key(s) already exist " + f"(first 5: {collisions[:5]}). This would overwrite existing data. " + f"Remove conflicting keys or choose a different prefix." + ) + + return renamed + + async def _rename_field_in_hash( + self, + client: AsyncRedisClient, + keys: List[str], + old_name: str, + new_name: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename a field in hash documents.""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.hget(key, old_name) + values = await pipe.execute() + + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value in zip(batch, values): + if value is not None: + pipe.hset(key, new_name, value) + pipe.hdel(key, old_name) + batch_ops += 1 + + try: + await pipe.execute() + # Count by number of keys that had old field values, + # not by HSET return (HSET returns 0 for existing field updates) + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + async def _rename_field_in_json( + self, + client: AsyncRedisClient, + keys: List[str], + old_path: str, + new_path: str, + progress_callback: Optional[Callable[[int, int], None]] = None, + ) -> int: + """Async version: Rename a field in JSON documents.""" + renamed = 0 + total = len(keys) + pipeline_size = 100 + + for i in range(0, total, pipeline_size): + batch = keys[i : i + pipeline_size] + + pipe = client.pipeline(transaction=False) + for key in batch: + pipe.json().get(key, old_path) + values = await pipe.execute() + + # JSONPath GET returns results as a list; unwrap single-element + # results to preserve the original document shape. + # Missing paths return None or [] depending on Redis version. + pipe = client.pipeline(transaction=False) + batch_ops = 0 + for key, value in zip(batch, values): + if value is None or value == []: + continue + if isinstance(value, list) and len(value) == 1: + value = value[0] + pipe.json().set(key, new_path, value) + pipe.json().delete(key, old_path) + batch_ops += 1 + try: + await pipe.execute() + # Count by number of keys that had old field values, + # not by JSON.SET return value + renamed += batch_ops + except Exception as e: + logger.warning(f"Error in JSON field rename batch: {e}") + raise + + if progress_callback: + progress_callback(min(i + pipeline_size, total), total) + + return renamed + + async def apply( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + query_check_file: Optional[str] = None, + progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + checkpoint_path: Optional[str] = None, + ) -> MigrationReport: + """Apply a migration plan asynchronously. + + Args: + plan: The migration plan to apply. + redis_url: Redis connection URL. + redis_client: Optional existing async Redis client. + query_check_file: Optional file with query checks. + progress_callback: Optional callback(step, detail) for progress updates. + checkpoint_path: Optional path for quantization checkpoint file. + When provided, enables crash-safe resume for vector re-encoding. + """ + started_at = timestamp_utc() + started = time.perf_counter() + + report = MigrationReport( + source_index=plan.source.index_name, + target_index=plan.merged_target_schema["index"]["name"], + result="failed", + started_at=started_at, + finished_at=started_at, + warnings=list(plan.warnings), + ) + + if not plan.diff_classification.supported: + report.validation.errors.extend(plan.diff_classification.blocked_reasons) + report.manual_actions.append( + "This change requires document migration, which is not yet supported." + ) + report.finished_at = timestamp_utc() + return report + + # Check if we are resuming from a checkpoint (post-drop crash). + # If so, the source index may no longer exist in Redis, so we + # skip live schema validation and construct from the plan snapshot. + resuming_from_checkpoint = False + if checkpoint_path: + existing_checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if existing_checkpoint is not None: + # Validate checkpoint belongs to this migration and is incomplete + if existing_checkpoint.index_name != plan.source.index_name: + logger.warning( + "Checkpoint index '%s' does not match plan index '%s', ignoring", + existing_checkpoint.index_name, + plan.source.index_name, + ) + elif existing_checkpoint.status == "completed": + logger.info( + "Checkpoint at %s is already completed, ignoring", + checkpoint_path, + ) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint found at %s, skipping source index validation " + "(index may have been dropped before crash)", + checkpoint_path, + ) + + if not resuming_from_checkpoint: + if not await self._async_current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = await AsyncSearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid AsyncSearchIndex with a Redis client attached. + source_index = AsyncSearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ) + + target_index = AsyncSearchIndex.from_dict( + plan.merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + + enumerate_duration = 0.0 + drop_duration = 0.0 + quantize_duration = 0.0 + field_rename_duration = 0.0 + key_rename_duration = 0.0 + recreate_duration = 0.0 + indexing_duration = 0.0 + target_info: Dict[str, Any] = {} + docs_quantized = 0 + keys_to_process: List[str] = [] + storage_type = plan.source.keyspace.storage_type + + datatype_changes = AsyncMigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, plan.merged_target_schema + ) + + # Check for rename operations + rename_ops = plan.rename_operations + has_prefix_change = bool(rename_ops.change_prefix) + has_field_renames = bool(rename_ops.rename_fields) + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if checkpoint_path and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --resume for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report + + def _notify(step: str, detail: Optional[str] = None) -> None: + if progress_callback: + progress_callback(step, detail) + + try: + client = source_index._redis_client + if client is None: + raise ValueError("Failed to get Redis client from source index") + aof_enabled = await self._detect_aof_enabled(client) + disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, + ) + report.disk_space_estimate = disk_estimate + + if resuming_from_checkpoint: + # On resume after a post-drop crash, the index no longer + # exists. Enumerate keys via SCAN using the plan prefix, + # and skip BGSAVE / field renames / drop (already done). + if needs_enumeration: + _notify("enumerate", "Enumerating documents via SCAN (resume)...") + enumerate_started = time.perf_counter() + prefixes = list(plan.source.keyspace.prefixes) + if has_prefix_change and rename_ops.change_prefix: + prefixes = [rename_ops.change_prefix] + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + prefixes, plan.source.keyspace.key_separator + ): + cursor: int = 0 + while True: + cursor, scanned = await client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=1000, + ) + for k in scanned: + key = k.decode() if isinstance(k, bytes) else str(k) + if key not in seen_keys: + seen_keys.add(key) + keys_to_process.append(key) + if cursor == 0: + break + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + _notify("bgsave", "skipped (resume)") + _notify("drop", "skipped (already dropped)") + else: + # Normal (non-resume) path + # STEP 1: Enumerate keys BEFORE any modifications + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = [ + key + async for key in self._enumerate_indexed_keys( + client, + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, + ) + ] + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) + + # BGSAVE safety net: snapshot data before mutations begin + if needs_enumeration and keys_to_process: + _notify("bgsave", "Triggering BGSAVE safety snapshot...") + try: + await async_trigger_bgsave_and_wait(client) + _notify("bgsave", "done") + except Exception as e: + logger.warning("BGSAVE safety snapshot failed: %s", e) + _notify("bgsave", f"skipped ({e})") + + # STEP 2: Field renames (before dropping index) + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + await self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + await self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") + + # STEP 3: Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + await source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") + + # STEP 4: Key renames (after drop, before recreate) + # On resume, key renames were already done before the crash. + if has_prefix_change and keys_to_process and not resuming_from_checkpoint: + _notify("key_rename", "Renaming keys...") + key_rename_started = time.perf_counter() + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + assert new_prefix is not None + renamed_count = await self._rename_keys( + client, + keys_to_process, + old_prefix, + new_prefix, + progress_callback=lambda done, total: _notify( + "key_rename", f"{done:,}/{total:,} keys" + ), + ) + key_rename_duration = round(time.perf_counter() - key_rename_started, 3) + _notify( + "key_rename", + f"done ({renamed_count:,} keys in {key_rename_duration}s)", + ) + + # STEP 5: Re-encode vectors using pre-enumerated keys + if needs_quantization and keys_to_process: + _notify("quantize", "Re-encoding vectors...") + quantize_started = time.perf_counter() + # If we renamed keys (non-resume), update keys_to_process + if ( + has_prefix_change + and rename_ops.change_prefix + and not resuming_from_checkpoint + ): + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + keys_to_process = [ + ( + new_prefix + k[len(old_prefix) :] + if k.startswith(old_prefix) + else k + ) + for k in keys_to_process + ] + keys_to_process = normalize_keys(keys_to_process) + docs_quantized = await self._async_quantize_vectors( + source_index, + datatype_changes, + keys_to_process, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + checkpoint_path=checkpoint_path, + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + report.warnings.append( + f"Re-encoded {docs_quantized} documents for vector quantization: " + f"{datatype_changes}" + ) + elif datatype_changes and storage_type == "json": + # No checkpoint for JSON: vectors are re-indexed on recreate, + # so there is nothing to resume. Creating one would leave a + # stale in-progress checkpoint that misleads future runs. + _notify("quantize", "skipped (JSON vectors are re-indexed on recreate)") + + _notify("create", "Creating index with new schema...") + recreate_started = time.perf_counter() + await target_index.create() + recreate_duration = round(time.perf_counter() - recreate_started, 3) + _notify("create", f"done ({recreate_duration}s)") + + _notify("index", "Waiting for re-indexing...") + + def _index_progress(indexed: int, total: int, pct: float) -> None: + _notify("index", f"{indexed:,}/{total:,} docs ({pct:.0f}%)") + + target_info, indexing_duration = await self._async_wait_for_index_ready( + target_index, progress_callback=_index_progress + ) + _notify("index", f"done ({indexing_duration}s)") + + _notify("validate", "Validating migration...") + validation, target_info, validation_duration = ( + await self.validator.validate( + plan, + redis_url=redis_url, + redis_client=redis_client, + query_check_file=query_check_file, + ) + ) + _notify("validate", f"done ({validation_duration}s)") + report.validation = validation + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration, + quantize_duration_seconds=( + quantize_duration if quantize_duration else None + ), + field_rename_duration_seconds=( + field_rename_duration if field_rename_duration else None + ), + key_rename_duration_seconds=( + key_rename_duration if key_rename_duration else None + ), + recreate_duration_seconds=recreate_duration, + initial_indexing_duration_seconds=indexing_duration, + validation_duration_seconds=validation_duration, + downtime_duration_seconds=round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ), + ) + report.benchmark_summary = self._build_benchmark_summary( + plan, + target_info, + report.timings, + ) + report.result = "succeeded" if not validation.errors else "failed" + if validation.errors: + report.manual_actions.append( + "Review validation errors before treating the migration as complete." + ) + except Exception as exc: + total_duration = round(time.perf_counter() - started, 3) + report.timings = MigrationTimings( + total_migration_duration_seconds=total_duration, + drop_duration_seconds=drop_duration or None, + quantize_duration_seconds=quantize_duration or None, + field_rename_duration_seconds=field_rename_duration or None, + key_rename_duration_seconds=key_rename_duration or None, + recreate_duration_seconds=recreate_duration or None, + initial_indexing_duration_seconds=indexing_duration or None, + downtime_duration_seconds=( + round( + drop_duration + + field_rename_duration + + key_rename_duration + + quantize_duration + + recreate_duration + + indexing_duration, + 3, + ) + if drop_duration + or field_rename_duration + or key_rename_duration + or quantize_duration + or recreate_duration + or indexing_duration + else None + ), + ) + report.validation = MigrationValidation( + errors=[f"Migration execution failed: {exc}"] + ) + report.manual_actions.extend( + [ + "Inspect the Redis index state before retrying.", + "If the source index was dropped, recreate it from the saved migration plan.", + ] + ) + finally: + report.finished_at = timestamp_utc() + + return report + + async def _async_quantize_vectors( + self, + source_index: AsyncSearchIndex, + datatype_changes: Dict[str, Dict[str, Any]], + keys: List[str], + progress_callback: Optional[Callable[[int, int], None]] = None, + checkpoint_path: Optional[str] = None, + ) -> int: + """Re-encode vectors in documents for datatype changes (quantization). + + Uses pre-enumerated keys (from _enumerate_indexed_keys) to process + only the documents that were in the index, avoiding full keyspace scan. + Includes idempotent skip (already-quantized vectors), bounded undo + buffer for per-batch rollback, and optional checkpointing for resume. + + Args: + source_index: The source AsyncSearchIndex (already dropped but client available) + datatype_changes: Dict mapping field_name -> {"source", "target", "dims"} + keys: Pre-enumerated list of document keys to process + progress_callback: Optional callback(docs_done, total_docs) + checkpoint_path: Optional path for checkpoint file (enables resume) + + Returns: + Number of documents processed + """ + client = source_index._redis_client + if client is None: + raise ValueError("Failed to get Redis client from source index") + + total_keys = len(keys) + docs_processed = 0 + docs_quantized = 0 + skipped = 0 + batch_size = 500 + + # Load or create checkpoint for resume support + checkpoint: Optional[QuantizationCheckpoint] = None + if checkpoint_path: + checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if checkpoint: + # Validate checkpoint matches current migration BEFORE + # checking completion status to avoid skipping quantization + # for an unrelated completed checkpoint. + if checkpoint.index_name != source_index.name: + raise ValueError( + f"Checkpoint index '{checkpoint.index_name}' does not match " + f"source index '{source_index.name}'. " + f"Use the correct checkpoint file or remove it to start fresh." + ) + # Skip if checkpoint shows a completed migration + if checkpoint.status == "completed": + logger.info( + "Checkpoint already marked as completed for index '%s'. " + "Skipping quantization. Remove the checkpoint file to force re-run.", + checkpoint.index_name, + ) + return 0 + if checkpoint.total_keys != total_keys: + if checkpoint.processed_keys: + current_keys = set(keys) + missing_processed = [ + key + for key in checkpoint.processed_keys + if key not in current_keys + ] + if missing_processed or total_keys < checkpoint.total_keys: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + logger.warning( + "Checkpoint total_keys=%d differs from current key set size=%d. " + "Proceeding because all legacy processed keys are present.", + checkpoint.total_keys, + total_keys, + ) + else: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + remaining = checkpoint.get_remaining_keys(keys) + logger.info( + "Resuming from checkpoint: %d/%d keys already processed", + total_keys - len(remaining), + total_keys, + ) + docs_processed = total_keys - len(remaining) + keys = remaining + total_keys_for_progress = total_keys + else: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=total_keys, + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + total_keys_for_progress = total_keys + else: + total_keys_for_progress = total_keys + + remaining_keys = len(keys) + + for i in range(0, remaining_keys, batch_size): + batch = keys[i : i + batch_size] + pipe = client.pipeline() + undo = BatchUndoBuffer() + keys_updated_in_batch: set[str] = set() + + try: + for key in batch: + for field_name, change in datatype_changes.items(): + field_data: bytes | None = await client.hget(key, field_name) # type: ignore[misc,assignment] + if not field_data: + continue + + # Idempotent: skip if already converted to target dtype + dims = change.get("dims", 0) + if dims and is_already_quantized( + field_data, dims, change["source"], change["target"] + ): + skipped += 1 + continue + + undo.store(key, field_name, field_data) + array = buffer_to_array(field_data, change["source"]) + new_bytes = array_to_buffer(array, change["target"]) + pipe.hset(key, field_name, new_bytes) # type: ignore[arg-type] + keys_updated_in_batch.add(key) + + if keys_updated_in_batch: + await pipe.execute() + except Exception: + logger.warning( + "Batch %d failed, rolling back %d entries", + i // batch_size, + undo.size, + ) + rollback_pipe = client.pipeline() + await undo.async_rollback(rollback_pipe) + if checkpoint: + checkpoint.save() + raise + finally: + undo.clear() + + docs_quantized += len(keys_updated_in_batch) + docs_processed += len(batch) + + if checkpoint: + # Record all keys in batch (including skipped) so they + # are not re-scanned on resume + checkpoint.record_batch(batch) + checkpoint.save() + + if progress_callback: + progress_callback(docs_processed, total_keys_for_progress) + + if checkpoint: + checkpoint.mark_complete() + checkpoint.save() + + if skipped: + logger.info("Skipped %d already-quantized vector fields", skipped) + logger.info( + "Quantized %d documents across %d fields", + docs_quantized, + len(datatype_changes), + ) + return docs_quantized + + async def _async_wait_for_index_ready( + self, + index: AsyncSearchIndex, + *, + timeout_seconds: int = 1800, + poll_interval_seconds: float = 0.5, + progress_callback: Optional[Callable[[int, int, float], None]] = None, + ) -> tuple[Dict[str, Any], float]: + """Wait for index to finish indexing all documents (async version).""" + start = time.perf_counter() + deadline = start + timeout_seconds + latest_info = await index.info() + + stable_ready_checks: Optional[int] = None + while time.perf_counter() < deadline: + latest_info = await index.info() + indexing = latest_info.get("indexing") + percent_indexed = latest_info.get("percent_indexed") + + if percent_indexed is not None or indexing is not None: + ready = float(percent_indexed or 0) >= 1.0 and not bool(indexing) + if progress_callback: + total_docs = int(latest_info.get("num_docs", 0)) + pct = float(percent_indexed or 0) + indexed_docs = int(total_docs * pct) + progress_callback(indexed_docs, total_docs, pct * 100) + else: + current_docs = latest_info.get("num_docs") + if current_docs is None: + ready = True + else: + if stable_ready_checks is None: + stable_ready_checks = int(current_docs) + await asyncio.sleep(poll_interval_seconds) + continue + current = int(current_docs) + if current == stable_ready_checks: + ready = True + else: + # num_docs changed; update baseline and keep waiting + stable_ready_checks = current + + if ready: + return latest_info, round(time.perf_counter() - start, 3) + + await asyncio.sleep(poll_interval_seconds) + + raise TimeoutError( + f"Index {index.schema.index.name} did not become ready within {timeout_seconds} seconds" + ) + + async def _async_current_source_matches_snapshot( + self, + index_name: str, + expected_schema: Dict[str, Any], + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> bool: + """Check if current source schema matches the snapshot (async version).""" + from redisvl.migration.utils import schemas_equal + + current_index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + return schemas_equal(current_index.schema.to_dict(), expected_schema) + + def _build_benchmark_summary( + self, + plan: MigrationPlan, + target_info: dict, + timings: MigrationTimings, + ) -> MigrationBenchmarkSummary: + source_index_size = float( + plan.source.stats_snapshot.get("vector_index_sz_mb", 0) or 0 + ) + target_index_size = float(target_info.get("vector_index_sz_mb", 0) or 0) + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + indexed_per_second = None + indexing_time = timings.initial_indexing_duration_seconds + if indexing_time and indexing_time > 0: + indexed_per_second = round(source_num_docs / indexing_time, 3) + + return MigrationBenchmarkSummary( + documents_indexed_per_second=indexed_per_second, + source_index_size_mb=round(source_index_size, 3), + target_index_size_mb=round(target_index_size, 3), + index_size_delta_mb=round(target_index_size - source_index_size, 3), + ) diff --git a/redisvl/migration/async_planner.py b/redisvl/migration/async_planner.py new file mode 100644 index 000000000..703714496 --- /dev/null +++ b/redisvl/migration/async_planner.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +from typing import List, Optional + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.models import ( + KeyspaceSnapshot, + MigrationPlan, + SchemaPatch, + SourceSnapshot, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.redis.connection import supports_svs_async +from redisvl.schema.schema import IndexSchema +from redisvl.types import AsyncRedisClient + + +class AsyncMigrationPlanner: + """Async migration planner for document-preserving drop/recreate flows. + + This is the async version of MigrationPlanner. It uses AsyncSearchIndex + and async Redis operations for better performance on large indexes. + + The classification logic, schema merging, and diff analysis are delegated + to a sync MigrationPlanner instance (they are CPU-bound and don't need async). + """ + + def __init__(self, key_sample_limit: int = 10): + self.key_sample_limit = key_sample_limit + # Delegate to sync planner for CPU-bound operations + self._sync_planner = MigrationPlanner(key_sample_limit=key_sample_limit) + + # Expose static methods from MigrationPlanner for convenience + get_vector_datatype_changes = staticmethod( + MigrationPlanner.get_vector_datatype_changes + ) + + async def create_plan( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + schema_patch_path: Optional[str] = None, + target_schema_path: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> MigrationPlan: + if not schema_patch_path and not target_schema_path: + raise ValueError( + "Must provide either --schema-patch or --target-schema for migration planning" + ) + if schema_patch_path and target_schema_path: + raise ValueError( + "Provide only one of --schema-patch or --target-schema for migration planning" + ) + + snapshot = await self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + + if schema_patch_path: + schema_patch = self._sync_planner.load_schema_patch(schema_patch_path) + else: + # target_schema_path is guaranteed to be not None here + assert target_schema_path is not None + schema_patch = self._sync_planner.normalize_target_schema_to_patch( + source_schema, target_schema_path + ) + + return await self.create_plan_from_patch( + index_name, + schema_patch=schema_patch, + redis_url=redis_url, + redis_client=redis_client, + ) + + async def create_plan_from_patch( + self, + index_name: str, + *, + schema_patch: SchemaPatch, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> MigrationPlan: + snapshot = await self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + merged_target_schema = self._sync_planner.merge_patch( + source_schema, schema_patch + ) + + # Extract rename operations first + rename_operations, rename_warnings = ( + self._sync_planner._extract_rename_operations(source_schema, schema_patch) + ) + + # Classify diff with awareness of rename operations + diff_classification = self._sync_planner.classify_diff( + source_schema, schema_patch, merged_target_schema, rename_operations + ) + + # Build warnings list + warnings = ["Index downtime is required"] + warnings.extend(rename_warnings) + + # Check for SVS-VAMANA in target schema and add appropriate warnings + svs_warnings = await self._check_svs_vamana_requirements( + merged_target_schema, + redis_url=redis_url, + redis_client=redis_client, + ) + warnings.extend(svs_warnings) + + return MigrationPlan( + source=snapshot, + requested_changes=schema_patch.model_dump(exclude_none=True), + merged_target_schema=merged_target_schema.to_dict(), + diff_classification=diff_classification, + rename_operations=rename_operations, + warnings=warnings, + ) + + async def _check_svs_vamana_requirements( + self, + target_schema: IndexSchema, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> List[str]: + """Async version: Check SVS-VAMANA requirements and return warnings.""" + warnings: List[str] = [] + target_dict = target_schema.to_dict() + + # Check if any vector field uses SVS-VAMANA + uses_svs = False + uses_compression = False + compression_type = None + + for field in target_dict.get("fields", []): + if field.get("type") != "vector": + continue + attrs = field.get("attrs", {}) + algo = attrs.get("algorithm", "").upper() + if algo == "SVS-VAMANA": + uses_svs = True + compression = attrs.get("compression", "") + if compression: + uses_compression = True + compression_type = compression + + if not uses_svs: + return warnings + + # Check Redis version support + try: + if redis_client: + client = redis_client + elif redis_url: + from redis.asyncio import Redis + + client = Redis.from_url(redis_url) + else: + client = None + + if client and not await supports_svs_async(client): + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "The target Redis instance may not support this algorithm. " + "Migration will fail at apply time if requirements are not met." + ) + except Exception: + warnings.append( + "SVS-VAMANA requires Redis >= 8.2.0 and Redis Search >= 2.8.10. " + "Verify your Redis instance supports this algorithm before applying." + ) + + # Intel hardware warning for compression + if uses_compression: + warnings.append( + f"SVS-VAMANA with {compression_type} compression: " + "LVQ and LeanVec optimizations require Intel hardware with AVX-512 support. " + "On non-Intel platforms or Redis Open Source, these fall back to basic " + "8-bit scalar quantization with reduced performance benefits." + ) + else: + warnings.append( + "SVS-VAMANA: For optimal performance, Intel hardware with AVX-512 support " + "is recommended. LVQ/LeanVec compression options provide additional memory " + "savings on supported hardware." + ) + + return warnings + + async def snapshot_source( + self, + index_name: str, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + ) -> SourceSnapshot: + index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + schema_dict = index.schema.to_dict() + stats_snapshot = await index.info() + prefixes = index.schema.index.prefix + prefix_list = prefixes if isinstance(prefixes, list) else [prefixes] + + client = index.client + if client is None: + raise ValueError("Failed to get Redis client from index") + + return SourceSnapshot( + index_name=index_name, + schema_snapshot=schema_dict, + stats_snapshot=stats_snapshot, + keyspace=KeyspaceSnapshot( + storage_type=index.schema.index.storage_type.value, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + key_sample=await self._async_sample_keys( + client=client, + prefixes=prefix_list, + key_separator=index.schema.index.key_separator, + ), + ), + ) + + async def _async_sample_keys( + self, *, client: AsyncRedisClient, prefixes: List[str], key_separator: str + ) -> List[str]: + """Async version of _sample_keys.""" + key_sample: List[str] = [] + if self.key_sample_limit <= 0: + return key_sample + + for prefix in prefixes: + if len(key_sample) >= self.key_sample_limit: + break + match_pattern = ( + f"{prefix}*" + if prefix.endswith(key_separator) + else f"{prefix}{key_separator}*" + ) + cursor: int = 0 + while True: + cursor, keys = await client.scan( + cursor=cursor, + match=match_pattern, + count=max(self.key_sample_limit, 10), + ) + for key in keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if decoded_key not in key_sample: + key_sample.append(decoded_key) + if len(key_sample) >= self.key_sample_limit: + return key_sample + if cursor == 0: + break + return key_sample + + def write_plan(self, plan: MigrationPlan, plan_out: str) -> None: + """Delegate to sync planner for file I/O.""" + self._sync_planner.write_plan(plan, plan_out) diff --git a/redisvl/migration/async_utils.py b/redisvl/migration/async_utils.py new file mode 100644 index 000000000..8e7af632a --- /dev/null +++ b/redisvl/migration/async_utils.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import asyncio +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.utils import schemas_equal +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.types import AsyncRedisClient + + +async def async_list_indexes( + *, redis_url: Optional[str] = None, redis_client: Optional[AsyncRedisClient] = None +) -> List[str]: + """List all search indexes in Redis (async version).""" + if redis_client is None: + if not redis_url: + raise ValueError("Must provide either redis_url or redis_client") + redis_client = await RedisConnectionFactory._get_aredis_connection( + redis_url=redis_url + ) + index = AsyncSearchIndex.from_dict( + {"index": {"name": "__redisvl_migration_helper__"}, "fields": []}, + redis_client=redis_client, + ) + return await index.listall() + + +async def async_wait_for_index_ready( + index: AsyncSearchIndex, + *, + timeout_seconds: int = 1800, + poll_interval_seconds: float = 0.5, + progress_callback: Optional[Callable[[int, int, float], None]] = None, +) -> Tuple[Dict[str, Any], float]: + """Wait for index to finish indexing all documents (async version). + + Args: + index: The AsyncSearchIndex to monitor. + timeout_seconds: Maximum time to wait. + poll_interval_seconds: How often to check status. + progress_callback: Optional callback(indexed_docs, total_docs, percent). + """ + start = time.perf_counter() + deadline = start + timeout_seconds + latest_info = await index.info() + + stable_ready_checks: Optional[int] = None + while time.perf_counter() < deadline: + latest_info = await index.info() + indexing = latest_info.get("indexing") + percent_indexed = latest_info.get("percent_indexed") + + if percent_indexed is not None or indexing is not None: + ready = float(percent_indexed or 0) >= 1.0 and not bool(indexing) + if progress_callback: + total_docs = int(latest_info.get("num_docs", 0)) + pct = float(percent_indexed or 0) + indexed_docs = int(total_docs * pct) + progress_callback(indexed_docs, total_docs, pct * 100) + else: + current_docs = latest_info.get("num_docs") + if current_docs is None: + ready = True + else: + if stable_ready_checks is None: + stable_ready_checks = int(current_docs) + await asyncio.sleep(poll_interval_seconds) + continue + current = int(current_docs) + if current == stable_ready_checks: + ready = True + else: + # num_docs changed; update baseline and keep waiting + stable_ready_checks = current + + if ready: + return latest_info, round(time.perf_counter() - start, 3) + + await asyncio.sleep(poll_interval_seconds) + + raise TimeoutError( + f"Index {index.schema.index.name} did not become ready within {timeout_seconds} seconds" + ) + + +async def async_current_source_matches_snapshot( + index_name: str, + expected_schema: Dict[str, Any], + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, +) -> bool: + """Check if current source schema matches the snapshot (async version).""" + current_index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + return schemas_equal(current_index.schema.to_dict(), expected_schema) diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py new file mode 100644 index 000000000..7b9691c4b --- /dev/null +++ b/redisvl/migration/async_validation.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import time +from typing import Any, Dict, List, Optional + +from redis.commands.search.query import Query + +from redisvl.index import AsyncSearchIndex +from redisvl.migration.models import ( + MigrationPlan, + MigrationValidation, + QueryCheckResult, +) +from redisvl.migration.utils import load_yaml, schemas_equal +from redisvl.types import AsyncRedisClient + + +class AsyncMigrationValidator: + """Async migration validator for post-migration checks. + + This is the async version of MigrationValidator. It uses AsyncSearchIndex + and async Redis operations for better performance. + """ + + async def validate( + self, + plan: MigrationPlan, + *, + redis_url: Optional[str] = None, + redis_client: Optional[AsyncRedisClient] = None, + query_check_file: Optional[str] = None, + ) -> tuple[MigrationValidation, Dict[str, Any], float]: + started = time.perf_counter() + target_index = await AsyncSearchIndex.from_existing( + plan.merged_target_schema["index"]["name"], + redis_url=redis_url, + redis_client=redis_client, + ) + target_info = await target_index.info() + validation = MigrationValidation() + + live_schema = target_index.schema.to_dict() + # Exclude query-time and creation-hint attributes (ef_runtime, epsilon, + # initial_cap, phonetic_matcher) that are not part of index structure + # validation. Confirmed by RediSearch team as not relevant for this check. + validation.schema_match = schemas_equal( + live_schema, plan.merged_target_schema, strip_excluded=True + ) + + source_num_docs = int(plan.source.stats_snapshot.get("num_docs", 0) or 0) + target_num_docs = int(target_info.get("num_docs", 0) or 0) + validation.doc_count_match = source_num_docs == target_num_docs + + source_failures = int( + plan.source.stats_snapshot.get("hash_indexing_failures", 0) or 0 + ) + target_failures = int(target_info.get("hash_indexing_failures", 0) or 0) + validation.indexing_failures_delta = target_failures - source_failures + + key_sample = plan.source.keyspace.key_sample + client = target_index.client + if not key_sample: + validation.key_sample_exists = True + elif client is None: + validation.key_sample_exists = False + validation.errors.append("Failed to get Redis client for key sample check") + else: + # Handle prefix change: transform key_sample to use new prefix + keys_to_check = key_sample + if plan.rename_operations.change_prefix: + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = plan.rename_operations.change_prefix + keys_to_check = [ + new_prefix + k[len(old_prefix) :] if k.startswith(old_prefix) else k + for k in key_sample + ] + existing_count = await client.exists(*keys_to_check) + validation.key_sample_exists = existing_count == len(keys_to_check) + + # Run automatic functional checks (always) + functional_checks = await self._run_functional_checks( + target_index, source_num_docs + ) + validation.query_checks.extend(functional_checks) + + # Run user-provided query checks (if file provided) + if query_check_file: + user_checks = await self._run_query_checks(target_index, query_check_file) + validation.query_checks.extend(user_checks) + + if not validation.schema_match: + validation.errors.append("Live schema does not match merged_target_schema.") + if not validation.doc_count_match: + validation.errors.append( + "Live document count does not match source num_docs." + ) + if validation.indexing_failures_delta != 0: + validation.errors.append("Indexing failures increased during migration.") + if not validation.key_sample_exists: + validation.errors.append( + "One or more sampled source keys is missing after migration." + ) + if any(not query_check.passed for query_check in validation.query_checks): + validation.errors.append("One or more query checks failed.") + + return validation, target_info, round(time.perf_counter() - started, 3) + + async def _run_query_checks( + self, + target_index: AsyncSearchIndex, + query_check_file: str, + ) -> list[QueryCheckResult]: + query_checks = load_yaml(query_check_file) + results: list[QueryCheckResult] = [] + + for doc_id in query_checks.get("fetch_ids", []): + fetched = await target_index.fetch(doc_id) + results.append( + QueryCheckResult( + name=f"fetch:{doc_id}", + passed=fetched is not None, + details=( + "Document fetched successfully" + if fetched is not None + else "Document not found" + ), + ) + ) + + client = target_index.client + for key in query_checks.get("keys_exist", []): + if client is None: + results.append( + QueryCheckResult( + name=f"key:{key}", + passed=False, + details="Failed to get Redis client", + ) + ) + else: + exists = bool(await client.exists(key)) + results.append( + QueryCheckResult( + name=f"key:{key}", + passed=exists, + details="Key exists" if exists else "Key not found", + ) + ) + + return results + + async def _run_functional_checks( + self, target_index: AsyncSearchIndex, expected_doc_count: int + ) -> List[QueryCheckResult]: + """Run automatic functional checks to verify the index is operational. + + These checks run automatically after every migration to prove the index + actually works, not just that the schema looks correct. + """ + results: List[QueryCheckResult] = [] + + # Check 1: Wildcard search - proves the index responds and returns docs + try: + search_result = await target_index.search(Query("*").paging(0, 1)) + total_found = search_result.total + passed = total_found == expected_doc_count + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=passed, + details=( + f"Wildcard search returned {total_found} docs " + f"(expected {expected_doc_count})" + ), + ) + ) + except Exception as e: + results.append( + QueryCheckResult( + name="functional:wildcard_search", + passed=False, + details=f"Wildcard search failed: {str(e)}", + ) + ) + + return results diff --git a/tests/integration/test_async_migration_v1.py b/tests/integration/test_async_migration_v1.py new file mode 100644 index 000000000..c50fdaf84 --- /dev/null +++ b/tests/integration/test_async_migration_v1.py @@ -0,0 +1,150 @@ +"""Integration tests for async migration (Phase 1.5). + +These tests verify the async migration components work correctly with a real +Redis instance, mirroring the sync tests in test_migration_v1.py. +""" + +import uuid + +import pytest +import yaml + +from redisvl.index import AsyncSearchIndex +from redisvl.migration import ( + AsyncMigrationExecutor, + AsyncMigrationPlanner, + AsyncMigrationValidator, +) +from redisvl.migration.utils import load_migration_plan, schemas_equal +from redisvl.redis.utils import array_to_buffer + + +@pytest.mark.asyncio +async def test_async_drop_recreate_plan_apply_validate_flow( + redis_url, worker_id, tmp_path +): + """Test full async migration flow: plan -> apply -> validate.""" + unique_id = str(uuid.uuid4())[:8] + index_name = f"async_migration_v1_{worker_id}_{unique_id}" + prefix = f"async_migration_v1:{worker_id}:{unique_id}" + + source_index = AsyncSearchIndex.from_dict( + { + "index": { + "name": index_name, + "prefix": prefix, + "storage_type": "hash", + }, + "fields": [ + {"name": "doc_id", "type": "tag"}, + {"name": "title", "type": "text"}, + {"name": "price", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + redis_url=redis_url, + ) + + docs = [ + { + "doc_id": "1", + "title": "alpha", + "price": 1, + "category": "news", + "embedding": array_to_buffer([0.1, 0.2, 0.3], "float32"), + }, + { + "doc_id": "2", + "title": "beta", + "price": 2, + "category": "sports", + "embedding": array_to_buffer([0.2, 0.1, 0.4], "float32"), + }, + ] + + await source_index.create(overwrite=True) + await source_index.load(docs, id_field="doc_id") + + # Create schema patch + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [{"name": "title", "attrs": {"sortable": True}}], + }, + }, + sort_keys=False, + ) + ) + + # Create plan using async planner + plan_path = tmp_path / "migration_plan.yaml" + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + index_name, + redis_url=redis_url, + schema_patch_path=str(patch_path), + ) + assert plan.diff_classification.supported is True + planner.write_plan(plan, str(plan_path)) + + # Create query checks + query_check_path = tmp_path / "query_checks.yaml" + query_check_path.write_text( + yaml.safe_dump({"fetch_ids": ["1", "2"]}, sort_keys=False) + ) + + # Apply migration using async executor + executor = AsyncMigrationExecutor() + report = await executor.apply( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + + # Verify migration succeeded + assert report.result == "succeeded" + assert report.validation.schema_match is True + assert report.validation.doc_count_match is True + assert report.validation.key_sample_exists is True + assert report.validation.indexing_failures_delta == 0 + assert not report.validation.errors + assert report.benchmark_summary.documents_indexed_per_second is not None + + # Verify schema matches target + live_index = await AsyncSearchIndex.from_existing(index_name, redis_url=redis_url) + assert schemas_equal(live_index.schema.to_dict(), plan.merged_target_schema) + + # Test standalone async validator + validator = AsyncMigrationValidator() + validation, _target_info, _duration = await validator.validate( + load_migration_plan(str(plan_path)), + redis_url=redis_url, + query_check_file=str(query_check_path), + ) + assert validation.schema_match is True + assert validation.doc_count_match is True + assert validation.key_sample_exists is True + assert not validation.errors + + # Cleanup + await live_index.delete(drop=True) diff --git a/tests/unit/test_async_migration_executor.py b/tests/unit/test_async_migration_executor.py new file mode 100644 index 000000000..ac65c3842 --- /dev/null +++ b/tests/unit/test_async_migration_executor.py @@ -0,0 +1,1092 @@ +"""Unit tests for migration executors and disk space estimator. + +These tests mirror the sync MigrationExecutor patterns but use async/await. +Also includes pure-calculation tests for estimate_disk_space(). +""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from redisvl.migration import AsyncMigrationExecutor, MigrationExecutor +from redisvl.migration.models import ( + DiffClassification, + KeyspaceSnapshot, + MigrationPlan, + SourceSnapshot, + ValidationPolicy, + _format_bytes, +) +from redisvl.migration.utils import ( + build_scan_match_patterns, + estimate_disk_space, + normalize_keys, +) + + +def _make_basic_plan(): + """Create a basic migration plan for testing.""" + return MigrationPlan( + mode="drop_recreate", + source=SourceSnapshot( + index_name="test_index", + keyspace=KeyspaceSnapshot( + storage_type="hash", + prefixes=["test"], + key_separator=":", + key_sample=["test:1", "test:2"], + ), + schema_snapshot={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + stats_snapshot={"num_docs": 2}, + ), + requested_changes={}, + merged_target_schema={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", # Changed from flat + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + diff_classification=DiffClassification( + supported=True, + blocked_reasons=[], + ), + validation=ValidationPolicy( + require_doc_count_match=True, + ), + warnings=["Index downtime is required"], + ) + + +def test_async_executor_instantiation(): + """Test AsyncMigrationExecutor can be instantiated.""" + executor = AsyncMigrationExecutor() + assert executor is not None + assert executor.validator is not None + + +def test_async_executor_with_validator(): + """Test AsyncMigrationExecutor with custom validator.""" + from redisvl.migration import AsyncMigrationValidator + + custom_validator = AsyncMigrationValidator() + executor = AsyncMigrationExecutor(validator=custom_validator) + assert executor.validator is custom_validator + + +@pytest.mark.asyncio +async def test_async_executor_handles_unsupported_plan(): + """Test executor returns error report for unsupported plan.""" + plan = _make_basic_plan() + plan.diff_classification.supported = False + plan.diff_classification.blocked_reasons = ["Test blocked reason"] + + executor = AsyncMigrationExecutor() + + # The executor doesn't raise an error - it returns a report with errors + report = await executor.apply(plan, redis_url="redis://localhost:6379") + assert report.result == "failed" + assert "Test blocked reason" in report.validation.errors + + +@pytest.mark.asyncio +async def test_async_executor_validates_redis_url(): + """Test executor requires redis_url or redis_client.""" + plan = _make_basic_plan() + executor = AsyncMigrationExecutor() + + # The executor should raise an error internally when trying to connect + # but let's verify it doesn't crash before it tries to apply + # For a proper test, we'd need to mock AsyncSearchIndex.from_existing + # For now, we just verify the executor is created + assert executor is not None + + +# ============================================================================= +# Disk Space Estimator Tests +# ============================================================================= + + +def _make_quantize_plan( + source_dtype="float32", + target_dtype="float16", + dims=3072, + doc_count=100_000, + storage_type="hash", +): + """Helper to create a migration plan with a vector datatype change.""" + return MigrationPlan( + mode="drop_recreate", + source=SourceSnapshot( + index_name="test_index", + keyspace=KeyspaceSnapshot( + storage_type=storage_type, + prefixes=["test"], + key_separator=":", + ), + schema_snapshot={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": source_dtype, + }, + }, + ], + }, + stats_snapshot={"num_docs": doc_count}, + ), + requested_changes={}, + merged_target_schema={ + "index": { + "name": "test_index", + "prefix": "test", + "storage_type": storage_type, + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "hnsw", + "dims": dims, + "distance_metric": "cosine", + "datatype": target_dtype, + }, + }, + ], + }, + diff_classification=DiffClassification(supported=True, blocked_reasons=[]), + validation=ValidationPolicy(require_doc_count_match=True), + ) + + +def test_estimate_fp32_to_fp16(): + """FP32->FP16 with 3072 dims, 100K docs should produce expected byte counts.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan) + + assert est.has_quantization is True + assert len(est.vector_fields) == 1 + vf = est.vector_fields[0] + assert vf.source_bytes_per_doc == 3072 * 4 # 12288 + assert vf.target_bytes_per_doc == 3072 * 2 # 6144 + + assert est.total_source_vector_bytes == 100_000 * 12288 + assert est.total_target_vector_bytes == 100_000 * 6144 + assert est.memory_savings_after_bytes == 100_000 * (12288 - 6144) + + # RDB = source * 0.95 + assert est.rdb_snapshot_disk_bytes == int(100_000 * 12288 * 0.95) + # COW = full source + assert est.rdb_cow_memory_if_concurrent_bytes == 100_000 * 12288 + # AOF disabled by default + assert est.aof_enabled is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + + +def test_estimate_with_aof_enabled(): + """AOF growth should include RESP overhead per HSET.""" + plan = _make_quantize_plan("float32", "float16", dims=3072, doc_count=100_000) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.aof_enabled is True + target_vec_size = 3072 * 2 + expected_aof = 100_000 * (target_vec_size + 114) # 114 = HSET overhead + assert est.aof_growth_bytes == expected_aof + assert est.total_new_disk_bytes == est.rdb_snapshot_disk_bytes + expected_aof + + +def test_estimate_json_storage_aof(): + """JSON storage quantization should not report in-place rewrite costs.""" + plan = _make_quantize_plan( + "float32", "float16", dims=128, doc_count=1000, storage_type="json" + ) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.has_quantization is False + assert est.aof_growth_bytes == 0 + assert est.total_new_disk_bytes == 0 + + +def test_estimate_no_quantization(): + """Same dtype source and target should produce empty estimate.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + + assert est.has_quantization is False + assert len(est.vector_fields) == 0 + assert est.total_new_disk_bytes == 0 + assert est.memory_savings_after_bytes == 0 + + +def test_estimate_fp32_to_int8(): + """FP32->INT8 should use 1 byte per element.""" + plan = _make_quantize_plan("float32", "int8", dims=768, doc_count=50_000) + est = estimate_disk_space(plan) + + assert est.vector_fields[0].source_bytes_per_doc == 768 * 4 + assert est.vector_fields[0].target_bytes_per_doc == 768 * 1 + assert est.memory_savings_after_bytes == 50_000 * 768 * 3 + + +def test_estimate_summary_with_quantization(): + """Summary string should contain key information.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "Pre-migration disk space estimate" in summary + assert "test_index" in summary + assert "1,000 documents" in summary + assert "float32 -> float16" in summary + assert "RDB snapshot" in summary + assert "memory savings" in summary + + +def test_estimate_summary_no_quantization(): + """Summary for non-quantization migration should say no disk needed.""" + plan = _make_quantize_plan("float32", "float32", dims=128, doc_count=1000) + est = estimate_disk_space(plan) + summary = est.summary() + + assert "No vector quantization" in summary + + +def test_format_bytes_gb(): + assert _format_bytes(1_073_741_824) == "1.00 GB" + assert _format_bytes(2_147_483_648) == "2.00 GB" + + +def test_format_bytes_mb(): + assert _format_bytes(1_048_576) == "1.0 MB" + assert _format_bytes(10_485_760) == "10.0 MB" + + +def test_format_bytes_kb(): + assert _format_bytes(1024) == "1.0 KB" + assert _format_bytes(2048) == "2.0 KB" + + +def test_format_bytes_bytes(): + assert _format_bytes(500) == "500 bytes" + assert _format_bytes(0) == "0 bytes" + + +def test_savings_pct(): + """Verify savings percentage calculation.""" + plan = _make_quantize_plan("float32", "float16", dims=128, doc_count=100) + est = estimate_disk_space(plan) + # FP32->FP16 = 50% savings + assert est._savings_pct() == 50 + + +# ============================================================================= +# TDD RED Phase: Idempotent Dtype Detection Tests +# ============================================================================= +# These test detect_vector_dtype() and is_already_quantized() which inspect +# raw vector bytes to determine whether a key needs conversion or can be skipped. + + +def test_detect_dtype_float32_by_size(): + """A 3072-dim vector stored as FP32 should be 12288 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float32).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float32" + + +def test_detect_dtype_float16_by_size(): + """A 3072-dim vector stored as FP16 should be 6144 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.random.randn(3072).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=3072) + assert detected == "float16" + + +def test_detect_dtype_int8_by_size(): + """A 768-dim vector stored as INT8 should be 768 bytes.""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + vec = np.zeros(768, dtype=np.int8).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + assert detected == "int8" + + +def test_detect_dtype_bfloat16_by_size(): + """A 768-dim bfloat16 vector should be 1536 bytes (same as float16).""" + import numpy as np + + from redisvl.migration.reliability import detect_vector_dtype + + # bfloat16 and float16 are both 2 bytes per element + vec = np.random.randn(768).astype(np.float16).tobytes() + detected = detect_vector_dtype(vec, expected_dims=768) + # Cannot distinguish float16 from bfloat16 by size alone; returns "float16" + assert detected in ("float16", "bfloat16") + + +def test_detect_dtype_empty_returns_none(): + """Empty bytes should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + assert detect_vector_dtype(b"", expected_dims=128) is None + + +def test_detect_dtype_unknown_size(): + """Bytes that don't match any known dtype should return None.""" + from redisvl.migration.reliability import detect_vector_dtype + + # 7 bytes doesn't match any dtype for 3 dims + assert detect_vector_dtype(b"\x00" * 7, expected_dims=3) is None + + +def test_is_already_quantized_skip(): + """If source is float32 and vector is already float16, should return True.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is True + + +def test_is_already_quantized_needs_conversion(): + """If source is float32 and vector IS float32, should return False.""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float32).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="float16" + ) + assert result is False + + +def test_is_already_quantized_bfloat16_target(): + """If target is bfloat16 and vector is 2-bytes-per-element, should return True. + + bfloat16 and float16 share the same byte width (2 bytes per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="bfloat16" + ) + assert result is True + + +def test_is_already_quantized_uint8_target(): + """If target is uint8 and vector is 1-byte-per-element, should return True. + + uint8 and int8 share the same byte width (1 byte per element) + and are treated as the same dtype family for idempotent detection. + """ + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float32", target_dtype="uint8" + ) + assert result is True + + +def test_is_already_quantized_same_width_float16_to_bfloat16(): + """float16 -> bfloat16 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.float16).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="float16", target_dtype="bfloat16" + ) + assert result is False + + +def test_is_already_quantized_same_width_int8_to_uint8(): + """int8 -> uint8 should NOT be skipped (same byte width, different encoding).""" + import numpy as np + + from redisvl.migration.reliability import is_already_quantized + + vec = np.random.randn(128).astype(np.int8).tobytes() + result = is_already_quantized( + vec, expected_dims=128, source_dtype="int8", target_dtype="uint8" + ) + assert result is False + + +# ============================================================================= +# TDD RED Phase: Checkpoint File Tests +# ============================================================================= + + +def test_checkpoint_create_new(tmp_path): + """Creating a new checkpoint should initialize with zero progress.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="test_index", + total_keys=10000, + checkpoint_path=str(tmp_path / "checkpoint.yaml"), + ) + assert cp.index_name == "test_index" + assert cp.total_keys == 10000 + assert cp.completed_keys == 0 + assert cp.completed_batches == 0 + assert cp.last_batch_keys == [] + assert cp.status == "in_progress" + + +def test_checkpoint_save_and_load(tmp_path): + """Checkpoint should persist to disk and reload with same state.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = str(tmp_path / "checkpoint.yaml") + cp = QuantizationCheckpoint( + index_name="test_index", + total_keys=5000, + checkpoint_path=path, + ) + cp.record_batch(["key:1", "key:2", "key:3"]) + cp.save() + + loaded = QuantizationCheckpoint.load(path) + assert loaded.index_name == "test_index" + assert loaded.total_keys == 5000 + assert loaded.completed_keys == 3 + assert loaded.completed_batches == 1 + assert loaded.last_batch_keys == ["key:1", "key:2", "key:3"] + + +def test_checkpoint_record_multiple_batches(tmp_path): + """Recording multiple batches should accumulate counts.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=100, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + cp.record_batch(["k3", "k4", "k5"]) + + assert cp.completed_keys == 5 + assert cp.completed_batches == 2 + assert cp.last_batch_keys == ["k3", "k4", "k5"] + + +def test_checkpoint_mark_complete(tmp_path): + """Marking complete should set status to 'completed'.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=2, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + cp.mark_complete() + + assert cp.status == "completed" + + +def test_checkpoint_get_remaining_keys(tmp_path): + """get_remaining_keys should return only keys not yet processed.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=5, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + all_keys = ["k1", "k2", "k3", "k4", "k5"] + cp.record_batch(["k1", "k2"]) + + remaining = cp.get_remaining_keys(all_keys) + assert remaining == ["k3", "k4", "k5"] + + +def test_checkpoint_get_remaining_keys_uses_completed_offset_when_compact(tmp_path): + """Compact checkpoints should resume via completed_keys ordering.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=5, + checkpoint_path=str(tmp_path / "cp.yaml"), + ) + cp.record_batch(["k1", "k2"]) + + remaining = cp.get_remaining_keys(["k1", "k2", "k3", "k4", "k5"]) + assert remaining == ["k3", "k4", "k5"] + + +def test_checkpoint_save_excludes_processed_keys(tmp_path): + """New checkpoints should persist compact state without processed_keys.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = tmp_path / "checkpoint.yaml" + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=3, + checkpoint_path=str(path), + ) + cp.save() + + raw = path.read_text() + assert "processed_keys" not in raw + + +def test_checkpoint_load_nonexistent_returns_none(tmp_path): + """Loading a nonexistent checkpoint file should return None.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + result = QuantizationCheckpoint.load( + str(tmp_path / "nonexistent_checkpoint_xyz.yaml") + ) + assert result is None + + +def test_checkpoint_load_forces_path(tmp_path): + """load() should set checkpoint_path to the file used to load, not the stored value.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + original_path = str(tmp_path / "original.yaml") + cp = QuantizationCheckpoint( + index_name="idx", + total_keys=10, + checkpoint_path=original_path, + ) + cp.record_batch(["k1"]) + cp.save() + + # Move the file to a new location + new_path = str(tmp_path / "moved.yaml") + import shutil + + shutil.copy(original_path, new_path) + + loaded = QuantizationCheckpoint.load(new_path) + assert loaded.checkpoint_path == new_path # should use load path, not stored + + +def test_checkpoint_save_preserves_legacy_processed_keys(tmp_path): + """Legacy checkpoints should keep processed_keys across subsequent saves.""" + from redisvl.migration.reliability import QuantizationCheckpoint + + path = tmp_path / "legacy.yaml" + path.write_text( + "index_name: idx\n" + "total_keys: 4\n" + "processed_keys:\n" + " - k1\n" + " - k2\n" + "status: in_progress\n" + ) + + checkpoint = QuantizationCheckpoint.load(str(path)) + checkpoint.record_batch(["k3"]) + checkpoint.save() + + reloaded = QuantizationCheckpoint.load(str(path)) + assert reloaded.processed_keys == ["k1", "k2", "k3"] + assert reloaded.completed_keys == 3 + + +def test_quantize_vectors_saves_checkpoint_before_processing(monkeypatch, tmp_path): + """Checkpoint save should happen before the first HGET in a fresh run.""" + import numpy as np + + executor = MigrationExecutor() + checkpoint_path = str(tmp_path / "checkpoint.yaml") + field_bytes = np.array([1.0, 2.0], dtype=np.float32).tobytes() + events: list[str] = [] + + original_save = executor._quantize_vectors.__globals__[ + "QuantizationCheckpoint" + ].save + + def tracking_save(self): + events.append("save") + return original_save(self) + + monkeypatch.setattr( + executor._quantize_vectors.__globals__["QuantizationCheckpoint"], + "save", + tracking_save, + ) + + client = MagicMock() + client.hget.side_effect = lambda key, field: (events.append("hget") or field_bytes) + pipe = MagicMock() + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + result = executor._quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1"], + checkpoint_path=checkpoint_path, + ) + + assert result == 1 + assert events[0] == "save" + assert Path(checkpoint_path).exists() + + +def test_quantize_vectors_returns_reencoded_docs_not_scanned_docs(): + """Quantize count should reflect converted docs, not skipped docs.""" + import numpy as np + + executor = MigrationExecutor() + already_quantized = np.array([1.0, 2.0], dtype=np.float16).tobytes() + needs_quantization = np.array([1.0, 2.0], dtype=np.float32).tobytes() + + client = MagicMock() + client.hget.side_effect = lambda key, field: { + "doc:1": already_quantized, + "doc:2": needs_quantization, + }[key] + pipe = MagicMock() + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + progress: list[tuple[int, int]] = [] + result = executor._quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1", "doc:2"], + progress_callback=lambda done, total: progress.append((done, total)), + ) + + assert result == 1 + assert progress[-1] == (2, 2) + + +def test_build_scan_match_patterns_uses_separator(): + assert build_scan_match_patterns(["test"], ":") == ["test:*"] + assert build_scan_match_patterns(["test:"], ":") == ["test:*"] + assert build_scan_match_patterns([], ":") == ["*"] + assert build_scan_match_patterns(["b", "a"], ":") == ["a:*", "b:*"] + + +def test_normalize_keys_dedupes_and_sorts(): + assert normalize_keys(["b", "a", "b"]) == ["a", "b"] + + +def test_detect_aof_enabled_from_info(): + from redisvl.migration.utils import detect_aof_enabled + + client = MagicMock() + client.info.return_value = {"aof_enabled": 1} + assert detect_aof_enabled(client) is True + + +@pytest.mark.asyncio +async def test_async_detect_aof_enabled_from_info(): + executor = AsyncMigrationExecutor() + client = MagicMock() + client.info = AsyncMock(return_value={"aof_enabled": 1}) + assert await executor._detect_aof_enabled(client) is True + + +def test_estimate_json_quantization_is_noop(): + """JSON datatype changes should not report in-place rewrite costs.""" + plan = _make_quantize_plan( + "float32", "float16", dims=128, doc_count=1000, storage_type="json" + ) + est = estimate_disk_space(plan, aof_enabled=True) + + assert est.has_quantization is False + assert est.total_new_disk_bytes == 0 + assert est.aof_growth_bytes == 0 + + +def test_estimate_unknown_dtype_raises(): + plan = _make_quantize_plan("madeup32", "float16", dims=128, doc_count=10) + + with pytest.raises(ValueError, match="Unknown source vector datatype"): + estimate_disk_space(plan) + + +def test_enumerate_with_scan_uses_all_prefixes(): + executor = MigrationExecutor() + client = MagicMock() + client.ft.return_value.info.return_value = { + "index_definition": {"prefixes": ["alpha", "beta"]} + } + client.scan.side_effect = [ + (0, [b"alpha:1", b"shared:1"]), + (0, [b"beta:2", b"shared:1"]), + ] + + keys = list(executor._enumerate_with_scan(client, "idx", batch_size=1000)) + + assert keys == ["alpha:1", "shared:1", "beta:2"] + + +@pytest.mark.asyncio +async def test_async_enumerate_with_scan_uses_all_prefixes(): + executor = AsyncMigrationExecutor() + client = MagicMock() + client.ft.return_value.info = AsyncMock( + return_value={"index_definition": {"prefixes": ["alpha", "beta"]}} + ) + client.scan = AsyncMock( + side_effect=[ + (0, [b"alpha:1", b"shared:1"]), + (0, [b"beta:2", b"shared:1"]), + ] + ) + + keys = [ + key + async for key in executor._enumerate_with_scan(client, "idx", batch_size=1000) + ] + + assert keys == ["alpha:1", "shared:1", "beta:2"] + + +def test_apply_rejects_same_width_resume(monkeypatch): + plan = _make_quantize_plan("float16", "bfloat16", dims=2, doc_count=1) + executor = MigrationExecutor() + + def _make_index(*args, **kwargs): + index = MagicMock() + index._redis_client = MagicMock() + index.name = "test_index" + return index + + monkeypatch.setattr( + "redisvl.migration.executor.current_source_matches_snapshot", + lambda *args, **kwargs: True, + ) + monkeypatch.setattr( + "redisvl.migration.executor.SearchIndex.from_existing", + _make_index, + ) + monkeypatch.setattr( + "redisvl.migration.executor.SearchIndex.from_dict", + _make_index, + ) + + report = executor.apply( + plan, + redis_client=MagicMock(), + checkpoint_path="resume.yaml", + ) + + assert report.result == "failed" + assert "same-width datatype changes" in report.validation.errors[0] + + +@pytest.mark.asyncio +async def test_async_quantize_vectors_saves_checkpoint_before_processing( + monkeypatch, tmp_path +): + """Async checkpoint save should happen before the first HGET in a fresh run.""" + import numpy as np + + executor = AsyncMigrationExecutor() + checkpoint_path = str(tmp_path / "checkpoint.yaml") + field_bytes = np.array([1.0, 2.0], dtype=np.float32).tobytes() + events: list[str] = [] + + original_save = executor._async_quantize_vectors.__globals__[ + "QuantizationCheckpoint" + ].save + + def tracking_save(self): + events.append("save") + return original_save(self) + + monkeypatch.setattr( + executor._async_quantize_vectors.__globals__["QuantizationCheckpoint"], + "save", + tracking_save, + ) + + client = MagicMock() + client.hget = AsyncMock( + side_effect=lambda key, field: (events.append("hget") or field_bytes) + ) + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + result = await executor._async_quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1"], + checkpoint_path=checkpoint_path, + ) + + assert result == 1 + assert events[0] == "save" + assert Path(checkpoint_path).exists() + + +@pytest.mark.asyncio +async def test_async_quantize_vectors_returns_reencoded_docs_not_scanned_docs(): + """Async quantize count should reflect converted docs, not skipped docs.""" + import numpy as np + + executor = AsyncMigrationExecutor() + already_quantized = np.array([1.0, 2.0], dtype=np.float16).tobytes() + needs_quantization = np.array([1.0, 2.0], dtype=np.float32).tobytes() + + client = MagicMock() + client.hget = AsyncMock( + side_effect=lambda key, field: { + "doc:1": already_quantized, + "doc:2": needs_quantization, + }[key] + ) + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline.return_value = pipe + source_index = MagicMock() + source_index._redis_client = client + source_index.name = "idx" + + progress: list[tuple[int, int]] = [] + result = await executor._async_quantize_vectors( + source_index, + {"embedding": {"source": "float32", "target": "float16", "dims": 2}}, + ["doc:1", "doc:2"], + progress_callback=lambda done, total: progress.append((done, total)), + ) + + assert result == 1 + assert progress[-1] == (2, 2) + + +# ============================================================================= +# TDD RED Phase: BGSAVE Safety Net Tests +# ============================================================================= + + +def test_trigger_bgsave_success(): + """BGSAVE should be triggered and waited on; returns True on success.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import trigger_bgsave_and_wait + + mock_client = MagicMock() + mock_client.bgsave.return_value = True + mock_client.info.return_value = {"rdb_bgsave_in_progress": 0} + + result = trigger_bgsave_and_wait(mock_client, timeout_seconds=5) + assert result is True + mock_client.bgsave.assert_called_once() + + +def test_trigger_bgsave_already_in_progress(): + """If BGSAVE is already running, wait for it instead of starting a new one.""" + from unittest.mock import MagicMock, call + + from redisvl.migration.reliability import trigger_bgsave_and_wait + + mock_client = MagicMock() + # First bgsave raises because one is already in progress + mock_client.bgsave.side_effect = Exception("Background save already in progress") + # First check: still running; second check: done + mock_client.info.side_effect = [ + {"rdb_bgsave_in_progress": 1}, + {"rdb_bgsave_in_progress": 0}, + ] + + result = trigger_bgsave_and_wait(mock_client, timeout_seconds=5, poll_interval=0.01) + assert result is True + + +@pytest.mark.asyncio +async def test_async_trigger_bgsave_success(): + """Async BGSAVE should work the same as sync.""" + from unittest.mock import AsyncMock + + from redisvl.migration.reliability import async_trigger_bgsave_and_wait + + mock_client = AsyncMock() + mock_client.bgsave.return_value = True + mock_client.info.return_value = {"rdb_bgsave_in_progress": 0} + + result = await async_trigger_bgsave_and_wait(mock_client, timeout_seconds=5) + assert result is True + mock_client.bgsave.assert_called_once() + + +# ============================================================================= +# TDD RED Phase: Bounded Undo Buffer Tests +# ============================================================================= + + +def test_undo_buffer_store_and_rollback(): + """Undo buffer should store original values and rollback via pipeline.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01\x02\x03") + buf.store("key:2", "embedding", b"\x04\x05\x06\x07") + + assert buf.size == 2 + + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + + # Should have called hset twice to restore originals + assert mock_pipe.hset.call_count == 2 + mock_pipe.execute.assert_called_once() + + +def test_undo_buffer_clear(): + """After clear, buffer should be empty.""" + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "field", b"\x00") + assert buf.size == 1 + + buf.clear() + assert buf.size == 0 + + +def test_undo_buffer_empty_rollback(): + """Rolling back an empty buffer should be a no-op.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + + # No hset calls, no execute + mock_pipe.hset.assert_not_called() + mock_pipe.execute.assert_not_called() + + +def test_undo_buffer_multiple_fields_same_key(): + """Should handle multiple fields for the same key.""" + from unittest.mock import MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01") + buf.store("key:1", "embedding2", b"\x02\x03") + + assert buf.size == 2 + + mock_pipe = MagicMock() + buf.rollback(mock_pipe) + assert mock_pipe.hset.call_count == 2 + + +@pytest.mark.asyncio +async def test_undo_buffer_async_rollback(): + """async_rollback should await pipe.execute() for async Redis pipelines.""" + from unittest.mock import AsyncMock, MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + buf.store("key:1", "embedding", b"\x00\x01") + buf.store("key:2", "embedding", b"\x02\x03") + + mock_pipe = MagicMock() + mock_pipe.execute = AsyncMock() + + await buf.async_rollback(mock_pipe) + assert mock_pipe.hset.call_count == 2 + mock_pipe.execute.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_undo_buffer_async_rollback_empty(): + """async_rollback on empty buffer should be a no-op.""" + from unittest.mock import AsyncMock, MagicMock + + from redisvl.migration.reliability import BatchUndoBuffer + + buf = BatchUndoBuffer() + mock_pipe = MagicMock() + mock_pipe.execute = AsyncMock() + + await buf.async_rollback(mock_pipe) + mock_pipe.hset.assert_not_called() + mock_pipe.execute.assert_not_awaited() diff --git a/tests/unit/test_async_migration_planner.py b/tests/unit/test_async_migration_planner.py new file mode 100644 index 000000000..93ce3d49d --- /dev/null +++ b/tests/unit/test_async_migration_planner.py @@ -0,0 +1,319 @@ +"""Unit tests for AsyncMigrationPlanner. + +These tests mirror the sync MigrationPlanner tests but use async/await patterns. +""" + +from fnmatch import fnmatch + +import pytest +import yaml + +from redisvl.migration import AsyncMigrationPlanner, MigrationPlanner +from redisvl.schema.schema import IndexSchema + + +class AsyncDummyClient: + """Async mock Redis client for testing.""" + + def __init__(self, keys): + self.keys = keys + + async def scan(self, cursor=0, match=None, count=None): + matched = [] + for key in self.keys: + decoded_key = key.decode() if isinstance(key, bytes) else str(key) + if match is None or fnmatch(decoded_key, match): + matched.append(key) + return 0, matched + + +class AsyncDummyIndex: + """Async mock SearchIndex for testing.""" + + def __init__(self, schema, stats, keys): + self.schema = schema + self._stats = stats + self._client = AsyncDummyClient(keys) + + @property + def client(self): + return self._client + + async def info(self): + return self._stats + + +def _make_source_schema(): + return IndexSchema.from_dict( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + { + "name": "title", + "type": "text", + "path": "$.title", + "attrs": {"sortable": False}, + }, + { + "name": "price", + "type": "numeric", + "path": "$.price", + "attrs": {"sortable": True}, + }, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +@pytest.mark.asyncio +async def test_async_create_plan_from_schema_patch(monkeypatch, tmp_path): + """Test async planner creates valid plan from schema patch.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex( + source_schema, + {"num_docs": 2, "indexing": False}, + [b"docs:1", b"docs:2", b"docs:3"], + ) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + patch_path = tmp_path / "schema_patch.yaml" + patch_path.write_text( + yaml.safe_dump( + { + "version": 1, + "changes": { + "add_fields": [ + { + "name": "category", + "type": "tag", + "path": "$.category", + "attrs": {"separator": ","}, + } + ], + "remove_fields": ["price"], + "update_fields": [ + { + "name": "title", + "options": {"sortable": True}, + } + ], + }, + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner(key_sample_limit=2) + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + schema_patch_path=str(patch_path), + ) + + assert plan.diff_classification.supported is True + assert plan.source.index_name == "docs" + assert plan.source.keyspace.storage_type == "json" + assert plan.source.keyspace.prefixes == ["docs"] + assert plan.source.keyspace.key_separator == ":" + assert plan.source.keyspace.key_sample == ["docs:1", "docs:2"] + assert plan.warnings == ["Index downtime is required"] + + merged_fields = { + field["name"]: field for field in plan.merged_target_schema["fields"] + } + assert plan.merged_target_schema["index"]["prefix"] == "docs" + assert merged_fields["title"]["attrs"]["sortable"] is True + assert "price" not in merged_fields + assert merged_fields["category"]["type"] == "tag" + + # Test write_plan works (delegates to sync) + plan_path = tmp_path / "migration_plan.yaml" + planner.write_plan(plan, str(plan_path)) + written_plan = yaml.safe_load(plan_path.read_text()) + assert written_plan["mode"] == "drop_recreate" + assert written_plan["diff_classification"]["supported"] is True + + +@pytest.mark.asyncio +async def test_async_planner_datatype_change_allowed(monkeypatch, tmp_path): + """Changing vector datatype (quantization) is allowed - executor will re-encode.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "price", "type": "numeric", "path": "$.price"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float16", # Changed from float32 + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + # Verify datatype changes are detected + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, plan.merged_target_schema + ) + assert "embedding" in datatype_changes + assert datatype_changes["embedding"]["source"] == "float32" + assert datatype_changes["embedding"]["target"] == "float16" + + +@pytest.mark.asyncio +async def test_async_planner_algorithm_change_allowed(monkeypatch, tmp_path): + """Changing vector algorithm is allowed (index-only change).""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs", + "key_separator": ":", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "price", "type": "numeric", "path": "$.price"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "hnsw", # Changed from flat + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + assert plan.diff_classification.supported is True + assert len(plan.diff_classification.blocked_reasons) == 0 + + +@pytest.mark.asyncio +async def test_async_planner_prefix_change_is_supported(monkeypatch, tmp_path): + """Prefix change is supported: executor will rename keys.""" + source_schema = _make_source_schema() + dummy_index = AsyncDummyIndex(source_schema, {"num_docs": 2}, [b"docs:1"]) + + async def mock_from_existing(*args, **kwargs): + return dummy_index + + monkeypatch.setattr( + "redisvl.migration.async_planner.AsyncSearchIndex.from_existing", + mock_from_existing, + ) + + target_schema_path = tmp_path / "target_schema.yaml" + target_schema_path.write_text( + yaml.safe_dump( + { + "index": { + "name": "docs", + "prefix": "docs_v2", # Changed prefix + "key_separator": ":", + "storage_type": "json", + }, + "fields": source_schema.to_dict()["fields"], + }, + sort_keys=False, + ) + ) + + planner = AsyncMigrationPlanner() + plan = await planner.create_plan( + "docs", + redis_url="redis://localhost:6379", + target_schema_path=str(target_schema_path), + ) + + # Prefix change is now supported + assert plan.diff_classification.supported is True + assert plan.rename_operations.change_prefix == "docs_v2" + # Should have a warning about key renaming + assert any("prefix" in w.lower() for w in plan.warnings) From 787c64d394ae3b2b02458da1925cbadcaf960640 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 1 Apr 2026 19:43:09 -0400 Subject: [PATCH 2/7] fix: async executor/validation bug fixes (#562) - Fix unbound 'ready' variable in async_utils.py and async_executor.py - Fix completed checkpoint: resume from post-drop state - Pass rename_operations to get_vector_datatype_changes - Fix has_prefix_change falsy check for empty string prefixes - Fix partial key renames: fail fast on collision - Warn when field rename overwrites existing destination field - Fix async_validation prefix handling and indexing failure delta --- redisvl/migration/async_executor.py | 46 ++++++++++++++++++++------- redisvl/migration/async_utils.py | 1 + redisvl/migration/async_validation.py | 23 ++++++++++---- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index 7d00b83a8..73d8c275e 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -268,16 +268,18 @@ async def _rename_keys( logger.warning(f"Error in rename batch: {e}") raise + # Fail fast on collisions to avoid partial renames across batches. + if collisions: + raise RuntimeError( + f"Prefix rename aborted after {renamed} successful rename(s): " + f"{len(collisions)} destination key(s) already exist " + f"(first 5: {collisions[:5]}). This would overwrite existing data. " + f"Remove conflicting keys or choose a different prefix." + ) + if progress_callback: progress_callback(min(i + pipeline_size, total), total) - if collisions: - raise RuntimeError( - f"Prefix rename aborted: {len(collisions)} destination key(s) already exist " - f"(first 5: {collisions[:5]}). This would overwrite existing data. " - f"Remove conflicting keys or choose a different prefix." - ) - return renamed async def _rename_field_in_hash( @@ -296,15 +298,28 @@ async def _rename_field_in_hash( for i in range(0, total, pipeline_size): batch = keys[i : i + pipeline_size] + # Get old field values AND check if destination exists pipe = client.pipeline(transaction=False) for key in batch: pipe.hget(key, old_name) - values = await pipe.execute() + pipe.hexists(key, new_name) + raw_results = await pipe.execute() + # Interleaved: [hget_0, hexists_0, hget_1, hexists_1, ...] + values = raw_results[0::2] + dest_exists = raw_results[1::2] pipe = client.pipeline(transaction=False) batch_ops = 0 - for key, value in zip(batch, values): + for key, value, exists in zip(batch, values, dest_exists): if value is not None: + if exists: + logger.warning( + "Field '%s' already exists in key '%s'; " + "overwriting with value from '%s'", + new_name, + key, + old_name, + ) pipe.hset(key, new_name, value) pipe.hdel(key, old_name) batch_ops += 1 @@ -427,8 +442,12 @@ async def apply( plan.source.index_name, ) elif existing_checkpoint.status == "completed": + # Quantization completed before the crash -- still need + # to resume from post-drop state (index recreation). + resuming_from_checkpoint = True logger.info( - "Checkpoint at %s is already completed, ignoring", + "Checkpoint at %s is already completed; resuming " + "index recreation from post-drop state", checkpoint_path, ) else: @@ -488,12 +507,14 @@ async def apply( storage_type = plan.source.keyspace.storage_type datatype_changes = AsyncMigrationPlanner.get_vector_datatype_changes( - plan.source.schema_snapshot, plan.merged_target_schema + plan.source.schema_snapshot, + plan.merged_target_schema, + rename_operations=plan.rename_operations, ) # Check for rename operations rename_ops = plan.rename_operations - has_prefix_change = bool(rename_ops.change_prefix) + has_prefix_change = rename_ops.change_prefix is not None has_field_renames = bool(rename_ops.rename_fields) needs_quantization = bool(datatype_changes) and storage_type != "json" needs_enumeration = needs_quantization or has_prefix_change or has_field_renames @@ -1015,6 +1036,7 @@ async def _async_wait_for_index_ready( stable_ready_checks: Optional[int] = None while time.perf_counter() < deadline: + ready = False latest_info = await index.info() indexing = latest_info.get("indexing") percent_indexed = latest_info.get("percent_indexed") diff --git a/redisvl/migration/async_utils.py b/redisvl/migration/async_utils.py index 8e7af632a..571d2273d 100644 --- a/redisvl/migration/async_utils.py +++ b/redisvl/migration/async_utils.py @@ -48,6 +48,7 @@ async def async_wait_for_index_ready( stable_ready_checks: Optional[int] = None while time.perf_counter() < deadline: + ready = False latest_info = await index.info() indexing = latest_info.get("indexing") percent_indexed = latest_info.get("percent_indexed") diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py index 7b9691c4b..dabdf218d 100644 --- a/redisvl/migration/async_validation.py +++ b/redisvl/migration/async_validation.py @@ -67,13 +67,24 @@ async def validate( else: # Handle prefix change: transform key_sample to use new prefix keys_to_check = key_sample - if plan.rename_operations.change_prefix: + if plan.rename_operations.change_prefix is not None: old_prefix = plan.source.keyspace.prefixes[0] new_prefix = plan.rename_operations.change_prefix - keys_to_check = [ - new_prefix + k[len(old_prefix) :] if k.startswith(old_prefix) else k - for k in key_sample - ] + # Normalize separator: strip trailing separator from both + # prefixes to avoid double/missing separator in transformed keys + sep = ":" + old_base = old_prefix.rstrip(sep) + new_base = new_prefix.rstrip(sep) if new_prefix else "" + keys_to_check = [] + for k in key_sample: + if k.startswith(old_prefix): + suffix = k[len(old_prefix):] + if new_base: + keys_to_check.append(f"{new_base}{sep}{suffix}") + else: + keys_to_check.append(suffix) + else: + keys_to_check.append(k) existing_count = await client.exists(*keys_to_check) validation.key_sample_exists = existing_count == len(keys_to_check) @@ -94,7 +105,7 @@ async def validate( validation.errors.append( "Live document count does not match source num_docs." ) - if validation.indexing_failures_delta != 0: + if validation.indexing_failures_delta > 0: validation.errors.append("Indexing failures increased during migration.") if not validation.key_sample_exists: validation.errors.append( From 1c91ffeebf0a5e8b0531867aaff0018f26267447 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 1 Apr 2026 19:55:40 -0400 Subject: [PATCH 3/7] fix: async minor cleanups (#562) - Fix _quantize_vectors docstring: 'documents quantized' not 'processed' - Close internally-created Redis client in async_list_indexes --- redisvl/migration/async_executor.py | 2 +- redisvl/migration/async_utils.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index 73d8c275e..08a3c1fd3 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -871,7 +871,7 @@ async def _async_quantize_vectors( checkpoint_path: Optional path for checkpoint file (enables resume) Returns: - Number of documents processed + Number of documents quantized """ client = source_index._redis_client if client is None: diff --git a/redisvl/migration/async_utils.py b/redisvl/migration/async_utils.py index 571d2273d..6dc0c6738 100644 --- a/redisvl/migration/async_utils.py +++ b/redisvl/migration/async_utils.py @@ -14,17 +14,23 @@ async def async_list_indexes( *, redis_url: Optional[str] = None, redis_client: Optional[AsyncRedisClient] = None ) -> List[str]: """List all search indexes in Redis (async version).""" + created_client = False if redis_client is None: if not redis_url: raise ValueError("Must provide either redis_url or redis_client") redis_client = await RedisConnectionFactory._get_aredis_connection( redis_url=redis_url ) - index = AsyncSearchIndex.from_dict( - {"index": {"name": "__redisvl_migration_helper__"}, "fields": []}, - redis_client=redis_client, - ) - return await index.listall() + created_client = True + try: + index = AsyncSearchIndex.from_dict( + {"index": {"name": "__redisvl_migration_helper__"}, "fields": []}, + redis_client=redis_client, + ) + return await index.listall() + finally: + if created_client: + await redis_client.aclose() # type: ignore[union-attr] async def async_wait_for_index_ready( From 5d8811af9a0c122fdd2838a0cbb6191439405a8d Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 1 Apr 2026 20:28:22 -0400 Subject: [PATCH 4/7] fix(validation): fix double-colon bug in async prefix key transform + formatting --- redisvl/migration/async_validation.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py index dabdf218d..40763a287 100644 --- a/redisvl/migration/async_validation.py +++ b/redisvl/migration/async_validation.py @@ -65,24 +65,17 @@ async def validate( validation.key_sample_exists = False validation.errors.append("Failed to get Redis client for key sample check") else: - # Handle prefix change: transform key_sample to use new prefix + # Handle prefix change: transform key_sample to use new prefix. + # Must match the executor's RENAME logic exactly: + # new_key = new_prefix + key[len(old_prefix):] keys_to_check = key_sample if plan.rename_operations.change_prefix is not None: old_prefix = plan.source.keyspace.prefixes[0] new_prefix = plan.rename_operations.change_prefix - # Normalize separator: strip trailing separator from both - # prefixes to avoid double/missing separator in transformed keys - sep = ":" - old_base = old_prefix.rstrip(sep) - new_base = new_prefix.rstrip(sep) if new_prefix else "" keys_to_check = [] for k in key_sample: if k.startswith(old_prefix): - suffix = k[len(old_prefix):] - if new_base: - keys_to_check.append(f"{new_base}{sep}{suffix}") - else: - keys_to_check.append(suffix) + keys_to_check.append(new_prefix + k[len(old_prefix) :]) else: keys_to_check.append(k) existing_count = await client.exists(*keys_to_check) From 2b932b12c2fd0738826224f77ee59e39c16b2962 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 1 Apr 2026 21:46:40 -0400 Subject: [PATCH 5/7] fix: address review round 3 for migrate-async (#562) - Pass existing snapshot to create_plan_from_patch to avoid double Redis round-trip - Use _get_client() instead of _redis_client for lazy async client initialization - Remap datatype_changes keys to post-rename field names before quantization - Only resume from completed checkpoint when source index is actually gone --- redisvl/migration/async_executor.py | 47 +++++++++++++++++++++++------ redisvl/migration/async_planner.py | 16 ++++++---- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index 08a3c1fd3..02492103a 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -442,14 +442,29 @@ async def apply( plan.source.index_name, ) elif existing_checkpoint.status == "completed": - # Quantization completed before the crash -- still need - # to resume from post-drop state (index recreation). - resuming_from_checkpoint = True - logger.info( - "Checkpoint at %s is already completed; resuming " - "index recreation from post-drop state", - checkpoint_path, + # Quantization completed previously. Only resume if + # the source index is actually gone (post-drop crash). + source_still_exists = ( + await self._async_current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ) ) + if source_still_exists: + logger.info( + "Checkpoint at %s is completed and source index " + "still exists; treating as fresh run", + checkpoint_path, + ) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint at %s is already completed; resuming " + "index recreation from post-drop state", + checkpoint_path, + ) else: resuming_from_checkpoint = True logger.info( @@ -540,7 +555,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: progress_callback(step, detail) try: - client = source_index._redis_client + client = await source_index._get_client() if client is None: raise ValueError("Failed to get Redis client from source index") aof_enabled = await self._detect_aof_enabled(client) @@ -718,9 +733,21 @@ def _notify(step: str, detail: Optional[str] = None) -> None: for k in keys_to_process ] keys_to_process = normalize_keys(keys_to_process) + # Remap datatype_changes keys from source to target field + # names when field renames exist, since quantization runs + # after field renames (step 2). + effective_changes = datatype_changes + if has_field_renames and not resuming_from_checkpoint: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } docs_quantized = await self._async_quantize_vectors( source_index, - datatype_changes, + effective_changes, keys_to_process, progress_callback=lambda done, total: _notify( "quantize", f"{done:,}/{total:,} docs" @@ -873,7 +900,7 @@ async def _async_quantize_vectors( Returns: Number of documents quantized """ - client = source_index._redis_client + client = await source_index._get_client() if client is None: raise ValueError("Failed to get Redis client from source index") diff --git a/redisvl/migration/async_planner.py b/redisvl/migration/async_planner.py index 703714496..5d4f90799 100644 --- a/redisvl/migration/async_planner.py +++ b/redisvl/migration/async_planner.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Any, List, Optional from redisvl.index import AsyncSearchIndex from redisvl.migration.models import ( @@ -74,6 +74,7 @@ async def create_plan( schema_patch=schema_patch, redis_url=redis_url, redis_client=redis_client, + _snapshot=snapshot, ) async def create_plan_from_patch( @@ -83,12 +84,15 @@ async def create_plan_from_patch( schema_patch: SchemaPatch, redis_url: Optional[str] = None, redis_client: Optional[AsyncRedisClient] = None, + _snapshot: Optional[Any] = None, ) -> MigrationPlan: - snapshot = await self.snapshot_source( - index_name, - redis_url=redis_url, - redis_client=redis_client, - ) + if _snapshot is None: + _snapshot = await self.snapshot_source( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + snapshot = _snapshot source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) merged_target_schema = self._sync_planner.merge_patch( source_schema, schema_patch From 72abbf0e8408024da941b9bf020afb00cc2689b5 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 2 Apr 2026 10:18:43 -0400 Subject: [PATCH 6/7] fix: address review round 4 for migrate-async (#562) - Switch from import logging to redisvl.utils.log.get_logger in async_executor --- redisvl/migration/async_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index 02492103a..b5e1571b1 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -import logging +from redisvl.utils.log import get_logger import time from typing import Any, AsyncGenerator, Callable, Dict, List, Optional @@ -34,7 +34,7 @@ from redisvl.redis.utils import array_to_buffer, buffer_to_array from redisvl.types import AsyncRedisClient -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class AsyncMigrationExecutor: From 28fb8261d6b083dfbde020616321e709a74f8275 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 2 Apr 2026 10:43:54 -0400 Subject: [PATCH 7/7] fix: address review round 5 for migrate-async (#562) - Honor ValidationPolicy flags (require_schema_match, require_doc_count_match) in async validator - Handle missing source index in async current_source_matches_snapshot - Always remap datatype_changes keys on resume in async executor - Delete stale completed checkpoint on fresh run in async executor --- redisvl/migration/async_executor.py | 21 +++++++++++++++------ redisvl/migration/async_validation.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/redisvl/migration/async_executor.py b/redisvl/migration/async_executor.py index b5e1571b1..960d3eb2a 100644 --- a/redisvl/migration/async_executor.py +++ b/redisvl/migration/async_executor.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +from pathlib import Path + from redisvl.utils.log import get_logger import time from typing import Any, AsyncGenerator, Callable, Dict, List, Optional @@ -458,6 +460,9 @@ async def apply( "still exists; treating as fresh run", checkpoint_path, ) + # Remove the stale checkpoint so that downstream + # steps (e.g. _quantize_vectors) don't skip work. + Path(checkpoint_path).unlink(missing_ok=True) else: resuming_from_checkpoint = True logger.info( @@ -737,7 +742,7 @@ def _notify(step: str, detail: Optional[str] = None) -> None: # names when field renames exist, since quantization runs # after field renames (step 2). effective_changes = datatype_changes - if has_field_renames and not resuming_from_checkpoint: + if has_field_renames: field_rename_map = { fr.old_name: fr.new_name for fr in rename_ops.rename_fields } @@ -1111,11 +1116,15 @@ async def _async_current_source_matches_snapshot( """Check if current source schema matches the snapshot (async version).""" from redisvl.migration.utils import schemas_equal - current_index = await AsyncSearchIndex.from_existing( - index_name, - redis_url=redis_url, - redis_client=redis_client, - ) + try: + current_index = await AsyncSearchIndex.from_existing( + index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + except Exception: + # Index no longer exists (e.g. already dropped during migration) + return False return schemas_equal(current_index.schema.to_dict(), expected_schema) def _build_benchmark_summary( diff --git a/redisvl/migration/async_validation.py b/redisvl/migration/async_validation.py index 40763a287..bd0dcf876 100644 --- a/redisvl/migration/async_validation.py +++ b/redisvl/migration/async_validation.py @@ -92,9 +92,9 @@ async def validate( user_checks = await self._run_query_checks(target_index, query_check_file) validation.query_checks.extend(user_checks) - if not validation.schema_match: + if not validation.schema_match and plan.validation.require_schema_match: validation.errors.append("Live schema does not match merged_target_schema.") - if not validation.doc_count_match: + if not validation.doc_count_match and plan.validation.require_doc_count_match: validation.errors.append( "Live document count does not match source num_docs." )