diff --git a/redisvl/migration/executor.py b/redisvl/migration/executor.py index 565f9e10f..91977c60b 100644 --- a/redisvl/migration/executor.py +++ b/redisvl/migration/executor.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional from redis.exceptions import ResponseError @@ -14,6 +15,13 @@ MigrationValidation, ) from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.reliability import ( + BatchUndoBuffer, + QuantizationCheckpoint, + is_already_quantized, + is_same_width_dtype_conversion, + trigger_bgsave_and_wait, +) from redisvl.migration.utils import ( build_scan_match_patterns, current_source_matches_snapshot, @@ -25,6 +33,7 @@ wait_for_index_ready, ) from redisvl.migration.validation import MigrationValidator +from redisvl.redis.utils import array_to_buffer, buffer_to_array from redisvl.types import SyncRedisClient from redisvl.utils.log import get_logger @@ -264,6 +273,8 @@ def _rename_keys( raise # Fail fast on collisions to avoid partial renames across batches. + # Keys already renamed in THIS batch are not rolled back -- caller + # can inspect the error to understand which keys moved. if collisions: raise RuntimeError( f"Prefix rename aborted after {renamed} successful rename(s): " @@ -328,6 +339,8 @@ def _rename_field_in_hash( try: pipe.execute() + # Count by number of keys that had old field values, + # not by HSET return (HSET returns 0 for existing field updates) renamed += batch_ops except Exception as e: logger.warning(f"Error in field rename batch: {e}") @@ -367,6 +380,9 @@ def _rename_field_in_json( values = pipe.execute() # Now set new field and delete old + # JSONPath GET returns results as a list; unwrap single-element + # results to preserve the original document shape. + # Missing paths return None or [] depending on Redis version. pipe = client.pipeline(transaction=False) batch_ops = 0 for key, value in zip(batch, values): @@ -379,6 +395,8 @@ def _rename_field_in_json( batch_ops += 1 try: pipe.execute() + # Count by number of keys that had old field values, + # not by JSON.SET return value renamed += batch_ops except Exception as e: logger.warning(f"Error in JSON field rename batch: {e}") @@ -397,6 +415,7 @@ def apply( redis_client: Optional[Any] = None, query_check_file: Optional[str] = None, progress_callback: Optional[Callable[[str, Optional[str]], None]] = None, + checkpoint_path: Optional[str] = None, ) -> MigrationReport: """Apply a migration plan. @@ -406,8 +425,10 @@ def apply( redis_client: Optional existing Redis client. query_check_file: Optional file with query checks. progress_callback: Optional callback(step, detail) for progress updates. - step: Current step name (e.g., "drop", "create", "index", "validate") + step: Current step name (e.g., "drop", "quantize", "create", "index", "validate") detail: Optional detail string (e.g., "1000/5000 docs (20%)") + checkpoint_path: Optional path for quantization checkpoint file. + When provided, enables crash-safe resume for vector re-encoding. """ started_at = timestamp_utc() started = time.perf_counter() @@ -429,26 +450,84 @@ def apply( report.finished_at = timestamp_utc() return report - if not current_source_matches_snapshot( - plan.source.index_name, - plan.source.schema_snapshot, - redis_url=redis_url, - redis_client=redis_client, - ): - report.validation.errors.append( - "The current live source schema no longer matches the saved source snapshot." + # Check if we are resuming from a checkpoint (post-drop crash). + # If so, the source index may no longer exist in Redis, so we + # skip live schema validation and construct from the plan snapshot. + resuming_from_checkpoint = False + if checkpoint_path: + existing_checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if existing_checkpoint is not None: + # Validate checkpoint belongs to this migration and is incomplete + if existing_checkpoint.index_name != plan.source.index_name: + logger.warning( + "Checkpoint index '%s' does not match plan index '%s', ignoring", + existing_checkpoint.index_name, + plan.source.index_name, + ) + elif existing_checkpoint.status == "completed": + # Quantization completed previously. Only resume if + # the source index is actually gone (post-drop crash). + # If the source still exists, this is a fresh run and + # the stale checkpoint should be ignored. + source_still_exists = current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ) + if source_still_exists: + logger.info( + "Checkpoint at %s is completed and source index " + "still exists; treating as fresh run", + checkpoint_path, + ) + # Remove the stale checkpoint so that downstream + # steps (e.g. _quantize_vectors) don't skip work. + Path(checkpoint_path).unlink(missing_ok=True) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint at %s is already completed; resuming " + "index recreation from post-drop state", + checkpoint_path, + ) + else: + resuming_from_checkpoint = True + logger.info( + "Checkpoint found at %s, skipping source index validation " + "(index may have been dropped before crash)", + checkpoint_path, + ) + + if not resuming_from_checkpoint: + if not current_source_matches_snapshot( + plan.source.index_name, + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, + ): + report.validation.errors.append( + "The current live source schema no longer matches the saved source snapshot." + ) + report.manual_actions.append( + "Re-run `rvl migrate plan` to refresh the migration plan before applying." + ) + report.finished_at = timestamp_utc() + return report + + source_index = SearchIndex.from_existing( + plan.source.index_name, + redis_url=redis_url, + redis_client=redis_client, ) - report.manual_actions.append( - "Re-run `rvl migrate plan` to refresh the migration plan before applying." + else: + # Source index was dropped before crash; reconstruct from snapshot + # to get a valid SearchIndex with a Redis client attached. + source_index = SearchIndex.from_dict( + plan.source.schema_snapshot, + redis_url=redis_url, + redis_client=redis_client, ) - report.finished_at = timestamp_utc() - return report - - source_index = SearchIndex.from_existing( - plan.source.index_name, - redis_url=redis_url, - redis_client=redis_client, - ) target_index = SearchIndex.from_dict( plan.merged_target_schema, @@ -458,19 +537,45 @@ def apply( enumerate_duration = 0.0 drop_duration = 0.0 + quantize_duration = 0.0 field_rename_duration = 0.0 key_rename_duration = 0.0 recreate_duration = 0.0 indexing_duration = 0.0 target_info: Dict[str, Any] = {} + docs_quantized = 0 keys_to_process: List[str] = [] storage_type = plan.source.keyspace.storage_type + # Check if we need to re-encode vectors for datatype changes + datatype_changes = MigrationPlanner.get_vector_datatype_changes( + plan.source.schema_snapshot, + plan.merged_target_schema, + rename_operations=plan.rename_operations, + ) + # Check for rename operations rename_ops = plan.rename_operations has_prefix_change = rename_ops.change_prefix is not None has_field_renames = bool(rename_ops.rename_fields) - needs_enumeration = has_prefix_change or has_field_renames + needs_quantization = bool(datatype_changes) and storage_type != "json" + needs_enumeration = needs_quantization or has_prefix_change or has_field_renames + has_same_width_quantization = any( + is_same_width_dtype_conversion(change["source"], change["target"]) + for change in datatype_changes.values() + ) + + if checkpoint_path and has_same_width_quantization: + report.validation.errors.append( + "Crash-safe resume is not supported for same-width datatype " + "changes (float16<->bfloat16 or int8<->uint8)." + ) + report.manual_actions.append( + "Re-run without --resume for same-width vector conversions, or " + "split the migration to avoid same-width datatype changes." + ) + report.finished_at = timestamp_utc() + return report def _notify(step: str, detail: Optional[str] = None) -> None: if progress_callback: @@ -480,77 +585,141 @@ def _notify(step: str, detail: Optional[str] = None) -> None: client = source_index._redis_client aof_enabled = detect_aof_enabled(client) disk_estimate = estimate_disk_space(plan, aof_enabled=aof_enabled) + if disk_estimate.has_quantization: + logger.info( + "Disk space estimate: RDB ~%d bytes, AOF ~%d bytes, total ~%d bytes", + disk_estimate.rdb_snapshot_disk_bytes, + disk_estimate.aof_growth_bytes, + disk_estimate.total_new_disk_bytes, + ) report.disk_space_estimate = disk_estimate - # STEP 1: Enumerate keys BEFORE any modifications - # Needed for: prefix change or field renames - if needs_enumeration: - _notify("enumerate", "Enumerating indexed documents...") - enumerate_started = time.perf_counter() - keys_to_process = list( - self._enumerate_indexed_keys( - client, - plan.source.index_name, - batch_size=1000, - key_separator=plan.source.keyspace.key_separator, + if resuming_from_checkpoint: + # On resume after a post-drop crash, the index no longer + # exists. Enumerate keys via SCAN using the plan prefix, + # and skip BGSAVE / field renames / drop (already done). + if needs_enumeration: + _notify("enumerate", "Enumerating documents via SCAN (resume)...") + enumerate_started = time.perf_counter() + prefixes = list(plan.source.keyspace.prefixes) + # If a prefix change was part of the migration, keys + # were already renamed before the crash, so scan with + # the new prefix instead. + if has_prefix_change and rename_ops.change_prefix: + prefixes = [rename_ops.change_prefix] + seen_keys: set[str] = set() + for match_pattern in build_scan_match_patterns( + prefixes, plan.source.keyspace.key_separator + ): + cursor: int = 0 + while True: + cursor, scanned = client.scan( # type: ignore[misc] + cursor=cursor, + match=match_pattern, + count=1000, + ) + for k in scanned: + key = k.decode() if isinstance(k, bytes) else str(k) + if key not in seen_keys: + seen_keys.add(key) + keys_to_process.append(key) + if cursor == 0: + break + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", ) - ) - keys_to_process = normalize_keys(keys_to_process) - enumerate_duration = round(time.perf_counter() - enumerate_started, 3) - _notify( - "enumerate", - f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", - ) - # STEP 2: Field renames (before dropping index) - if has_field_renames and keys_to_process: - _notify("field_rename", "Renaming fields in documents...") - field_rename_started = time.perf_counter() - for field_rename in rename_ops.rename_fields: - if storage_type == "json": - old_path = get_schema_field_path( - plan.source.schema_snapshot, field_rename.old_name - ) - new_path = get_schema_field_path( - plan.merged_target_schema, field_rename.new_name - ) - if not old_path or not new_path or old_path == new_path: - continue - self._rename_field_in_json( - client, - keys_to_process, - old_path, - new_path, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), - ) - else: - self._rename_field_in_hash( + _notify("bgsave", "skipped (resume)") + _notify("drop", "skipped (already dropped)") + else: + # Normal (non-resume) path + # STEP 1: Enumerate keys BEFORE any modifications + # Needed for: quantization, prefix change, or field renames + if needs_enumeration: + _notify("enumerate", "Enumerating indexed documents...") + enumerate_started = time.perf_counter() + keys_to_process = list( + self._enumerate_indexed_keys( client, - keys_to_process, - field_rename.old_name, - field_rename.new_name, - progress_callback=lambda done, total: _notify( - "field_rename", - f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", - ), + plan.source.index_name, + batch_size=1000, + key_separator=plan.source.keyspace.key_separator, ) - field_rename_duration = round( - time.perf_counter() - field_rename_started, 3 - ) - _notify("field_rename", f"done ({field_rename_duration}s)") + ) + keys_to_process = normalize_keys(keys_to_process) + enumerate_duration = round( + time.perf_counter() - enumerate_started, 3 + ) + _notify( + "enumerate", + f"found {len(keys_to_process):,} documents ({enumerate_duration}s)", + ) - # STEP 3: Drop the index - _notify("drop", "Dropping index definition...") - drop_started = time.perf_counter() - source_index.delete(drop=False) - drop_duration = round(time.perf_counter() - drop_started, 3) - _notify("drop", f"done ({drop_duration}s)") + # BGSAVE safety net: snapshot data before mutations begin + if needs_enumeration and keys_to_process: + _notify("bgsave", "Triggering BGSAVE safety snapshot...") + try: + trigger_bgsave_and_wait(client) + _notify("bgsave", "done") + except Exception as e: + logger.warning("BGSAVE safety snapshot failed: %s", e) + _notify("bgsave", f"skipped ({e})") + + # STEP 2: Field renames (before dropping index) + if has_field_renames and keys_to_process: + _notify("field_rename", "Renaming fields in documents...") + field_rename_started = time.perf_counter() + for field_rename in rename_ops.rename_fields: + if storage_type == "json": + old_path = get_schema_field_path( + plan.source.schema_snapshot, field_rename.old_name + ) + new_path = get_schema_field_path( + plan.merged_target_schema, field_rename.new_name + ) + if not old_path or not new_path or old_path == new_path: + continue + self._rename_field_in_json( + client, + keys_to_process, + old_path, + new_path, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + else: + self._rename_field_in_hash( + client, + keys_to_process, + field_rename.old_name, + field_rename.new_name, + progress_callback=lambda done, total: _notify( + "field_rename", + f"{field_rename.old_name} -> {field_rename.new_name}: {done:,}/{total:,}", + ), + ) + field_rename_duration = round( + time.perf_counter() - field_rename_started, 3 + ) + _notify("field_rename", f"done ({field_rename_duration}s)") + + # STEP 3: Drop the index + _notify("drop", "Dropping index definition...") + drop_started = time.perf_counter() + source_index.delete(drop=False) + drop_duration = round(time.perf_counter() - drop_started, 3) + _notify("drop", f"done ({drop_duration}s)") # STEP 4: Key renames (after drop, before recreate) - if has_prefix_change and keys_to_process: + # On resume, key renames were already done before the crash. + if has_prefix_change and keys_to_process and not resuming_from_checkpoint: _notify("key_rename", "Renaming keys...") key_rename_started = time.perf_counter() old_prefix = plan.source.keyspace.prefixes[0] @@ -571,6 +740,65 @@ def _notify(step: str, detail: Optional[str] = None) -> None: f"done ({renamed_count:,} keys in {key_rename_duration}s)", ) + # STEP 5: Re-encode vectors using pre-enumerated keys + if needs_quantization and keys_to_process: + _notify("quantize", "Re-encoding vectors...") + quantize_started = time.perf_counter() + # If we renamed keys (non-resume), update keys_to_process + if ( + has_prefix_change + and rename_ops.change_prefix + and not resuming_from_checkpoint + ): + old_prefix = plan.source.keyspace.prefixes[0] + new_prefix = rename_ops.change_prefix + keys_to_process = [ + ( + new_prefix + k[len(old_prefix) :] + if k.startswith(old_prefix) + else k + ) + for k in keys_to_process + ] + keys_to_process = normalize_keys(keys_to_process) + # Remap datatype_changes keys from source to target field + # names when field renames exist, since quantization runs + # after field renames (step 2). The plan always stores + # datatype_changes keyed by source field names, so the + # remap is needed regardless of whether we are resuming. + effective_changes = datatype_changes + if has_field_renames: + field_rename_map = { + fr.old_name: fr.new_name for fr in rename_ops.rename_fields + } + effective_changes = { + field_rename_map.get(k, k): v + for k, v in datatype_changes.items() + } + docs_quantized = self._quantize_vectors( + source_index, + effective_changes, + keys_to_process, + progress_callback=lambda done, total: _notify( + "quantize", f"{done:,}/{total:,} docs" + ), + checkpoint_path=checkpoint_path, + ) + quantize_duration = round(time.perf_counter() - quantize_started, 3) + _notify( + "quantize", + f"done ({docs_quantized:,} docs in {quantize_duration}s)", + ) + report.warnings.append( + f"Re-encoded {docs_quantized} documents for vector quantization: " + f"{datatype_changes}" + ) + elif datatype_changes and storage_type == "json": + # No checkpoint for JSON: vectors are re-indexed on recreate, + # so there is nothing to resume. Creating one would leave a + # stale in-progress checkpoint that misleads future runs. + _notify("quantize", "skipped (JSON vectors are re-indexed on recreate)") + _notify("create", "Creating index with new schema...") recreate_started = time.perf_counter() target_index.create() @@ -600,6 +828,9 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: report.timings = MigrationTimings( total_migration_duration_seconds=total_duration, drop_duration_seconds=drop_duration, + quantize_duration_seconds=( + quantize_duration if quantize_duration else None + ), field_rename_duration_seconds=( field_rename_duration if field_rename_duration else None ), @@ -613,6 +844,7 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: drop_duration + field_rename_duration + key_rename_duration + + quantize_duration + recreate_duration + indexing_duration, 3, @@ -633,6 +865,7 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: report.timings = MigrationTimings( total_migration_duration_seconds=total_duration, drop_duration_seconds=drop_duration or None, + quantize_duration_seconds=quantize_duration or None, field_rename_duration_seconds=field_rename_duration or None, key_rename_duration_seconds=key_rename_duration or None, recreate_duration_seconds=recreate_duration or None, @@ -642,6 +875,7 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: drop_duration + field_rename_duration + key_rename_duration + + quantize_duration + recreate_duration + indexing_duration, 3, @@ -649,6 +883,7 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: if drop_duration or field_rename_duration or key_rename_duration + or quantize_duration or recreate_duration or indexing_duration else None @@ -668,6 +903,176 @@ def _index_progress(indexed: int, total: int, pct: float) -> None: return report + def _quantize_vectors( + self, + source_index: SearchIndex, + datatype_changes: Dict[str, Dict[str, Any]], + keys: List[str], + progress_callback: Optional[Callable[[int, int], None]] = None, + checkpoint_path: Optional[str] = None, + ) -> int: + """Re-encode vectors in documents for datatype changes (quantization). + + Uses pre-enumerated keys (from _enumerate_indexed_keys) to process + only the documents that were in the index, avoiding full keyspace scan. + Includes idempotent skip (already-quantized vectors), bounded undo + buffer for per-batch rollback, and optional checkpointing for resume. + + Args: + source_index: The source SearchIndex (already dropped but client available) + datatype_changes: Dict mapping field_name -> {"source", "target", "dims"} + keys: Pre-enumerated list of document keys to process + progress_callback: Optional callback(docs_done, total_docs) + checkpoint_path: Optional path for checkpoint file (enables resume) + + Returns: + Number of documents quantized + """ + client = source_index._redis_client + total_keys = len(keys) + docs_processed = 0 + docs_quantized = 0 + skipped = 0 + batch_size = 500 + + # Load or create checkpoint for resume support + checkpoint: Optional[QuantizationCheckpoint] = None + if checkpoint_path: + checkpoint = QuantizationCheckpoint.load(checkpoint_path) + if checkpoint: + # Validate checkpoint matches current migration BEFORE + # checking completion status to avoid skipping quantization + # for an unrelated completed checkpoint. + if checkpoint.index_name != source_index.name: + raise ValueError( + f"Checkpoint index '{checkpoint.index_name}' does not match " + f"source index '{source_index.name}'. " + f"Use the correct checkpoint file or remove it to start fresh." + ) + # Skip if checkpoint shows a completed migration + if checkpoint.status == "completed": + logger.info( + "Checkpoint already marked as completed for index '%s'. " + "Skipping quantization. Remove the checkpoint file to force re-run.", + checkpoint.index_name, + ) + return 0 + if checkpoint.total_keys != total_keys: + if checkpoint.processed_keys: + current_keys = set(keys) + missing_processed = [ + key + for key in checkpoint.processed_keys + if key not in current_keys + ] + if missing_processed or total_keys < checkpoint.total_keys: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + logger.warning( + "Checkpoint total_keys=%d differs from current key set size=%d. " + "Proceeding because all legacy processed keys are present.", + checkpoint.total_keys, + total_keys, + ) + else: + raise ValueError( + f"Checkpoint total_keys={checkpoint.total_keys} does not match " + f"the current key set ({total_keys}). " + "Use the correct checkpoint file or remove it to start fresh." + ) + remaining = checkpoint.get_remaining_keys(keys) + logger.info( + "Resuming from checkpoint: %d/%d keys already processed", + total_keys - len(remaining), + total_keys, + ) + docs_processed = total_keys - len(remaining) + keys = remaining + total_keys_for_progress = total_keys + else: + checkpoint = QuantizationCheckpoint( + index_name=source_index.name, + total_keys=total_keys, + checkpoint_path=checkpoint_path, + ) + checkpoint.save() + total_keys_for_progress = total_keys + else: + total_keys_for_progress = total_keys + + remaining_keys = len(keys) + + for i in range(0, remaining_keys, batch_size): + batch = keys[i : i + batch_size] + pipe = client.pipeline() + undo = BatchUndoBuffer() + keys_updated_in_batch: set[str] = set() + + try: + for key in batch: + for field_name, change in datatype_changes.items(): + field_data: bytes | None = client.hget(key, field_name) # type: ignore[misc,assignment] + if not field_data: + continue + + # Idempotent: skip if already converted to target dtype + dims = change.get("dims", 0) + if dims and is_already_quantized( + field_data, dims, change["source"], change["target"] + ): + skipped += 1 + continue + + undo.store(key, field_name, field_data) + array = buffer_to_array(field_data, change["source"]) + new_bytes = array_to_buffer(array, change["target"]) + pipe.hset(key, field_name, new_bytes) # type: ignore[arg-type] + keys_updated_in_batch.add(key) + + if keys_updated_in_batch: + pipe.execute() + except Exception: + logger.warning( + "Batch %d failed, rolling back %d entries", + i // batch_size, + undo.size, + ) + rollback_pipe = client.pipeline() + undo.rollback(rollback_pipe) + if checkpoint: + checkpoint.save() + raise + finally: + undo.clear() + + docs_quantized += len(keys_updated_in_batch) + docs_processed += len(batch) + + if checkpoint: + # Record all keys in batch (including skipped) so they + # are not re-scanned on resume + checkpoint.record_batch(batch) + checkpoint.save() + + if progress_callback: + progress_callback(docs_processed, total_keys_for_progress) + + if checkpoint: + checkpoint.mark_complete() + checkpoint.save() + + if skipped: + logger.info("Skipped %d already-quantized vector fields", skipped) + logger.info( + "Quantized %d documents across %d fields", + docs_quantized, + len(datatype_changes), + ) + return docs_quantized + def _build_benchmark_summary( self, plan: MigrationPlan, diff --git a/redisvl/migration/reliability.py b/redisvl/migration/reliability.py new file mode 100644 index 000000000..71f6e672e --- /dev/null +++ b/redisvl/migration/reliability.py @@ -0,0 +1,340 @@ +"""Crash-safe quantization utilities for index migration. + +Provides idempotent dtype detection, checkpointing, BGSAVE safety, +and bounded undo buffering for reliable vector re-encoding. +""" + +import asyncio +import os +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml +from pydantic import BaseModel, Field + +from redisvl.migration.models import DTYPE_BYTES +from redisvl.utils.log import get_logger + +logger = get_logger(__name__) + +# Dtypes that share byte widths and are functionally interchangeable +# for idempotent detection purposes (same byte length per element). +_DTYPE_FAMILY: Dict[str, str] = { + "float64": "8byte", + "float32": "4byte", + "float16": "2byte", + "bfloat16": "2byte", + "int8": "1byte", + "uint8": "1byte", +} + + +def is_same_width_dtype_conversion(source_dtype: str, target_dtype: str) -> bool: + """Return True when two dtypes share byte width but differ in encoding.""" + if source_dtype == target_dtype: + return False + source_family = _DTYPE_FAMILY.get(source_dtype) + target_family = _DTYPE_FAMILY.get(target_dtype) + if source_family is None or target_family is None: + return False + return source_family == target_family + + +# --------------------------------------------------------------------------- +# Idempotent Dtype Detection +# --------------------------------------------------------------------------- + + +def detect_vector_dtype(data: bytes, expected_dims: int) -> Optional[str]: + """Inspect raw vector bytes and infer the storage dtype. + + Uses byte length and expected dimensions to determine which dtype + the vector is currently stored as. Returns the canonical representative + for each byte-width family (float16 for 2-byte, int8 for 1-byte), + since dtypes within a family cannot be distinguished by length alone. + + Args: + data: Raw vector bytes from Redis. + expected_dims: Number of dimensions expected for this vector field. + + Returns: + Detected dtype string (e.g. "float32", "float16", "int8") or None + if the size does not match any known dtype. + """ + if not data or expected_dims <= 0: + return None + + nbytes = len(data) + + # Check each dtype in decreasing element size to avoid ambiguity. + # Only canonical representatives are checked (float16 covers bfloat16, + # int8 covers uint8) since they share byte widths. + for dtype in ("float64", "float32", "float16", "int8"): + if nbytes == expected_dims * DTYPE_BYTES[dtype]: + return dtype + + return None + + +def is_already_quantized( + data: bytes, + expected_dims: int, + source_dtype: str, + target_dtype: str, +) -> bool: + """Check whether a vector has already been converted to the target dtype. + + Uses byte-width families to handle ambiguous dtypes. For example, + if source is float32 and target is float16, a vector detected as + 2-bytes-per-element is considered already quantized (the byte width + shrank from 4 to 2, so conversion already happened). + + However, same-width conversions (e.g. float16 -> bfloat16 or + int8 -> uint8) are NOT skipped because the encoding semantics + differ even though the byte length is identical. We cannot + distinguish these by length, so we must always re-encode. + + Args: + data: Raw vector bytes. + expected_dims: Number of dimensions. + source_dtype: The dtype the vector was originally stored as. + target_dtype: The dtype we want to convert to. + + Returns: + True if the vector already matches the target dtype (skip conversion). + """ + detected = detect_vector_dtype(data, expected_dims) + if detected is None: + return False + + detected_family = _DTYPE_FAMILY.get(detected) + target_family = _DTYPE_FAMILY.get(target_dtype) + source_family = _DTYPE_FAMILY.get(source_dtype) + + # If detected byte-width matches target family, the vector looks converted. + # But if source and target share the same byte-width family (e.g. + # float16 -> bfloat16), we cannot tell whether conversion happened, + # so we must NOT skip -- always re-encode for same-width migrations. + if source_family == target_family: + return False + + return detected_family == target_family + + +# --------------------------------------------------------------------------- +# Quantization Checkpoint +# --------------------------------------------------------------------------- + + +class QuantizationCheckpoint(BaseModel): + """Tracks migration progress for crash-safe resume.""" + + index_name: str + total_keys: int + completed_keys: int = 0 + completed_batches: int = 0 + last_batch_keys: List[str] = Field(default_factory=list) + # Retained for backward compatibility with older checkpoint files. + # New checkpoints rely on completed_keys with deterministic key ordering + # instead of rewriting an ever-growing processed key list on every batch. + processed_keys: List[str] = Field(default_factory=list) + status: str = "in_progress" + checkpoint_path: str = "" + + def record_batch(self, keys: List[str]) -> None: + """Record a successfully processed batch. + + Does not auto-save to disk. Call save() after record_batch() + to persist the checkpoint for crash recovery. + """ + self.completed_keys += len(keys) + self.completed_batches += 1 + self.last_batch_keys = list(keys) + if self.processed_keys: + self.processed_keys.extend(keys) + + def mark_complete(self) -> None: + """Mark the migration as completed.""" + self.status = "completed" + + def save(self) -> None: + """Persist checkpoint to disk atomically. + + Writes to a temporary file first, then renames. This ensures a + crash mid-write does not corrupt the checkpoint file. + """ + path = Path(self.checkpoint_path) + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=path.parent, suffix=".tmp", prefix=".checkpoint_" + ) + try: + exclude = set() + if not self.processed_keys: + exclude.add("processed_keys") + with os.fdopen(fd, "w") as f: + yaml.safe_dump( + self.model_dump(exclude=exclude), + f, + sort_keys=False, + ) + os.replace(tmp_path, str(path)) + except BaseException: + # Clean up temp file on any failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + @classmethod + def load(cls, path: str) -> Optional["QuantizationCheckpoint"]: + """Load a checkpoint from disk. Returns None if file does not exist. + + Always sets checkpoint_path to the path used to load, not the + value stored in the file. This ensures subsequent save() calls + write to the correct location even if the file was moved. + """ + p = Path(path) + if not p.exists(): + return None + with open(p, "r") as f: + data = yaml.safe_load(f) + if not data: + return None + checkpoint = cls.model_validate(data) + if checkpoint.processed_keys and checkpoint.completed_keys < len( + checkpoint.processed_keys + ): + checkpoint.completed_keys = len(checkpoint.processed_keys) + checkpoint.checkpoint_path = str(p) + return checkpoint + + def get_remaining_keys(self, all_keys: List[str]) -> List[str]: + """Return keys that have not yet been processed.""" + if self.processed_keys: + done = set(self.processed_keys) + return [k for k in all_keys if k not in done] + + if self.completed_keys <= 0: + return list(all_keys) + + return all_keys[self.completed_keys :] + + +# --------------------------------------------------------------------------- +# BGSAVE Safety Net +# --------------------------------------------------------------------------- + + +def trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Trigger a Redis BGSAVE and wait for it to complete. + + If a BGSAVE is already in progress, waits for it instead. + + Args: + client: Sync Redis client. + timeout_seconds: Max seconds to wait for BGSAVE to finish. + poll_interval: Seconds between status polls. + + Returns: + True if BGSAVE completed successfully. + """ + try: + client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = client.info("persistence") + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + status = info.get("rdb_last_bgsave_status", "ok") + if status != "ok": + logger.warning("BGSAVE completed with status: %s", status) + return False + return True + time.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +async def async_trigger_bgsave_and_wait( + client: Any, + *, + timeout_seconds: int = 300, + poll_interval: float = 1.0, +) -> bool: + """Async version of trigger_bgsave_and_wait.""" + try: + await client.bgsave() + except Exception as exc: + if "already in progress" not in str(exc).lower(): + raise + logger.info("BGSAVE already in progress, waiting for it to finish.") + + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + info = await client.info("persistence") + if isinstance(info, dict) and not info.get("rdb_bgsave_in_progress", 0): + status = info.get("rdb_last_bgsave_status", "ok") + if status != "ok": + logger.warning("BGSAVE completed with status: %s", status) + return False + return True + await asyncio.sleep(poll_interval) + + raise TimeoutError(f"BGSAVE did not complete within {timeout_seconds}s") + + +# --------------------------------------------------------------------------- +# Bounded Undo Buffer +# --------------------------------------------------------------------------- + + +class BatchUndoBuffer: + """Stores original vector values for the current batch to allow rollback. + + Memory-bounded: only holds data for one batch at a time. Call clear() + after each successful batch commit. + """ + + def __init__(self) -> None: + self._entries: List[Tuple[str, str, bytes]] = [] + + @property + def size(self) -> int: + return len(self._entries) + + def store(self, key: str, field: str, original_value: bytes) -> None: + """Record the original value of a field before mutation.""" + self._entries.append((key, field, original_value)) + + def rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (sync).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + pipe.execute() + + async def async_rollback(self, pipe: Any) -> None: + """Restore all stored originals via the given pipeline (async).""" + if not self._entries: + return + for key, field, value in self._entries: + pipe.hset(key, field, value) + await pipe.execute() + + def clear(self) -> None: + """Discard all stored entries.""" + self._entries.clear() diff --git a/tests/integration/test_migration_comprehensive.py b/tests/integration/test_migration_comprehensive.py index c9d26e90f..0b3353573 100644 --- a/tests/integration/test_migration_comprehensive.py +++ b/tests/integration/test_migration_comprehensive.py @@ -1211,6 +1211,39 @@ def test_change_hnsw_epsilon(self, redis_url, tmp_path, base_schema, sample_docs cleanup_index(index) raise + def test_change_datatype_quantization( + self, redis_url, tmp_path, base_schema, sample_docs + ): + """Test changing vector datatype (quantization).""" + index = setup_index(redis_url, base_schema, sample_docs) + + try: + result = run_migration( + redis_url, + tmp_path, + base_schema["index"]["name"], + { + "version": 1, + "changes": { + "update_fields": [ + {"name": "embedding", "attrs": {"datatype": "float16"}} + ], + }, + }, + ) + + assert result["supported"], "Change datatype should be supported" + assert result["succeeded"], f"Migration failed: {result['report']}" + cleanup_index(index) + except Exception: + cleanup_index(index) + raise + + +# ============================================================================== +# 8. JSON Storage Type Tests +# ============================================================================== + class TestJsonStorageType: """Tests for migrations with JSON storage type.""" diff --git a/tests/integration/test_migration_routes.py b/tests/integration/test_migration_routes.py new file mode 100644 index 000000000..5d897d010 --- /dev/null +++ b/tests/integration/test_migration_routes.py @@ -0,0 +1,331 @@ +""" +Integration tests for migration routes. + +Tests the full Apply+Validate flow for all supported migration operations. +Requires Redis 8.0+ for INT8/UINT8 datatype tests. +""" + +import uuid + +import pytest +from redis import Redis + +from redisvl.index import SearchIndex +from redisvl.migration import MigrationExecutor, MigrationPlanner +from redisvl.migration.models import FieldUpdate, SchemaPatch +from tests.conftest import skip_if_redis_version_below + + +def create_source_index(redis_url, worker_id, source_attrs): + """Helper to create a source index with specified vector attributes.""" + unique_id = str(uuid.uuid4())[:8] + index_name = f"mig_route_{worker_id}_{unique_id}" + prefix = f"mig_route:{worker_id}:{unique_id}" + + base_attrs = { + "dims": 128, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + } + base_attrs.update(source_attrs) + + index = SearchIndex.from_dict( + { + "index": {"name": index_name, "prefix": prefix, "storage_type": "json"}, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": base_attrs, + }, + ], + }, + redis_url=redis_url, + ) + index.create(overwrite=True) + return index, index_name + + +def run_migration(redis_url, index_name, patch_attrs): + """Helper to run a migration with the given patch attributes.""" + patch = SchemaPatch( + version=1, + changes={ + "add_fields": [], + "remove_fields": [], + "update_fields": [FieldUpdate(name="embedding", attrs=patch_attrs)], + "rename_fields": [], + "index": {}, + }, + ) + + planner = MigrationPlanner() + plan = planner.create_plan_from_patch( + index_name, schema_patch=patch, redis_url=redis_url + ) + + executor = MigrationExecutor() + report = executor.apply(plan, redis_url=redis_url) + return report, plan + + +class TestAlgorithmChanges: + """Test algorithm migration routes.""" + + def test_hnsw_to_flat(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"algorithm": "flat"}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert str(live.schema.fields["embedding"].attrs.algorithm).endswith("FLAT") + finally: + index.delete(drop=True) + + def test_flat_to_hnsw_with_params(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "m": 32, "ef_construction": 200}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert str(attrs.algorithm).endswith("HNSW") + assert attrs.m == 32 + assert attrs.ef_construction == 200 + finally: + index.delete(drop=True) + + +class TestDatatypeChanges: + """Test datatype migration routes.""" + + @pytest.mark.parametrize( + "source_dtype,target_dtype", + [ + ("float32", "float16"), + ("float32", "bfloat16"), + ("float16", "float32"), + ], + ) + def test_flat_datatype_change( + self, redis_url, worker_id, source_dtype, target_dtype + ): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat", "datatype": source_dtype} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) + def test_flat_quantized_datatype(self, redis_url, worker_id, target_dtype): + """Test INT8/UINT8 datatypes (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize( + "source_dtype,target_dtype", + [ + ("float32", "float16"), + ("float32", "bfloat16"), + ], + ) + def test_hnsw_datatype_change( + self, redis_url, worker_id, source_dtype, target_dtype + ): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw", "datatype": source_dtype} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + @pytest.mark.parametrize("target_dtype", ["int8", "uint8"]) + def test_hnsw_quantized_datatype(self, redis_url, worker_id, target_dtype): + """Test INT8/UINT8 datatypes with HNSW (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8/UINT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"datatype": target_dtype}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + +class TestDistanceMetricChanges: + """Test distance metric migration routes.""" + + @pytest.mark.parametrize( + "source_metric,target_metric", + [ + ("cosine", "l2"), + ("cosine", "ip"), + ("l2", "cosine"), + ("ip", "l2"), + ], + ) + def test_distance_metric_change( + self, redis_url, worker_id, source_metric, target_metric + ): + index, index_name = create_source_index( + redis_url, + worker_id, + {"algorithm": "flat", "distance_metric": source_metric}, + ) + try: + report, _ = run_migration( + redis_url, index_name, {"distance_metric": target_metric} + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + +class TestHNSWTuningParameters: + """Test HNSW parameter tuning routes.""" + + def test_hnsw_m_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"m": 64}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert live.schema.fields["embedding"].attrs.m == 64 + finally: + index.delete(drop=True) + + def test_hnsw_ef_construction_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"ef_construction": 500}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + assert live.schema.fields["embedding"].attrs.ef_construction == 500 + finally: + index.delete(drop=True) + + def test_hnsw_ef_runtime_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"ef_runtime": 50}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + def test_hnsw_epsilon_parameter(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration(redis_url, index_name, {"epsilon": 0.1}) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True) + + def test_hnsw_all_params_combined(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "hnsw"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"m": 48, "ef_construction": 300, "ef_runtime": 75, "epsilon": 0.05}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert attrs.m == 48 + assert attrs.ef_construction == 300 + finally: + index.delete(drop=True) + + +class TestCombinedChanges: + """Test combined migration routes (multiple changes at once).""" + + def test_flat_to_hnsw_with_datatype_and_metric(self, redis_url, worker_id): + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "datatype": "float16", "distance_metric": "l2"}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + + live = SearchIndex.from_existing(index_name, redis_url=redis_url) + attrs = live.schema.fields["embedding"].attrs + assert str(attrs.algorithm).endswith("HNSW") + finally: + index.delete(drop=True) + + def test_flat_to_hnsw_with_int8(self, redis_url, worker_id): + """Combined algorithm + quantized datatype (requires Redis 8.0+).""" + client = Redis.from_url(redis_url) + skip_if_redis_version_below(client, "8.0.0", "INT8 requires Redis 8.0+") + index, index_name = create_source_index( + redis_url, worker_id, {"algorithm": "flat"} + ) + try: + report, _ = run_migration( + redis_url, + index_name, + {"algorithm": "hnsw", "datatype": "int8"}, + ) + assert report.result == "succeeded" + assert report.validation.schema_match is True + finally: + index.delete(drop=True)