Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions paimon-python/pypaimon/ray/data_evolution_merge_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pypaimon.ray.data_evolution_merge_join import (
build_matched_update_ds,
build_not_matched_insert_ds,
build_self_merge_update_ds,
distributed_update_apply,
distributed_write_collect_msgs,
)
Expand Down Expand Up @@ -53,6 +54,7 @@ class _PrepareCtx:
update_pa_schema: pa.Schema
full_pa_schema: pa.Schema
catalog_options: Dict[str, str]
is_self_merge: bool = False


def merge_into(
Expand Down Expand Up @@ -178,33 +180,45 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
_NormalizedClause(spec=spec, condition=c.condition)
)

source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
is_self_merge = _is_self_merge(target, source, target_on_cols, source_on_cols)
if is_self_merge and not_matched_specs:
raise ValueError(
"Self-merge (source == target with ON _ROW_ID) does not "
"support WHEN NOT MATCHED clauses."
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id

source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
if is_self_merge:
source_ds = None
source_col_names = set(full_target_field_names) | set(source_on_cols)
else:
source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id
source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
source_col_names = set(_source_schema_or_raise(source_ds).names)
_validate_source_has_target_cols(
source_ds, matched_specs + not_matched_specs,
source_col_names, matched_specs + not_matched_specs,
)

if has_condition:
from pypaimon.ray.merge_condition import extract_columns
source_names = set(_source_schema_or_raise(source_ds).names)
target_names = set(full_target_field_names)
if is_self_merge:
target_names |= set(target_on_cols)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
for ref in extract_columns(c.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col not in source_names:
if prefix == "s" and col not in source_col_names:
raise ValueError(
f"condition references unknown source "
f"column '{col}'"
Expand Down Expand Up @@ -233,10 +247,20 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
update_pa_schema=update_pa_schema,
full_pa_schema=full_pa_schema,
catalog_options=catalog_options,
is_self_merge=is_self_merge,
)
return table, source_ds, matched_specs, not_matched_specs, ctx


def _is_self_merge(target, source, target_on_cols, source_on_cols) -> bool:
from pypaimon.table.special_fields import SpecialFields
row_id_name = SpecialFields.ROW_ID.name
return (isinstance(source, str)
and source == target
and target_on_cols == [row_id_name]
and source_on_cols == [row_id_name])


def _build_datasets(
target, source_ds, matched_specs, not_matched_specs,
ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args,
Expand All @@ -250,6 +274,22 @@ def _build_datasets(
insert_ds = None
update_cols_union: List[str] = []

if ctx.is_self_merge:
if matched_specs and base_snapshot is not None:
update_cols_union = _union_update_cols(matched_specs)
update_ds = build_self_merge_update_ds(
target_identifier=target,
clauses=matched_specs,
target_field_names=ctx.full_target_field_names,
target_pa_schema=ctx.update_pa_schema,
update_cols=update_cols_union,
catalog_options=ctx.catalog_options,
resolve_target_projection=_resolve_target_projection,
snapshot_id=base_snapshot_id,
ray_remote_args=ray_remote_args,
)
return update_ds, insert_ds, update_cols_union

# Mirror Spark: matched/not-matched run as two independent joins
# (inner / left_anti). One unified left_outer join would force
# joined.materialize() to feed both branches, which can OOM on large merges.
Expand Down Expand Up @@ -561,16 +601,15 @@ def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None:


def _validate_source_has_target_cols(
source_ds,
source_col_names: set,
specs: List[_NormalizedClause],
) -> None:
names = set(_source_schema_or_raise(source_ds).names)
needed = set()
for clause in specs:
for val in clause.spec.values():
if isinstance(val, SourceColumnRef):
needed.add(val.column)
missing = sorted(needed - names)
missing = sorted(needed - source_col_names)
if missing:
raise ValueError(
f"source is missing columns {missing} referenced by SET spec"
Expand Down
192 changes: 140 additions & 52 deletions paimon-python/pypaimon/ray/data_evolution_merge_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pyarrow as pa

from pypaimon.ray.data_evolution_merge_transform import (
SourceColumnRef,
_NormalizedClause,
build_update_schema,
vectorized_insert_transform,
Expand All @@ -40,6 +41,137 @@ def _map_kwargs(
return kwargs


def _build_matched_transform(
clauses: List[_NormalizedClause],
on_map: Dict[str, str],
on_pairs: List[Tuple[str, str]],
update_cols: List[str],
row_id_name: str,
update_schema: pa.Schema,
):
prepared_clauses = []
for clause in clauses:
rewritten = None
if clause.condition is not None:
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = remap_source_on_keys(
rewrite_condition(clause.condition), on_map,
)
prepared_clauses.append((clause.spec, rewritten))

_filter_batch = None
if any(r is not None for _, r in prepared_clauses):
from pypaimon.ray.merge_condition import filter_batch as _filter_batch

def _transform(batch: pa.Table) -> pa.Table:
remaining = batch
parts = []
for spec, rewritten in prepared_clauses:
if remaining.num_rows == 0:
break
if rewritten is not None:
matched = _filter_batch(
remaining, rewritten, _pre_rewritten=True,
)
else:
matched = remaining
if matched.num_rows == 0:
continue
parts.append(vectorized_matched_transform(
matched, spec, on_pairs,
update_cols, row_id_name,
update_schema,
))
if rewritten is not None and matched.num_rows < remaining.num_rows:
not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
remaining = _filter_batch(
remaining, not_cond, _pre_rewritten=True,
)
else:
remaining = remaining.slice(0, 0)
if not parts:
return update_schema.empty_table()
return pa.concat_tables(parts)

return _transform


def build_self_merge_update_ds(
*,
target_identifier: str,
clauses: List[_NormalizedClause],
target_field_names: Sequence[str],
target_pa_schema: pa.Schema,
update_cols: Sequence[str],
catalog_options: Dict[str, str],
resolve_target_projection,
snapshot_id: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
) -> Tuple:
from pypaimon.ray.ray_paimon import read_paimon
from pypaimon.table.special_fields import SpecialFields

row_id_name = SpecialFields.ROW_ID.name
needed_cols = set(resolve_target_projection(
clauses, [row_id_name], update_cols, target_field_names,
))
for clause in clauses:
for value in clause.spec.values():
if isinstance(value, SourceColumnRef):
needed_cols.add(value.column)
target_set = set(target_field_names)
for clause in clauses:
if clause.condition is not None:
from pypaimon.ray.merge_condition import extract_columns
for ref in extract_columns(clause.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col in target_set:
needed_cols.add(col)
projection = [row_id_name] + [
c for c in target_field_names if c in needed_cols
]

target_ds = read_paimon(
target_identifier, catalog_options,
projection=projection, snapshot_id=snapshot_id,
)
update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name)

orig_names = target_ds.schema().names
target_renamed = target_ds.rename_columns(
{c: f"t.{c}" for c in orig_names}
)

def _add_source_aliases(batch: pa.Table) -> pa.Table:
columns = list(batch.columns)
names = list(batch.schema.names)
for orig in orig_names:
if orig == row_id_name:
continue
t_col_name = f"t.{orig}"
if t_col_name in names:
idx = names.index(t_col_name)
columns.append(columns[idx])
names.append(f"s.{orig}")
return pa.table(columns, names=names)

aliased = target_renamed.map_batches(
_add_source_aliases, **_map_kwargs(ray_remote_args),
)

_transform = _build_matched_transform(
clauses,
on_map={row_id_name: row_id_name},
on_pairs=[(row_id_name, row_id_name)],
update_cols=list(update_cols),
row_id_name=row_id_name,
update_schema=update_schema,
)
return aliased.map_batches(_transform, **_map_kwargs(ray_remote_args))


def build_matched_update_ds(
*,
target_identifier: str,
Expand Down Expand Up @@ -87,58 +219,14 @@ def build_matched_update_ds(
right_on=tuple(f"s.{c}" for c in source_on),
)

captured_update_cols = list(update_cols)
captured_row_id_name = row_id_name
captured_on_pairs = list(zip(source_on, target_on))
captured_schema = update_schema

on_map = dict(zip(source_on, target_on))
prepared_clauses = []
for clause in clauses:
rewritten = None
if clause.condition is not None:
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = remap_source_on_keys(
rewrite_condition(clause.condition), on_map,
)
prepared_clauses.append((clause.spec, rewritten))

_filter_batch = None
if any(r is not None for _, r in prepared_clauses):
from pypaimon.ray.merge_condition import filter_batch as _filter_batch

def _transform(batch: pa.Table) -> pa.Table:
remaining = batch
parts = []
for spec, rewritten in prepared_clauses:
if remaining.num_rows == 0:
break
if rewritten is not None:
matched = _filter_batch(
remaining, rewritten, _pre_rewritten=True,
)
else:
matched = remaining
if matched.num_rows == 0:
continue
parts.append(vectorized_matched_transform(
matched, spec, captured_on_pairs,
captured_update_cols, captured_row_id_name,
captured_schema,
))
if rewritten is not None and matched.num_rows < remaining.num_rows:
not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
remaining = _filter_batch(
remaining, not_cond, _pre_rewritten=True,
)
else:
remaining = remaining.slice(0, 0)
if not parts:
return captured_schema.empty_table()
return pa.concat_tables(parts)

_transform = _build_matched_transform(
clauses,
on_map=dict(zip(source_on, target_on)),
on_pairs=list(zip(source_on, target_on)),
update_cols=list(update_cols),
row_id_name=row_id_name,
update_schema=update_schema,
)
return joined.map_batches(_transform, **_map_kwargs(ray_remote_args))


Expand Down
Loading
Loading