From deb11587cee52646fa32e28bbb1959b37c8ea6eb Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 2 Apr 2026 11:45:13 -0400 Subject: [PATCH 1/2] feat(migrate): [2/6] interactive migration wizard with CLI subcommand Adds guided migration builder for interactive plan creation: - wizard.py: MigrationWizard with index selection, field operations, vector tuning, quantization, and preview - cli/migrate.py: adds 'wizard' subcommand (rvl migrate wizard --index ) - Unit tests for wizard logic (41 tests) --- redisvl/cli/migrate.py | 61 +- redisvl/migration/__init__.py | 2 + redisvl/migration/wizard.py | 814 ++++++++++++++++++ tests/unit/test_migration_wizard.py | 1190 +++++++++++++++++++++++++++ 4 files changed, 2066 insertions(+), 1 deletion(-) create mode 100644 redisvl/migration/wizard.py create mode 100644 tests/unit/test_migration_wizard.py diff --git a/redisvl/cli/migrate.py b/redisvl/cli/migrate.py index 5bc70bb36..9b6d66adc 100644 --- a/redisvl/cli/migrate.py +++ b/redisvl/cli/migrate.py @@ -3,7 +3,12 @@ from typing import Optional from redisvl.cli.utils import add_redis_connection_options, create_redis_url -from redisvl.migration import MigrationExecutor, MigrationPlanner, MigrationValidator +from redisvl.migration import ( + MigrationExecutor, + MigrationPlanner, + MigrationValidator, + MigrationWizard, +) from redisvl.migration.utils import ( detect_aof_enabled, estimate_disk_space, @@ -25,6 +30,7 @@ class Migrate: "Commands:", "\thelper Show migration guidance and supported capabilities", "\tlist List all available indexes", + "\twizard Interactively build a migration plan and schema patch", "\tplan Generate a migration plan for a document-preserving drop/recreate migration", "\tapply Execute a reviewed drop/recreate migration plan", "\testimate Estimate disk space required for a migration plan (dry-run, no mutations)", @@ -84,6 +90,7 @@ def helper(self): Commands: rvl migrate list List all indexes + rvl migrate wizard --index Guided migration builder rvl migrate plan --index --schema-patch rvl migrate apply --plan rvl migrate validate --plan """ @@ -101,6 +108,58 @@ def list(self): for position, index_name in enumerate(indexes, start=1): print(f"{position}. {index_name}") + def wizard(self): + parser = argparse.ArgumentParser( + usage=( + "rvl migrate wizard [--index ] " + "[--patch ] " + "[--plan-out ] [--patch-out ]" + ) + ) + parser.add_argument("-i", "--index", help="Source index name", required=False) + parser.add_argument( + "--patch", + help="Load an existing schema patch to continue editing", + default=None, + ) + parser.add_argument( + "--plan-out", + help="Path to write migration_plan.yaml", + default="migration_plan.yaml", + ) + parser.add_argument( + "--patch-out", + help="Path to write schema_patch.yaml (for later editing)", + default="schema_patch.yaml", + ) + parser.add_argument( + "--target-schema-out", + help="Optional path to write the merged target schema", + default=None, + ) + parser.add_argument( + "--key-sample-limit", + help="Maximum number of keys to sample from the index keyspace", + type=int, + default=10, + ) + parser = add_redis_connection_options(parser) + args = parser.parse_args(sys.argv[3:]) + + redis_url = create_redis_url(args) + wizard = MigrationWizard( + planner=MigrationPlanner(key_sample_limit=args.key_sample_limit) + ) + plan = wizard.run( + index_name=args.index, + redis_url=redis_url, + existing_patch_path=args.patch, + plan_out=args.plan_out, + patch_out=args.patch_out, + target_schema_out=args.target_schema_out, + ) + self._print_plan_summary(args.plan_out, plan) + def plan(self): parser = argparse.ArgumentParser( usage=( diff --git a/redisvl/migration/__init__.py b/redisvl/migration/__init__.py index 8ee0461cc..7a7b57ca0 100644 --- a/redisvl/migration/__init__.py +++ b/redisvl/migration/__init__.py @@ -1,9 +1,11 @@ from redisvl.migration.executor import MigrationExecutor from redisvl.migration.planner import MigrationPlanner from redisvl.migration.validation import MigrationValidator +from redisvl.migration.wizard import MigrationWizard __all__ = [ "MigrationExecutor", "MigrationPlanner", "MigrationValidator", + "MigrationWizard", ] diff --git a/redisvl/migration/wizard.py b/redisvl/migration/wizard.py new file mode 100644 index 000000000..3127b0497 --- /dev/null +++ b/redisvl/migration/wizard.py @@ -0,0 +1,814 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import yaml + +from redisvl.migration.models import ( + FieldRename, + FieldUpdate, + SchemaPatch, + SchemaPatchChanges, +) +from redisvl.migration.planner import MigrationPlanner +from redisvl.migration.utils import list_indexes, write_yaml +from redisvl.schema.schema import IndexSchema + +SUPPORTED_FIELD_TYPES = ["text", "tag", "numeric", "geo"] +UPDATABLE_FIELD_TYPES = ["text", "tag", "numeric", "geo", "vector"] + + +class MigrationWizard: + def __init__(self, planner: Optional[MigrationPlanner] = None): + self.planner = planner or MigrationPlanner() + + def run( + self, + *, + index_name: Optional[str] = None, + redis_url: Optional[str] = None, + redis_client: Optional[Any] = None, + existing_patch_path: Optional[str] = None, + plan_out: str = "migration_plan.yaml", + patch_out: Optional[str] = None, + target_schema_out: Optional[str] = None, + ): + resolved_index_name = self._resolve_index_name( + index_name=index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + snapshot = self.planner.snapshot_source( + resolved_index_name, + redis_url=redis_url, + redis_client=redis_client, + ) + source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + + print(f"Building a migration plan for index '{resolved_index_name}'") + self._print_source_schema(source_schema.to_dict()) + + # Load existing patch if provided + existing_changes = None + if existing_patch_path: + existing_changes = self._load_existing_patch(existing_patch_path) + + schema_patch = self._build_patch( + source_schema.to_dict(), existing_changes=existing_changes + ) + plan = self.planner.create_plan_from_patch( + resolved_index_name, + schema_patch=schema_patch, + redis_url=redis_url, + redis_client=redis_client, + ) + self.planner.write_plan(plan, plan_out) + + if patch_out: + write_yaml(schema_patch.model_dump(exclude_none=True), patch_out) + if target_schema_out: + write_yaml(plan.merged_target_schema, target_schema_out) + + return plan + + def _load_existing_patch(self, patch_path: str) -> SchemaPatchChanges: + from redisvl.migration.utils import load_yaml + + data = load_yaml(patch_path) + patch = SchemaPatch.model_validate(data) + print(f"Loaded existing patch from {patch_path}") + print(f" Add fields: {len(patch.changes.add_fields)}") + print(f" Update fields: {len(patch.changes.update_fields)}") + print(f" Remove fields: {len(patch.changes.remove_fields)}") + print(f" Rename fields: {len(patch.changes.rename_fields)}") + if patch.changes.index: + print(f" Index changes: {list(patch.changes.index.keys())}") + return patch.changes + + def _resolve_index_name( + self, + *, + index_name: Optional[str], + redis_url: Optional[str], + redis_client: Optional[Any], + ) -> str: + if index_name: + return index_name + + indexes = list_indexes(redis_url=redis_url, redis_client=redis_client) + if not indexes: + raise ValueError("No indexes found in Redis") + + print("Available indexes:") + for position, name in enumerate(indexes, start=1): + print(f"{position}. {name}") + + while True: + choice = input("Select an index by number or name: ").strip() + if choice in indexes: + return choice + if choice.isdigit(): + offset = int(choice) - 1 + if 0 <= offset < len(indexes): + return indexes[offset] + print("Invalid selection. Please try again.") + + @staticmethod + def _filter_staged_adds( + working_schema: Dict[str, Any], staged_add_names: set + ) -> Dict[str, Any]: + """Return a copy of working_schema with staged-add fields removed. + + This prevents staged additions from appearing in update/rename + candidate lists. + """ + import copy + + filtered = copy.deepcopy(working_schema) + filtered["fields"] = [ + f for f in filtered["fields"] if f["name"] not in staged_add_names + ] + return filtered + + def _apply_staged_changes( + self, + source_schema: Dict[str, Any], + changes: SchemaPatchChanges, + ) -> Dict[str, Any]: + """Build a working copy of source_schema reflecting staged changes. + + This ensures subsequent prompts show the current state of the schema + after renames, removes, and adds have been queued. + """ + import copy + + working = copy.deepcopy(source_schema) + + # Apply removes + removed_names = set(changes.remove_fields) + working["fields"] = [ + f for f in working["fields"] if f["name"] not in removed_names + ] + + # Apply renames + rename_map = {r.old_name: r.new_name for r in changes.rename_fields} + for field in working["fields"]: + if field["name"] in rename_map: + field["name"] = rename_map[field["name"]] + + # Apply updates (reflect attribute changes in working schema) + update_map = {u.name: u for u in changes.update_fields} + for field in working["fields"]: + if field["name"] in update_map: + upd = update_map[field["name"]] + if upd.attrs: + field.setdefault("attrs", {}).update(upd.attrs) + if upd.type: + field["type"] = upd.type + + # Apply adds + for added in changes.add_fields: + working["fields"].append(added) + + # Apply index-level changes (name, prefix) so preview reflects them + if changes.index: + for key, value in changes.index.items(): + working["index"][key] = value + + return working + + def _build_patch( + self, + source_schema: Dict[str, Any], + existing_changes: Optional[SchemaPatchChanges] = None, + ) -> SchemaPatch: + if existing_changes: + changes = existing_changes + else: + changes = SchemaPatchChanges() + done = False + while not done: + # Refresh working schema to reflect staged changes + working_schema = self._apply_staged_changes(source_schema, changes) + + print("\nChoose an action:") + print("1. Add field (text, tag, numeric, geo)") + print("2. Update field (sortable, weight, separator, vector config)") + print("3. Remove field") + print("4. Rename field (rename field in all documents)") + print("5. Rename index (change index name)") + print("6. Change prefix (rename all keys)") + print("7. Preview patch (show pending changes as YAML)") + print("8. Finish") + action = input("Enter a number: ").strip() + + if action == "1": + field = self._prompt_add_field(working_schema) + if field: + staged_names = {f["name"] for f in changes.add_fields} + if field["name"] in staged_names: + print( + f"Field '{field['name']}' is already staged for addition." + ) + else: + changes.add_fields.append(field) + elif action == "2": + # Filter out staged additions from update candidates + staged_add_names = {f["name"] for f in changes.add_fields} + update_schema = self._filter_staged_adds( + working_schema, staged_add_names + ) + update = self._prompt_update_field(update_schema) + if update: + # Merge with existing update for same field if present + existing = next( + (u for u in changes.update_fields if u.name == update.name), + None, + ) + if existing: + if update.attrs: + existing.attrs = {**(existing.attrs or {}), **update.attrs} + if update.type: + existing.type = update.type + else: + changes.update_fields.append(update) + elif action == "3": + field_name = self._prompt_remove_field(working_schema) + if field_name: + # If removing a staged-add, cancel the add instead of + # appending to remove_fields + staged_add_names = {f["name"] for f in changes.add_fields} + if field_name in staged_add_names: + changes.add_fields = [ + f for f in changes.add_fields if f["name"] != field_name + ] + print(f"Cancelled staged addition of '{field_name}'.") + else: + changes.remove_fields.append(field_name) + # Also remove any queued updates or renames for this field + changes.update_fields = [ + u for u in changes.update_fields if u.name != field_name + ] + changes.rename_fields = [ + r for r in changes.rename_fields if r.old_name != field_name + ] + elif action == "4": + # Filter out staged additions from rename candidates + staged_add_names = {f["name"] for f in changes.add_fields} + rename_schema = self._filter_staged_adds( + working_schema, staged_add_names + ) + field_rename = self._prompt_rename_field(rename_schema) + if field_rename: + # Check rename target doesn't collide with staged additions + if field_rename.new_name in staged_add_names: + print( + f"Cannot rename to '{field_rename.new_name}': " + "a field with that name is already staged for addition." + ) + else: + changes.rename_fields.append(field_rename) + elif action == "5": + new_name = self._prompt_rename_index(working_schema) + if new_name: + changes.index["name"] = new_name + elif action == "6": + new_prefix = self._prompt_change_prefix(working_schema) + if new_prefix: + changes.index["prefix"] = new_prefix + elif action == "7": + print( + yaml.safe_dump( + { + "version": 1, + "changes": changes.model_dump(exclude_none=True), + }, + sort_keys=False, + ) + ) + elif action == "8": + done = True + else: + print("Invalid action. Please choose 1-8.") + + return SchemaPatch(version=1, changes=changes) + + def _prompt_add_field( + self, source_schema: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + field_name = input("Field name: ").strip() + existing_names = {field["name"] for field in source_schema["fields"]} + if not field_name: + print("Field name is required.") + return None + if field_name in existing_names: + print(f"Field '{field_name}' already exists in the source schema.") + return None + + field_type = self._prompt_from_choices( + "Field type", + SUPPORTED_FIELD_TYPES, + block_message="Vector fields cannot be added (requires embedding all documents). Only text, tag, numeric, and geo are supported.", + ) + if not field_type: + return None + + field: Dict[str, Any] = {"name": field_name, "type": field_type} + storage_type = source_schema["index"]["storage_type"] + if storage_type == "json": + print(" JSON path: location in document where this field is stored") + path = ( + input(f"JSON path [default $.{field_name}]: ").strip() + or f"$.{field_name}" + ) + field["path"] = path + + attrs = self._prompt_common_attrs(field_type) + if attrs: + field["attrs"] = attrs + return field + + def _prompt_update_field( + self, source_schema: Dict[str, Any] + ) -> Optional[FieldUpdate]: + fields = [ + field + for field in source_schema["fields"] + if field["type"] in UPDATABLE_FIELD_TYPES + ] + if not fields: + print("No updatable fields are available.") + return None + + print("Updatable fields:") + for position, field in enumerate(fields, start=1): + print(f"{position}. {field['name']} ({field['type']})") + + choice = input("Select a field to update by number or name: ").strip() + selected: Optional[Dict[str, Any]] = None + for position, field in enumerate(fields, start=1): + if choice == str(position) or choice == field["name"]: + selected = field + break + if not selected: + print("Invalid field selection.") + return None + + if selected["type"] == "vector": + attrs = self._prompt_vector_attrs(selected) + else: + attrs = self._prompt_common_attrs(selected["type"], allow_blank=True) + if not attrs: + print("No changes collected.") + return None + return FieldUpdate(name=selected["name"], attrs=attrs) + + def _prompt_remove_field(self, source_schema: Dict[str, Any]) -> Optional[str]: + removable_fields = [field["name"] for field in source_schema["fields"]] + if not removable_fields: + print("No fields available to remove.") + return None + + print("Removable fields:") + for position, field in enumerate(source_schema["fields"], start=1): + field_type = field["type"] + warning = " [WARNING: vector field]" if field_type == "vector" else "" + print(f"{position}. {field['name']} ({field_type}){warning}") + + choice = input("Select a field to remove by number or name: ").strip() + selected_name: Optional[str] = None + if choice in removable_fields: + selected_name = choice + elif choice.isdigit(): + offset = int(choice) - 1 + if 0 <= offset < len(removable_fields): + selected_name = removable_fields[offset] + + if not selected_name: + print("Invalid field selection.") + return None + + # Check if it's a vector field and require confirmation + selected_field = next( + (f for f in source_schema["fields"] if f["name"] == selected_name), None + ) + if selected_field and selected_field["type"] == "vector": + print( + f"\n WARNING: Removing vector field '{selected_name}' will:\n" + " - Remove it from the search index\n" + " - Leave vector data in documents (wasted storage)\n" + " - Require re-embedding if you want to restore it later" + ) + confirm = input("Type 'yes' to confirm removal: ").strip().lower() + if confirm != "yes": + print("Cancelled.") + return None + + return selected_name + + def _prompt_rename_field( + self, source_schema: Dict[str, Any] + ) -> Optional[FieldRename]: + """Prompt user to rename a field in all documents.""" + fields = source_schema["fields"] + if not fields: + print("No fields available to rename.") + return None + + print("Fields available for renaming:") + for position, field in enumerate(fields, start=1): + print(f"{position}. {field['name']} ({field['type']})") + + choice = input("Select a field to rename by number or name: ").strip() + selected: Optional[Dict[str, Any]] = None + for position, field in enumerate(fields, start=1): + if choice == str(position) or choice == field["name"]: + selected = field + break + if not selected: + print("Invalid field selection.") + return None + + old_name = selected["name"] + print(f"Renaming field '{old_name}'") + print( + " Warning: This will modify all documents to rename the field. " + "This is an expensive operation for large datasets." + ) + new_name = input("New field name: ").strip() + if not new_name: + print("New field name is required.") + return None + if new_name == old_name: + print("New name is the same as the old name.") + return None + + existing_names = {f["name"] for f in fields} + if new_name in existing_names: + print(f"Field '{new_name}' already exists.") + return None + + return FieldRename(old_name=old_name, new_name=new_name) + + def _prompt_rename_index(self, source_schema: Dict[str, Any]) -> Optional[str]: + """Prompt user to rename the index.""" + current_name = source_schema["index"]["name"] + print(f"Current index name: {current_name}") + print( + " Note: This only changes the index name. " + "Documents and keys are unchanged." + ) + new_name = input("New index name: ").strip() + if not new_name: + print("New index name is required.") + return None + if new_name == current_name: + print("New name is the same as the current name.") + return None + return new_name + + def _prompt_change_prefix(self, source_schema: Dict[str, Any]) -> Optional[str]: + """Prompt user to change the key prefix.""" + current_prefix = source_schema["index"]["prefix"] + print(f"Current prefix: {current_prefix}") + print( + " Warning: This will RENAME all keys from the old prefix to the new prefix. " + "This is an expensive operation for large datasets." + ) + new_prefix = input("New prefix: ").strip() + if not new_prefix: + print("New prefix is required.") + return None + if new_prefix == current_prefix: + print("New prefix is the same as the current prefix.") + return None + return new_prefix + + def _prompt_common_attrs( + self, field_type: str, allow_blank: bool = False + ) -> Dict[str, Any]: + attrs: Dict[str, Any] = {} + + # Sortable - available for all non-vector types + print(" Sortable: enables sorting and aggregation on this field") + sortable = self._prompt_bool("Sortable", allow_blank=allow_blank) + if sortable is not None: + attrs["sortable"] = sortable + + # Index missing - available for all types (requires Redis Search 2.10+) + print( + " Index missing: enables ismissing() queries for documents without this field" + ) + index_missing = self._prompt_bool("Index missing", allow_blank=allow_blank) + if index_missing is not None: + attrs["index_missing"] = index_missing + + # Index empty - index documents where field value is empty string + print( + " Index empty: enables isempty() queries for documents with empty string values" + ) + index_empty = self._prompt_bool("Index empty", allow_blank=allow_blank) + if index_empty is not None: + attrs["index_empty"] = index_empty + + # Type-specific attributes + if field_type == "text": + self._prompt_text_attrs(attrs, allow_blank) + elif field_type == "tag": + self._prompt_tag_attrs(attrs, allow_blank) + elif field_type == "numeric": + self._prompt_numeric_attrs(attrs, allow_blank, sortable) + + # No index - only meaningful with sortable + if sortable or (allow_blank and attrs.get("sortable")): + print(" No index: store field for sorting only, not searchable") + no_index = self._prompt_bool("No index", allow_blank=allow_blank) + if no_index is not None: + attrs["no_index"] = no_index + + return attrs + + def _prompt_text_attrs(self, attrs: Dict[str, Any], allow_blank: bool) -> None: + """Prompt for text field specific attributes.""" + # No stem + print( + " Disable stemming: prevents word variations (running/runs) from matching" + ) + no_stem = self._prompt_bool("Disable stemming", allow_blank=allow_blank) + if no_stem is not None: + attrs["no_stem"] = no_stem + + # Weight + print(" Weight: relevance multiplier for full-text search (default: 1.0)") + weight_input = input("Weight [leave blank for default]: ").strip() + if weight_input: + try: + weight = float(weight_input) + if weight > 0: + attrs["weight"] = weight + else: + print("Weight must be positive.") + except ValueError: + print("Invalid weight value.") + + # Phonetic matcher + print( + " Phonetic matcher: enables phonetic matching (e.g., 'dm:en' for Metaphone)" + ) + phonetic = input("Phonetic matcher [leave blank for none]: ").strip() + if phonetic: + attrs["phonetic_matcher"] = phonetic + + # UNF (only if sortable) + if attrs.get("sortable"): + print(" UNF: preserve original form (no lowercasing) for sorting") + unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) + if unf is not None: + attrs["unf"] = unf + + def _prompt_tag_attrs(self, attrs: Dict[str, Any], allow_blank: bool) -> None: + """Prompt for tag field specific attributes.""" + # Separator + print(" Separator: character that splits multiple values (default: comma)") + separator = input("Separator [leave blank to keep existing/default]: ").strip() + if separator: + attrs["separator"] = separator + + # Case sensitive + print(" Case sensitive: match tags with exact case (default: false)") + case_sensitive = self._prompt_bool("Case sensitive", allow_blank=allow_blank) + if case_sensitive is not None: + attrs["case_sensitive"] = case_sensitive + + def _prompt_numeric_attrs( + self, attrs: Dict[str, Any], allow_blank: bool, sortable: Optional[bool] + ) -> None: + """Prompt for numeric field specific attributes.""" + # UNF (only if sortable) + if sortable or attrs.get("sortable"): + print(" UNF: preserve exact numeric representation for sorting") + unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) + if unf is not None: + attrs["unf"] = unf + + def _prompt_vector_attrs(self, field: Dict[str, Any]) -> Dict[str, Any]: + attrs: Dict[str, Any] = {} + current = field.get("attrs", {}) + field_name = field["name"] + + print(f"Current vector config for '{field_name}':") + current_algo = current.get("algorithm", "hnsw").upper() + print(f" algorithm: {current_algo}") + print(f" datatype: {current.get('datatype', 'float32')}") + print(f" distance_metric: {current.get('distance_metric', 'cosine')}") + print(f" dims: {current.get('dims')} (cannot be changed)") + if current_algo == "HNSW": + print(f" m: {current.get('m', 16)}") + print(f" ef_construction: {current.get('ef_construction', 200)}") + + print("\nLeave blank to keep current value.") + + # Algorithm + print( + " Algorithm: vector search method (FLAT=brute force, HNSW=graph, SVS-VAMANA=compressed graph)" + ) + algo = ( + input(f"Algorithm [current: {current_algo}]: ") + .strip() + .upper() + .replace("_", "-") # Normalize SVS_VAMANA to SVS-VAMANA + ) + if algo and algo in ("FLAT", "HNSW", "SVS-VAMANA") and algo != current_algo: + attrs["algorithm"] = algo + + # Datatype (quantization) - show algorithm-specific options + effective_algo = attrs.get("algorithm", current_algo) + valid_datatypes: tuple[str, ...] + if effective_algo == "SVS-VAMANA": + # SVS-VAMANA only supports float16, float32 + print( + " Datatype for SVS-VAMANA: float16, float32 " + "(float16 reduces memory by ~50%)" + ) + valid_datatypes = ("float16", "float32") + else: + # FLAT/HNSW support: float16, float32, bfloat16, float64, int8, uint8 + print( + " Datatype: float16, float32, bfloat16, float64, int8, uint8\n" + " (float16 reduces memory ~50%, int8/uint8 reduce ~75%)" + ) + valid_datatypes = ( + "float16", + "float32", + "bfloat16", + "float64", + "int8", + "uint8", + ) + current_datatype = current.get("datatype", "float32") + # If switching to SVS-VAMANA and current datatype is incompatible, + # require the user to pick a valid one. + force_datatype = ( + effective_algo == "SVS-VAMANA" and current_datatype not in valid_datatypes + ) + if force_datatype: + print( + f" Current datatype '{current_datatype}' is not compatible with SVS-VAMANA. " + "You must select a valid datatype." + ) + datatype = input(f"Datatype [current: {current_datatype}]: ").strip().lower() + if datatype and datatype in valid_datatypes: + attrs["datatype"] = datatype + elif force_datatype: + # Default to float32 when user skips but current dtype is incompatible + print(" Defaulting to float32 for SVS-VAMANA compatibility.") + attrs["datatype"] = "float32" + + # Distance metric + print(" Distance metric: how similarity is measured (cosine, l2, ip)") + metric = ( + input( + f"Distance metric [current: {current.get('distance_metric', 'cosine')}]: " + ) + .strip() + .lower() + ) + if metric and metric in ("cosine", "l2", "ip"): + attrs["distance_metric"] = metric + + # Algorithm-specific params (effective_algo already computed above) + if effective_algo == "HNSW": + print( + " M: number of connections per node (higher=better recall, more memory)" + ) + m_input = input(f"M [current: {current.get('m', 16)}]: ").strip() + if m_input and m_input.isdigit(): + attrs["m"] = int(m_input) + + print( + " EF_CONSTRUCTION: build-time search depth (higher=better recall, slower build)" + ) + ef_input = input( + f"EF_CONSTRUCTION [current: {current.get('ef_construction', 200)}]: " + ).strip() + if ef_input and ef_input.isdigit(): + attrs["ef_construction"] = int(ef_input) + + print( + " EF_RUNTIME: query-time search depth (higher=better recall, slower queries)" + ) + ef_runtime_input = input( + f"EF_RUNTIME [current: {current.get('ef_runtime', 10)}]: " + ).strip() + if ef_runtime_input and ef_runtime_input.isdigit(): + ef_runtime_val = int(ef_runtime_input) + if ef_runtime_val > 0: + attrs["ef_runtime"] = ef_runtime_val + + print( + " EPSILON: relative factor for range queries (0.0-1.0, lower=more accurate)" + ) + epsilon_input = input( + f"EPSILON [current: {current.get('epsilon', 0.01)}]: " + ).strip() + if epsilon_input: + try: + epsilon_val = float(epsilon_input) + if 0.0 <= epsilon_val <= 1.0: + attrs["epsilon"] = epsilon_val + else: + print(" Epsilon must be between 0.0 and 1.0, ignoring.") + except ValueError: + print(" Invalid epsilon value, ignoring.") + + elif effective_algo == "SVS-VAMANA": + print( + " GRAPH_MAX_DEGREE: max edges per node (higher=better recall, more memory)" + ) + gmd_input = input( + f"GRAPH_MAX_DEGREE [current: {current.get('graph_max_degree', 40)}]: " + ).strip() + if gmd_input and gmd_input.isdigit(): + attrs["graph_max_degree"] = int(gmd_input) + + print(" COMPRESSION: optional vector compression for memory savings") + print(" Options: LVQ4, LVQ8, LVQ4x4, LVQ4x8, LeanVec4x8, LeanVec8x8") + print( + " Note: LVQ/LeanVec optimizations require Intel hardware with AVX-512" + ) + compression_input = ( + input("COMPRESSION [leave blank for none]: ").strip().upper() + ) + # Map input to correct enum case (CompressionType expects exact case) + compression_map = { + "LVQ4": "LVQ4", + "LVQ8": "LVQ8", + "LVQ4X4": "LVQ4x4", + "LVQ4X8": "LVQ4x8", + "LEANVEC4X8": "LeanVec4x8", + "LEANVEC8X8": "LeanVec8x8", + } + compression = compression_map.get(compression_input) + if compression: + attrs["compression"] = compression + + # Prompt for REDUCE if LeanVec compression is selected + if compression.startswith("LeanVec"): + dims = current.get("dims", 0) + recommended = dims // 2 if dims > 0 else None + print( + f" REDUCE: dimensionality reduction for LeanVec (must be < {dims})" + ) + if recommended: + print( + f" Recommended: {recommended} (dims/2 for balanced performance)" + ) + reduce_input = input(f"REDUCE [leave blank to skip]: ").strip() + if reduce_input and reduce_input.isdigit(): + reduce_val = int(reduce_input) + if reduce_val > 0 and reduce_val < dims: + attrs["reduce"] = reduce_val + else: + print( + f" Invalid: reduce must be > 0 and < {dims}, ignoring." + ) + + return attrs + + def _prompt_bool(self, label: str, allow_blank: bool = False) -> Optional[bool]: + suffix = " [y/n]" if not allow_blank else " [y/n/skip]" + while True: + value = input(f"{label}{suffix}: ").strip().lower() + if value in ("y", "yes"): + return True + if value in ("n", "no"): + return False + if allow_blank and value in ("", "skip", "s"): + return None + if not allow_blank and value == "": + return False + hint = "y, n, or skip" if allow_blank else "y or n" + print(f"Please answer {hint}.") + + def _prompt_from_choices( + self, + label: str, + choices: List[str], + *, + block_message: str, + ) -> Optional[str]: + print(f"{label} options: {', '.join(choices)}") + value = input(f"{label}: ").strip().lower() + if value not in choices: + print(block_message) + return None + return value + + def _print_source_schema(self, schema_dict: Dict[str, Any]) -> None: + print("Current schema:") + print(f"- Index name: {schema_dict['index']['name']}") + print(f"- Storage type: {schema_dict['index']['storage_type']}") + for field in schema_dict["fields"]: + path = field.get("path") + suffix = f" path={path}" if path else "" + print(f" - {field['name']} ({field['type']}){suffix}") diff --git a/tests/unit/test_migration_wizard.py b/tests/unit/test_migration_wizard.py new file mode 100644 index 000000000..bd53d6415 --- /dev/null +++ b/tests/unit/test_migration_wizard.py @@ -0,0 +1,1190 @@ +from redisvl.migration.wizard import MigrationWizard + + +def _make_vector_source_schema(algorithm="hnsw", datatype="float32"): + """Helper to create a source schema with a vector field.""" + return { + "index": { + "name": "test_index", + "prefix": "test:", + "storage_type": "hash", + }, + "fields": [ + {"name": "title", "type": "text"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": algorithm, + "dims": 384, + "distance_metric": "cosine", + "datatype": datatype, + "m": 16, + "ef_construction": 200, + }, + }, + ], + } + + +def test_wizard_builds_patch_from_interactive_inputs(monkeypatch): + source_schema = { + "index": { + "name": "docs", + "prefix": "docs", + "storage_type": "json", + }, + "fields": [ + {"name": "title", "type": "text", "path": "$.title"}, + {"name": "category", "type": "tag", "path": "$.category"}, + { + "name": "embedding", + "type": "vector", + "path": "$.embedding", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + + answers = iter( + [ + # Add field + "1", + "status", # field name + "tag", # field type + "$.status", # JSON path + "y", # sortable + "n", # index_missing + "n", # index_empty + "|", # separator (tag-specific) + "n", # case_sensitive (tag-specific) + "n", # no_index (prompted since sortable=y) + # Update field + "2", + "title", # select field + "y", # sortable + "n", # index_missing + "n", # index_empty + "n", # no_stem (text-specific) + "", # weight (blank to skip, text-specific) + "", # phonetic_matcher (blank to skip) + "n", # unf (prompted since sortable=y) + "n", # no_index (prompted since sortable=y) + # Remove field + "3", + "category", + # Finish + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) # noqa: SLF001 + + assert patch.changes.add_fields == [ + { + "name": "status", + "type": "tag", + "path": "$.status", + "attrs": { + "sortable": True, + "index_missing": False, + "index_empty": False, + "separator": "|", + "case_sensitive": False, + "no_index": False, + }, + } + ] + assert patch.changes.remove_fields == ["category"] + assert len(patch.changes.update_fields) == 1 + assert patch.changes.update_fields[0].name == "title" + assert patch.changes.update_fields[0].attrs["sortable"] is True + assert patch.changes.update_fields[0].attrs["no_stem"] is False + + +# ============================================================================= +# Vector Algorithm Tests +# ============================================================================= + + +class TestVectorAlgorithmChanges: + """Test wizard handling of vector algorithm changes.""" + + def test_hnsw_to_flat(self, monkeypatch): + """Test changing from HNSW to FLAT algorithm.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "FLAT", # Change to FLAT + "", # datatype (keep current) + "", # distance_metric (keep current) + # No HNSW params prompted for FLAT + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 1 + update = patch.changes.update_fields[0] + assert update.name == "embedding" + assert update.attrs["algorithm"] == "FLAT" + + def test_flat_to_hnsw_with_params(self, monkeypatch): + """Test changing from FLAT to HNSW with custom M and EF_CONSTRUCTION.""" + source_schema = _make_vector_source_schema(algorithm="flat") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "HNSW", # Change to HNSW + "", # datatype (keep current) + "", # distance_metric (keep current) + "32", # M + "400", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "HNSW" + assert update.attrs["m"] == 32 + assert update.attrs["ef_construction"] == 400 + + def test_hnsw_to_svs_vamana_with_underscore(self, monkeypatch): + """Test changing to SVS_VAMANA (underscore format) is normalized.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS_VAMANA", # Underscore format (should be normalized) + "float16", # SVS only supports float16/float32 + "", # distance_metric (keep current) + "64", # GRAPH_MAX_DEGREE + "LVQ8", # COMPRESSION + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" # Normalized to hyphen + assert update.attrs["datatype"] == "float16" + assert update.attrs["graph_max_degree"] == 64 + assert update.attrs["compression"] == "LVQ8" + + def test_hnsw_to_svs_vamana_with_hyphen(self, monkeypatch): + """Test changing to SVS-VAMANA (hyphen format) works directly.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS-VAMANA", # Hyphen format + "", # datatype (keep current) + "", # distance_metric (keep current) + "", # GRAPH_MAX_DEGREE (keep default) + "", # COMPRESSION (none) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + + def test_svs_vamana_with_leanvec_compression(self, monkeypatch): + """Test SVS-VAMANA with LeanVec compression type.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", # Update field + "embedding", # Select vector field + "SVS-VAMANA", + "float16", + "", # distance_metric + "48", # GRAPH_MAX_DEGREE + "LEANVEC8X8", # COMPRESSION + "192", # REDUCE (dims/2) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert update.attrs["compression"] == "LeanVec8x8" + assert update.attrs["reduce"] == 192 + + +# ============================================================================= +# Vector Datatype (Quantization) Tests +# ============================================================================= + + +class TestVectorDatatypeChanges: + """Test wizard handling of vector datatype/quantization changes.""" + + def test_float32_to_float16(self, monkeypatch): + """Test quantization from float32 to float16.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep current) + "float16", # datatype + "", # distance_metric + "", # M (keep current) + "", # EF_CONSTRUCTION (keep current) + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "float16" + + def test_float16_to_float32(self, monkeypatch): + """Test changing from float16 back to float32.""" + source_schema = _make_vector_source_schema(datatype="float16") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "float32", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "float32" + + def test_int8_accepted_for_hnsw(self, monkeypatch): + """Test that int8 is accepted for HNSW/FLAT (but not SVS-VAMANA).""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep HNSW) + "int8", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # int8 is now valid for HNSW/FLAT + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "int8" + + +# ============================================================================= +# Distance Metric Tests +# ============================================================================= + + +class TestDistanceMetricChanges: + """Test wizard handling of distance metric changes.""" + + def test_cosine_to_l2(self, monkeypatch): + """Test changing distance metric from cosine to L2.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "", # datatype + "l2", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "l2" + + def test_cosine_to_ip(self, monkeypatch): + """Test changing distance metric from cosine to inner product.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm + "", # datatype + "ip", # distance_metric (inner product) + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "ip" + + +# ============================================================================= +# Combined Changes Tests +# ============================================================================= + + +class TestCombinedVectorChanges: + """Test wizard handling of multiple vector attribute changes.""" + + def test_algorithm_datatype_and_metric_change(self, monkeypatch): + """Test changing algorithm, datatype, and distance metric together.""" + source_schema = _make_vector_source_schema(algorithm="flat", datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "HNSW", # algorithm + "float16", # datatype + "l2", # distance_metric + "24", # M + "300", # EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "HNSW" + assert update.attrs["datatype"] == "float16" + assert update.attrs["distance_metric"] == "l2" + assert update.attrs["m"] == 24 + assert update.attrs["ef_construction"] == 300 + + def test_svs_vamana_full_config(self, monkeypatch): + """Test SVS-VAMANA with all parameters configured.""" + source_schema = _make_vector_source_schema(algorithm="hnsw", datatype="float32") + + answers = iter( + [ + "2", # Update field + "embedding", + "SVS-VAMANA", # algorithm + "float16", # datatype (required for SVS) + "ip", # distance_metric + "50", # GRAPH_MAX_DEGREE + "LVQ4X8", # COMPRESSION + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert update.attrs["datatype"] == "float16" + assert update.attrs["distance_metric"] == "ip" + assert update.attrs["graph_max_degree"] == 50 + assert update.attrs["compression"] == "LVQ4x8" + + def test_no_changes_when_all_blank(self, monkeypatch): + """Test that blank inputs result in no changes.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", # Update field + "embedding", + "", # algorithm (keep current) + "", # datatype (keep current) + "", # distance_metric (keep current) + "", # M (keep current) + "", # EF_CONSTRUCTION (keep current) + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", # Finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # No changes collected means no update_fields + assert len(patch.changes.update_fields) == 0 + + +# ============================================================================= +# Adversarial / Edge Case Tests +# ============================================================================= + + +class TestWizardAdversarialInputs: + """Test wizard robustness against malformed, malicious, or edge case inputs.""" + + # ------------------------------------------------------------------------- + # Invalid Algorithm Inputs + # ------------------------------------------------------------------------- + + def test_typo_in_algorithm_ignored(self, monkeypatch): + """Test that typos in algorithm name are ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW_TYPO", # Invalid algorithm + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Invalid algorithm should be ignored, no changes + assert len(patch.changes.update_fields) == 0 + + def test_partial_algorithm_name_ignored(self, monkeypatch): + """Test that partial algorithm names are ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNS", # Partial name + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_with_special_chars_ignored(self, monkeypatch): + """Test that algorithm with special characters is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW; DROP TABLE users;--", # SQL injection attempt + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_lowercase_works(self, monkeypatch): + """Test that lowercase algorithm names work (case insensitive).""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "flat", # lowercase + "", + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "FLAT" + + def test_algorithm_mixed_case_works(self, monkeypatch): + """Test that mixed case algorithm names work.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SvS_VaMaNa", # Mixed case with underscore + "", + "", + "", + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + + # ------------------------------------------------------------------------- + # Invalid Numeric Inputs + # ------------------------------------------------------------------------- + + def test_negative_m_ignored(self, monkeypatch): + """Test that negative M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "-16", # Negative M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Negative M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_float_m_ignored(self, monkeypatch): + """Test that float M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "16.5", # Float M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Float M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_string_m_ignored(self, monkeypatch): + """Test that string M value is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "sixteen", # String M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # String M is ignored, and since algorithm/datatype/metric are unchanged, + # no update should be generated at all + assert len(patch.changes.update_fields) == 0 + + def test_zero_m_accepted(self, monkeypatch): + """Test that zero M is accepted (validation happens at schema level).""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", # datatype + "", # distance_metric + "0", # Zero M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Zero is a valid digit, wizard accepts it (validation at apply time) + update = patch.changes.update_fields[0] + assert update.attrs.get("m") == 0 + + def test_very_large_ef_construction_accepted(self, monkeypatch): + """Test that very large EF_CONSTRUCTION is accepted by wizard.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "HNSW", + "", + "", + "", + "999999999", # Very large EF_CONSTRUCTION + "", # EF_RUNTIME (keep current) + "", # EPSILON (keep current) + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["ef_construction"] == 999999999 + + # ------------------------------------------------------------------------- + # Invalid Datatype Inputs + # ------------------------------------------------------------------------- + + def test_bfloat16_accepted_for_hnsw(self, monkeypatch): + """Test that bfloat16 is accepted for HNSW/FLAT.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "bfloat16", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "bfloat16" + + def test_uint8_accepted_for_hnsw(self, monkeypatch): + """Test that uint8 is accepted for HNSW/FLAT.""" + source_schema = _make_vector_source_schema(datatype="float32") + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "uint8", # Valid for HNSW/FLAT + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["datatype"] == "uint8" + + def test_int8_rejected_for_svs_vamana(self, monkeypatch): + """Test that int8 is rejected for SVS-VAMANA (only float16/float32 allowed).""" + source_schema = _make_vector_source_schema(datatype="float32", algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", # Switch to SVS-VAMANA + "int8", # Invalid for SVS-VAMANA + "", + "", + "", # graph_max_degree + "", # compression + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should have algorithm change but NOT datatype + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "SVS-VAMANA" + assert "datatype" not in update.attrs # int8 rejected + + # ------------------------------------------------------------------------- + # Invalid Distance Metric Inputs + # ------------------------------------------------------------------------- + + def test_invalid_distance_metric_ignored(self, monkeypatch): + """Test that invalid distance metric is ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "", # datatype + "euclidean", # Invalid (should be 'l2') + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_distance_metric_uppercase_works(self, monkeypatch): + """Test that uppercase distance metric works.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "", # algorithm + "", # datatype + "L2", # Uppercase + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["distance_metric"] == "l2" + + # ------------------------------------------------------------------------- + # Invalid Compression Inputs + # ------------------------------------------------------------------------- + + def test_invalid_compression_ignored(self, monkeypatch): + """Test that invalid compression type is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "", + "INVALID_COMPRESSION", # Invalid + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "compression" not in update.attrs + + def test_compression_lowercase_works(self, monkeypatch): + """Test that lowercase compression works.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "", + "lvq8", # lowercase + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["compression"] == "LVQ8" + + # ------------------------------------------------------------------------- + # Whitespace and Special Character Inputs + # ------------------------------------------------------------------------- + + def test_whitespace_only_treated_as_blank(self, monkeypatch): + """Test that whitespace-only input is treated as blank.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + " ", # Whitespace only (algorithm) + " ", # datatype + " ", # distance_metric + " ", # M + " ", # EF_CONSTRUCTION + " ", # EF_RUNTIME + " ", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_algorithm_with_leading_trailing_whitespace(self, monkeypatch): + """Test that algorithm with whitespace is trimmed and works.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + " FLAT ", # Whitespace around (FLAT has no extra params) + "", # datatype + "", # distance_metric + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert update.attrs["algorithm"] == "FLAT" + + def test_unicode_input_ignored(self, monkeypatch): + """Test that unicode/emoji inputs are ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "HNSW\U0001f680", # Unicode emoji + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_very_long_input_ignored(self, monkeypatch): + """Test that very long inputs are ignored.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "embedding", + "A" * 10000, # Very long string + "", # datatype + "", # distance_metric + "", # M + "", # EF_CONSTRUCTION + "", # EF_RUNTIME + "", # EPSILON + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + # ------------------------------------------------------------------------- + # Field Selection Edge Cases + # ------------------------------------------------------------------------- + + def test_nonexistent_field_selection(self, monkeypatch): + """Test selecting a nonexistent field.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "nonexistent_field", # Doesn't exist + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should print "Invalid field selection" and continue + assert len(patch.changes.update_fields) == 0 + + def test_field_selection_by_number_out_of_range(self, monkeypatch): + """Test selecting a field by out-of-range number.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "99", # Out of range + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + def test_field_selection_negative_number(self, monkeypatch): + """Test selecting a field with negative number.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "2", + "-1", # Negative + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + assert len(patch.changes.update_fields) == 0 + + # ------------------------------------------------------------------------- + # Menu Action Edge Cases + # ------------------------------------------------------------------------- + + def test_invalid_menu_action(self, monkeypatch): + """Test invalid menu action selection.""" + source_schema = _make_vector_source_schema() + + answers = iter( + [ + "99", # Invalid action + "abc", # Invalid action + "", # Empty + "8", # Finally finish + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + # Should handle invalid actions gracefully and eventually finish + assert patch is not None + + # ------------------------------------------------------------------------- + # SVS-VAMANA Specific Edge Cases + # ------------------------------------------------------------------------- + + def test_svs_vamana_negative_graph_max_degree_ignored(self, monkeypatch): + """Test that negative GRAPH_MAX_DEGREE is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "-40", # Negative + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "graph_max_degree" not in update.attrs + + def test_svs_vamana_string_graph_max_degree_ignored(self, monkeypatch): + """Test that string GRAPH_MAX_DEGREE is ignored.""" + source_schema = _make_vector_source_schema(algorithm="hnsw") + + answers = iter( + [ + "2", + "embedding", + "SVS-VAMANA", + "", + "", + "forty", # String + "", + "8", + ] + ) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + wizard = MigrationWizard() + patch = wizard._build_patch(source_schema) + + update = patch.changes.update_fields[0] + assert "graph_max_degree" not in update.attrs From dfa069a2d8b9536d1080ec334bd760fd597e32d7 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Thu, 2 Apr 2026 16:02:28 -0400 Subject: [PATCH 2/2] fix: address codex review for PR2 (wizard) - Improve field removal to clean up renames by both old_name and new_name - Resolve update names through rename map in working schema preview - Add multi-prefix guard to reject indexes with multiple prefixes - Fix dependent prompts (UNF, no_index) when field is already sortable - Pass existing field attrs to common attrs prompts for update mode --- redisvl/migration/wizard.py | 67 +++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/redisvl/migration/wizard.py b/redisvl/migration/wizard.py index 3127b0497..24ec15cde 100644 --- a/redisvl/migration/wizard.py +++ b/redisvl/migration/wizard.py @@ -21,6 +21,7 @@ class MigrationWizard: def __init__(self, planner: Optional[MigrationPlanner] = None): self.planner = planner or MigrationPlanner() + self._existing_sortable: bool = False def run( self, @@ -45,6 +46,15 @@ def run( ) source_schema = IndexSchema.from_dict(snapshot.schema_snapshot) + # Guard: the wizard does not support indexes with multiple prefixes. + prefixes = source_schema.index.prefix + if isinstance(prefixes, list) and len(prefixes) > 1: + raise ValueError( + f"Index '{resolved_index_name}' has multiple prefixes " + f"({prefixes}). The migration wizard only supports single-prefix " + "indexes. Use the planner API directly for multi-prefix indexes." + ) + print(f"Building a migration plan for index '{resolved_index_name}'") self._print_source_schema(source_schema.to_dict()) @@ -156,8 +166,13 @@ def _apply_staged_changes( if field["name"] in rename_map: field["name"] = rename_map[field["name"]] - # Apply updates (reflect attribute changes in working schema) - update_map = {u.name: u for u in changes.update_fields} + # Apply updates (reflect attribute changes in working schema). + # Resolve update names through the rename map so that updates staged + # before a rename (referencing the old name) still match. + update_map = {} + for u in changes.update_fields: + resolved = rename_map.get(u.name, u.name) + update_map[resolved] = u for field in working["fields"]: if field["name"] in update_map: upd = update_map[field["name"]] @@ -245,12 +260,27 @@ def _build_patch( print(f"Cancelled staged addition of '{field_name}'.") else: changes.remove_fields.append(field_name) - # Also remove any queued updates or renames for this field + # Also remove any queued updates or renames for this field. + # Check both old_name and new_name so that: + # - renames FROM this field are dropped (old_name match) + # - renames TO this field are dropped (new_name match) + # Also drop updates referencing either the field itself or + # any pre-rename name that mapped to it. + rename_aliases = {field_name} + for r in changes.rename_fields: + if r.new_name == field_name: + rename_aliases.add(r.old_name) + if r.old_name == field_name: + rename_aliases.add(r.new_name) changes.update_fields = [ - u for u in changes.update_fields if u.name != field_name + u + for u in changes.update_fields + if u.name not in rename_aliases ] changes.rename_fields = [ - r for r in changes.rename_fields if r.old_name != field_name + r + for r in changes.rename_fields + if r.old_name != field_name and r.new_name != field_name ] elif action == "4": # Filter out staged additions from rename candidates @@ -357,7 +387,11 @@ def _prompt_update_field( if selected["type"] == "vector": attrs = self._prompt_vector_attrs(selected) else: - attrs = self._prompt_common_attrs(selected["type"], allow_blank=True) + attrs = self._prompt_common_attrs( + selected["type"], + allow_blank=True, + existing_attrs=selected.get("attrs"), + ) if not attrs: print("No changes collected.") return None @@ -485,7 +519,10 @@ def _prompt_change_prefix(self, source_schema: Dict[str, Any]) -> Optional[str]: return new_prefix def _prompt_common_attrs( - self, field_type: str, allow_blank: bool = False + self, + field_type: str, + allow_blank: bool = False, + existing_attrs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: attrs: Dict[str, Any] = {} @@ -511,6 +548,11 @@ def _prompt_common_attrs( if index_empty is not None: attrs["index_empty"] = index_empty + # Track whether the field was already sortable so that type-specific + # prompt helpers (text UNF, numeric UNF) can offer dependent prompts + # even when the user leaves sortable blank during an update. + self._existing_sortable = (existing_attrs or {}).get("sortable", False) + # Type-specific attributes if field_type == "text": self._prompt_text_attrs(attrs, allow_blank) @@ -519,8 +561,11 @@ def _prompt_common_attrs( elif field_type == "numeric": self._prompt_numeric_attrs(attrs, allow_blank, sortable) - # No index - only meaningful with sortable - if sortable or (allow_blank and attrs.get("sortable")): + # No index - only meaningful with sortable. + # When updating (allow_blank), also check the existing field's sortable + # state so we offer dependent prompts even if the user left sortable blank. + _existing_sortable = self._existing_sortable + if sortable or (allow_blank and (_existing_sortable or attrs.get("sortable"))): print(" No index: store field for sorting only, not searchable") no_index = self._prompt_bool("No index", allow_blank=allow_blank) if no_index is not None: @@ -560,7 +605,7 @@ def _prompt_text_attrs(self, attrs: Dict[str, Any], allow_blank: bool) -> None: attrs["phonetic_matcher"] = phonetic # UNF (only if sortable) - if attrs.get("sortable"): + if attrs.get("sortable") or self._existing_sortable: print(" UNF: preserve original form (no lowercasing) for sorting") unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) if unf is not None: @@ -585,7 +630,7 @@ def _prompt_numeric_attrs( ) -> None: """Prompt for numeric field specific attributes.""" # UNF (only if sortable) - if sortable or attrs.get("sortable"): + if sortable or attrs.get("sortable") or self._existing_sortable: print(" UNF: preserve exact numeric representation for sorting") unf = self._prompt_bool("UNF (un-normalized form)", allow_blank=allow_blank) if unf is not None: