diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index cbfcef907d81..871985a369f4 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -26,6 +26,7 @@ from pypaimon.ray.data_evolution_merge_join import ( build_matched_update_ds, build_not_matched_insert_ds, + build_self_merge_update_ds, distributed_update_apply, distributed_write_collect_msgs, ) @@ -53,6 +54,7 @@ class _PrepareCtx: update_pa_schema: pa.Schema full_pa_schema: pa.Schema catalog_options: Dict[str, str] + is_self_merge: bool = False def merge_into( @@ -178,33 +180,45 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on _NormalizedClause(spec=spec, condition=c.condition) ) - source_snapshot_id = None - if isinstance(source, str): - source_snapshot = ( - catalog.get_table(source) - .snapshot_manager() - .get_latest_snapshot() + is_self_merge = _is_self_merge(target, source, target_on_cols, source_on_cols) + if is_self_merge and not_matched_specs: + raise ValueError( + "Self-merge (source == target with ON _ROW_ID) does not " + "support WHEN NOT MATCHED clauses." ) - if source_snapshot is not None: - source_snapshot_id = source_snapshot.id - source_ds = _normalize_source( - source, catalog_options, source_snapshot_id=source_snapshot_id, - ) - _validate_source_on_cols(source_ds, source_on_cols) + if is_self_merge: + source_ds = None + source_col_names = set(full_target_field_names) | set(source_on_cols) + else: + source_snapshot_id = None + if isinstance(source, str): + source_snapshot = ( + catalog.get_table(source) + .snapshot_manager() + .get_latest_snapshot() + ) + if source_snapshot is not None: + source_snapshot_id = source_snapshot.id + source_ds = _normalize_source( + source, catalog_options, source_snapshot_id=source_snapshot_id, + ) + _validate_source_on_cols(source_ds, source_on_cols) + source_col_names = set(_source_schema_or_raise(source_ds).names) _validate_source_has_target_cols( - source_ds, matched_specs + not_matched_specs, + source_col_names, matched_specs + not_matched_specs, ) if has_condition: from pypaimon.ray.merge_condition import extract_columns - source_names = set(_source_schema_or_raise(source_ds).names) target_names = set(full_target_field_names) + if is_self_merge: + target_names |= set(target_on_cols) for c in list(when_matched) + list(when_not_matched): if c.condition is not None: for ref in extract_columns(c.condition): prefix, col = ref.split(".", 1) - if prefix == "s" and col not in source_names: + if prefix == "s" and col not in source_col_names: raise ValueError( f"condition references unknown source " f"column '{col}'" @@ -233,10 +247,20 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on update_pa_schema=update_pa_schema, full_pa_schema=full_pa_schema, catalog_options=catalog_options, + is_self_merge=is_self_merge, ) return table, source_ds, matched_specs, not_matched_specs, ctx +def _is_self_merge(target, source, target_on_cols, source_on_cols) -> bool: + from pypaimon.table.special_fields import SpecialFields + row_id_name = SpecialFields.ROW_ID.name + return (isinstance(source, str) + and source == target + and target_on_cols == [row_id_name] + and source_on_cols == [row_id_name]) + + def _build_datasets( target, source_ds, matched_specs, not_matched_specs, ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args, @@ -250,6 +274,22 @@ def _build_datasets( insert_ds = None update_cols_union: List[str] = [] + if ctx.is_self_merge: + if matched_specs and base_snapshot is not None: + update_cols_union = _union_update_cols(matched_specs) + update_ds = build_self_merge_update_ds( + target_identifier=target, + clauses=matched_specs, + target_field_names=ctx.full_target_field_names, + target_pa_schema=ctx.update_pa_schema, + update_cols=update_cols_union, + catalog_options=ctx.catalog_options, + resolve_target_projection=_resolve_target_projection, + snapshot_id=base_snapshot_id, + ray_remote_args=ray_remote_args, + ) + return update_ds, insert_ds, update_cols_union + # Mirror Spark: matched/not-matched run as two independent joins # (inner / left_anti). One unified left_outer join would force # joined.materialize() to feed both branches, which can OOM on large merges. @@ -561,16 +601,15 @@ def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None: def _validate_source_has_target_cols( - source_ds, + source_col_names: set, specs: List[_NormalizedClause], ) -> None: - names = set(_source_schema_or_raise(source_ds).names) needed = set() for clause in specs: for val in clause.spec.values(): if isinstance(val, SourceColumnRef): needed.add(val.column) - missing = sorted(needed - names) + missing = sorted(needed - source_col_names) if missing: raise ValueError( f"source is missing columns {missing} referenced by SET spec" diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py index 14088979f893..4ad91e7da117 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py @@ -21,6 +21,7 @@ import pyarrow as pa from pypaimon.ray.data_evolution_merge_transform import ( + SourceColumnRef, _NormalizedClause, build_update_schema, vectorized_insert_transform, @@ -40,6 +41,137 @@ def _map_kwargs( return kwargs +def _build_matched_transform( + clauses: List[_NormalizedClause], + on_map: Dict[str, str], + on_pairs: List[Tuple[str, str]], + update_cols: List[str], + row_id_name: str, + update_schema: pa.Schema, +): + prepared_clauses = [] + for clause in clauses: + rewritten = None + if clause.condition is not None: + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = remap_source_on_keys( + rewrite_condition(clause.condition), on_map, + ) + prepared_clauses.append((clause.spec, rewritten)) + + _filter_batch = None + if any(r is not None for _, r in prepared_clauses): + from pypaimon.ray.merge_condition import filter_batch as _filter_batch + + def _transform(batch: pa.Table) -> pa.Table: + remaining = batch + parts = [] + for spec, rewritten in prepared_clauses: + if remaining.num_rows == 0: + break + if rewritten is not None: + matched = _filter_batch( + remaining, rewritten, _pre_rewritten=True, + ) + else: + matched = remaining + if matched.num_rows == 0: + continue + parts.append(vectorized_matched_transform( + matched, spec, on_pairs, + update_cols, row_id_name, + update_schema, + )) + if rewritten is not None and matched.num_rows < remaining.num_rows: + not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" + remaining = _filter_batch( + remaining, not_cond, _pre_rewritten=True, + ) + else: + remaining = remaining.slice(0, 0) + if not parts: + return update_schema.empty_table() + return pa.concat_tables(parts) + + return _transform + + +def build_self_merge_update_ds( + *, + target_identifier: str, + clauses: List[_NormalizedClause], + target_field_names: Sequence[str], + target_pa_schema: pa.Schema, + update_cols: Sequence[str], + catalog_options: Dict[str, str], + resolve_target_projection, + snapshot_id: Optional[int] = None, + ray_remote_args: Optional[Dict[str, Any]] = None, +) -> Tuple: + from pypaimon.ray.ray_paimon import read_paimon + from pypaimon.table.special_fields import SpecialFields + + row_id_name = SpecialFields.ROW_ID.name + needed_cols = set(resolve_target_projection( + clauses, [row_id_name], update_cols, target_field_names, + )) + for clause in clauses: + for value in clause.spec.values(): + if isinstance(value, SourceColumnRef): + needed_cols.add(value.column) + target_set = set(target_field_names) + for clause in clauses: + if clause.condition is not None: + from pypaimon.ray.merge_condition import extract_columns + for ref in extract_columns(clause.condition): + prefix, col = ref.split(".", 1) + if prefix == "s" and col in target_set: + needed_cols.add(col) + projection = [row_id_name] + [ + c for c in target_field_names if c in needed_cols + ] + + target_ds = read_paimon( + target_identifier, catalog_options, + projection=projection, snapshot_id=snapshot_id, + ) + update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name) + + orig_names = target_ds.schema().names + target_renamed = target_ds.rename_columns( + {c: f"t.{c}" for c in orig_names} + ) + + def _add_source_aliases(batch: pa.Table) -> pa.Table: + columns = list(batch.columns) + names = list(batch.schema.names) + for orig in orig_names: + if orig == row_id_name: + continue + t_col_name = f"t.{orig}" + if t_col_name in names: + idx = names.index(t_col_name) + columns.append(columns[idx]) + names.append(f"s.{orig}") + return pa.table(columns, names=names) + + aliased = target_renamed.map_batches( + _add_source_aliases, **_map_kwargs(ray_remote_args), + ) + + _transform = _build_matched_transform( + clauses, + on_map={row_id_name: row_id_name}, + on_pairs=[(row_id_name, row_id_name)], + update_cols=list(update_cols), + row_id_name=row_id_name, + update_schema=update_schema, + ) + return aliased.map_batches(_transform, **_map_kwargs(ray_remote_args)) + + def build_matched_update_ds( *, target_identifier: str, @@ -87,58 +219,14 @@ def build_matched_update_ds( right_on=tuple(f"s.{c}" for c in source_on), ) - captured_update_cols = list(update_cols) - captured_row_id_name = row_id_name - captured_on_pairs = list(zip(source_on, target_on)) - captured_schema = update_schema - - on_map = dict(zip(source_on, target_on)) - prepared_clauses = [] - for clause in clauses: - rewritten = None - if clause.condition is not None: - from pypaimon.ray.merge_condition import ( - remap_source_on_keys, rewrite_condition, - ) - rewritten = remap_source_on_keys( - rewrite_condition(clause.condition), on_map, - ) - prepared_clauses.append((clause.spec, rewritten)) - - _filter_batch = None - if any(r is not None for _, r in prepared_clauses): - from pypaimon.ray.merge_condition import filter_batch as _filter_batch - - def _transform(batch: pa.Table) -> pa.Table: - remaining = batch - parts = [] - for spec, rewritten in prepared_clauses: - if remaining.num_rows == 0: - break - if rewritten is not None: - matched = _filter_batch( - remaining, rewritten, _pre_rewritten=True, - ) - else: - matched = remaining - if matched.num_rows == 0: - continue - parts.append(vectorized_matched_transform( - matched, spec, captured_on_pairs, - captured_update_cols, captured_row_id_name, - captured_schema, - )) - if rewritten is not None and matched.num_rows < remaining.num_rows: - not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" - remaining = _filter_batch( - remaining, not_cond, _pre_rewritten=True, - ) - else: - remaining = remaining.slice(0, 0) - if not parts: - return captured_schema.empty_table() - return pa.concat_tables(parts) - + _transform = _build_matched_transform( + clauses, + on_map=dict(zip(source_on, target_on)), + on_pairs=list(zip(source_on, target_on)), + update_cols=list(update_cols), + row_id_name=row_id_name, + update_schema=update_schema, + ) return joined.map_batches(_transform, **_map_kwargs(ray_remote_args)) diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index b54eeb5cf0c9..b40844cdf10b 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -1648,6 +1648,363 @@ def test_multi_clause_duplicate_both_actionable_raises(self): ) self.assertIn('multiple source rows', str(ctx.exception)) + def test_self_merge_update_literal(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update={'age': lit(99)})], + ) + + self.assertEqual(result['num_matched'], 3) + out = self._read_sorted(target) + self.assertEqual(out['age'], [99, 99, 99]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + + def test_self_merge_update_star(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update='*')], + ) + + self.assertEqual(result['num_matched'], 3) + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + self.assertEqual(out['age'], [10, 20, 30]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update={'age': lit(99)}, condition='t.age > 15')], + ) + + self.assertEqual(result['num_matched'], 2) + out = self._read_sorted(target) + self.assertEqual(out['age'], [10, 99, 99]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_with_source_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched( + update={'name': lit('updated')}, + condition='s.age > 15', + )], + ) + + self.assertEqual(result['num_matched'], 2) + out = self._read_sorted(target) + self.assertEqual(out['name'], ['a', 'updated', 'updated']) + self.assertEqual(out['age'], [10, 20, 30]) + + def test_self_merge_rejects_not_matched(self): + target = self._create_table() + self._write(target, self._source(ids=(1,))) + + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update='*')], + when_not_matched=[WhenNotMatched(insert='*')], + ) + self.assertIn('Self-merge', str(ctx.exception)) + + def test_self_merge_partial_set(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['old_a', 'old_b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update={'name': lit('updated')})], + ) + + self.assertEqual(result['num_matched'], 2) + out = self._read_sorted(target) + self.assertEqual(out['name'], ['updated', 'updated']) + self.assertEqual(out['age'], [10, 20]) + + def test_self_merge_source_col_row_id(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[WhenMatched(update={'name': source_col('_ROW_ID')})], + ) + + self.assertEqual(result['num_matched'], 2) + out = self._read_sorted(target) + for v in out['name']: + self.assertTrue(int(v) >= 0) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_condition_on_row_id(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[ + WhenMatched( + update={'age': lit(99)}, + condition='s._ROW_ID >= 0', + ), + ], + ) + + self.assertEqual(result['num_matched'], 3) + out = self._read_sorted(target) + self.assertEqual(out['age'], [99, 99, 99]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_condition_on_target_row_id(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[ + WhenMatched( + update={'age': lit(99)}, + condition='t._ROW_ID >= 0', + ), + ], + ) + + self.assertEqual(result['num_matched'], 3) + out = self._read_sorted(target) + self.assertEqual(out['age'], [99, 99, 99]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_multi_clause_fall_through(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + result = merge_into( + target=target, + source=target, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[ + WhenMatched(update={'name': lit('old')}, condition='s.age <= 10'), + WhenMatched(update={'name': lit('young')}, condition='s.age <= 20'), + WhenMatched(update={'name': lit('senior')}), + ], + ) + + self.assertEqual(result['num_matched'], 3) + out = self._read_sorted(target) + self.assertEqual(out['name'], ['old', 'young', 'senior']) + self.assertEqual(out['age'], [10, 20, 30]) + + @unittest.skip("blocked by blob DE sequence bug fix, see PR #8147") + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_blob_source_condition(self): + blob_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('picture', pa.large_binary()), + ]) + tbl_name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(blob_schema, options=self.de_options) + self.catalog.create_table(tbl_name, s, False) + + self._write( + tbl_name, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'picture': [None, None], + }, + schema=blob_schema, + ), + ) + + result = merge_into( + target=tbl_name, + source=tbl_name, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[ + WhenMatched( + update={'name': lit('updated')}, + condition='s.picture IS NULL', + ), + ], + ) + + self.assertEqual(result['num_matched'], 2) + out = self._read_sorted(tbl_name) + self.assertEqual(out['name'], ['updated', 'updated']) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_self_merge_blob_target_condition_rejected(self): + blob_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('picture', pa.large_binary()), + ]) + tbl_name = f'default.tbl_{uuid.uuid4().hex[:8]}' + s = Schema.from_pyarrow_schema(blob_schema, options=self.de_options) + self.catalog.create_table(tbl_name, s, False) + + self._write( + tbl_name, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'picture': [None], + }, + schema=blob_schema, + ), + ) + + with self.assertRaises(ValueError) as ctx: + merge_into( + target=tbl_name, + source=tbl_name, + catalog_options=self.catalog_options, + on=['_ROW_ID'], + when_matched=[ + WhenMatched( + update={'name': lit('x')}, + condition='t.picture IS NOT NULL', + ), + ], + ) + self.assertIn('blob', str(ctx.exception).lower()) + class TargetProjectionTest(unittest.TestCase):