Skip to content

Commit 4af616a

Browse files
committed
fix: batch planner/executor bug fixes and improvements (#563)
- Fix status mismatch: executor writes 'success' to match BatchState.success_count - Pass rename_operations to get_vector_datatype_changes - Validate failure_policy early (reject unknown values) - Make update_fields applicability rename-aware - Fix progress position during resume (correct offset) - Fix fail-fast: leave remaining in state for checkpoint resume - Atomic checkpoint writes (write to .tmp then rename) - Sanitize index_name in report filenames (path traversal) - Add assert guard for fnmatch pattern type
1 parent 6b430d1 commit 4af616a

2 files changed

Lines changed: 54 additions & 22 deletions

File tree

redisvl/migration/batch_executor.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,17 @@ def apply(
8080
applicable_indexes = [idx for idx in batch_plan.indexes if idx.applicable]
8181
total = len(applicable_indexes)
8282

83+
# Calculate the correct starting position for progress reporting
84+
# (accounts for already-completed indexes during resume)
85+
already_completed = len(state.completed)
86+
8387
# Process each remaining index
84-
for position, index_name in enumerate(state.remaining[:], start=1):
88+
for offset, index_name in enumerate(state.remaining[:]):
8589
state.current_index = index_name
8690
state.updated_at = timestamp_utc()
8791
self._write_state(state, state_path)
8892

93+
position = already_completed + offset + 1
8994
if progress_callback:
9095
progress_callback(index_name, position, total, "starting")
9196

@@ -133,18 +138,8 @@ def apply(
133138
index_state.status == "failed"
134139
and batch_plan.failure_policy == "fail_fast"
135140
):
136-
# Mark remaining as skipped
137-
for remaining_name in state.remaining[:]:
138-
state.remaining.remove(remaining_name)
139-
state.completed.append(
140-
BatchIndexState(
141-
name=remaining_name,
142-
status="skipped",
143-
completed_at=timestamp_utc(),
144-
)
145-
)
146-
state.updated_at = timestamp_utc()
147-
self._write_state(state, state_path)
141+
# Leave remaining indexes in state.remaining so that
142+
# checkpoint resume can pick them up later.
148143
break
149144

150145
# Build final report
@@ -225,13 +220,14 @@ def _migrate_single_index(
225220
redis_client=redis_client,
226221
)
227222

228-
# Write individual report
229-
report_file = report_dir / f"{index_name}_report.yaml"
223+
# Sanitize index_name to prevent path traversal
224+
safe_name = index_name.replace("/", "_").replace("\\", "_").replace("..", "_")
225+
report_file = report_dir / f"{safe_name}_report.yaml"
230226
write_yaml(report.model_dump(exclude_none=True), str(report_file))
231227

232228
return BatchIndexState(
233229
name=index_name,
234-
status="succeeded" if report.result == "succeeded" else "failed",
230+
status="success" if report.result == "succeeded" else "failed",
235231
completed_at=timestamp_utc(),
236232
report_path=str(report_file),
237233
error=report.validation.errors[0] if report.validation.errors else None,
@@ -277,11 +273,18 @@ def _init_or_load_state(
277273
)
278274

279275
def _write_state(self, state: BatchState, state_path: str) -> None:
280-
"""Write checkpoint state to file."""
276+
"""Write checkpoint state to file atomically.
277+
278+
Writes to a temporary file first, then renames to avoid corruption
279+
if the process crashes mid-write.
280+
"""
281281
path = Path(state_path).resolve()
282282
path.parent.mkdir(parents=True, exist_ok=True)
283-
with open(path, "w") as f:
283+
tmp_path = path.with_suffix(".tmp")
284+
with open(tmp_path, "w") as f:
284285
yaml.safe_dump(state.model_dump(exclude_none=True), f, sort_keys=False)
286+
f.flush()
287+
tmp_path.replace(path)
285288

286289
def _load_state(self, state_path: str) -> BatchState:
287290
"""Load checkpoint state from file."""
@@ -323,13 +326,24 @@ def _build_batch_report(
323326
error=idx_state.error,
324327
)
325328
)
326-
if idx_state.status == "succeeded":
329+
if idx_state.status in ("succeeded", "success"):
327330
succeeded += 1
328331
elif idx_state.status == "failed":
329332
failed += 1
330333
else:
331334
skipped += 1
332335

336+
# Add remaining indexes (fail-fast left them pending) as skipped
337+
for remaining_name in state.remaining:
338+
index_reports.append(
339+
BatchIndexReport(
340+
name=remaining_name,
341+
status="skipped",
342+
error="Skipped due to fail_fast policy",
343+
)
344+
)
345+
skipped += 1
346+
333347
# Add non-applicable indexes as skipped
334348
for idx in batch_plan.indexes:
335349
if not idx.applicable:

redisvl/migration/batch_planner.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def create_batch_plan(
3737
redis_client: Optional[Any] = None,
3838
failure_policy: str = "fail_fast",
3939
) -> BatchPlan:
40+
# --- NEW: validate failure_policy early ---
4041
"""Create a batch migration plan for multiple indexes.
4142
4243
Args:
@@ -51,6 +52,13 @@ def create_batch_plan(
5152
Returns:
5253
BatchPlan with shared patch and per-index applicability.
5354
"""
55+
_VALID_FAILURE_POLICIES = {"fail_fast", "continue_on_error"}
56+
if failure_policy not in _VALID_FAILURE_POLICIES:
57+
raise ValueError(
58+
f"Invalid failure_policy '{failure_policy}'. "
59+
f"Must be one of: {sorted(_VALID_FAILURE_POLICIES)}"
60+
)
61+
5462
# Get Redis client
5563
client = redis_client
5664
if client is None:
@@ -95,6 +103,7 @@ def create_batch_plan(
95103
datatype_changes = MigrationPlanner.get_vector_datatype_changes(
96104
plan.source.schema_snapshot,
97105
plan.merged_target_schema,
106+
rename_operations=plan.rename_operations,
98107
)
99108
if datatype_changes:
100109
requires_quantization = True
@@ -134,7 +143,8 @@ def _resolve_index_names(
134143
if indexes_file:
135144
return self._load_indexes_from_file(indexes_file)
136145

137-
# Pattern matching
146+
# Pattern matching -- pattern is guaranteed non-None at this point
147+
assert pattern is not None, "pattern must be set when reaching fnmatch"
138148
all_indexes = list_indexes(redis_client=redis_client)
139149
matched = [idx for idx in all_indexes if fnmatch.fnmatch(idx, pattern)]
140150
return sorted(matched)
@@ -167,10 +177,18 @@ def _check_index_applicability(
167177
schema_dict = index.schema.to_dict()
168178
field_names = {f["name"] for f in schema_dict.get("fields", [])}
169179

170-
# Check that all update_fields exist in this index
180+
# Build a set of field names that includes rename targets so
181+
# that update_fields referencing the NEW name of a renamed field
182+
# are considered applicable.
183+
rename_target_names = {
184+
fr.new_name for fr in shared_patch.changes.rename_fields
185+
}
186+
effective_field_names = field_names | rename_target_names
187+
188+
# Check that all update_fields exist in this index (or are rename targets)
171189
missing_fields = []
172190
for field_update in shared_patch.changes.update_fields:
173-
if field_update.name not in field_names:
191+
if field_update.name not in effective_field_names:
174192
missing_fields.append(field_update.name)
175193

176194
if missing_fields:

0 commit comments

Comments
 (0)