Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,55 @@ class JoinITCase(miniBatch: MiniBatchMode, state: StateBackendMode, enableAsyncS
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testThreeWayMultiJoinWithoutPk(): Unit = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally would be great to follow new store restore test approach
would be great if you can do that instead

Copy link
Copy Markdown
Contributor Author

@ldadima ldadima Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@snuyanzin Thanks for review
This test, as in the previous PR with a similar problem, is only temporary. It was decided to expand this PR with the necessary tests.
Is it ok, what do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's add also this one to the new suite. I'll take a look at the new suite next

env.setParallelism(1)
val data1 = new mutable.MutableList[(Int, Long)]
data1.+=((1, 1L))
data1.+=((1, 2L))
data1.+=((1, 2L))
data1.+=((1, 5L))
data1.+=((2, 7L))
data1.+=((1, 9L))
data1.+=((1, 8L))
data1.+=((3, 8L))

val data2 = new mutable.MutableList[(Int, Long)]
data2.+=((1, 1L))
data2.+=((2, 2L))
data2.+=((3, 2L))
data2.+=((1, 4L))

val data3 = new mutable.MutableList[(Int, Long)]
data3.+=((1, 1L))
data3.+=((2, 2L))
data3.+=((3, 2L))
data3.+=((2, 1L))

val a = failingDataSource(data1).toTable(tEnv, 'a1, 'a2)
val b = failingDataSource(data2).toTable(tEnv, 'b1, 'b2)
val c = failingDataSource(data3).toTable(tEnv, 'c1, 'c2)

tEnv.createTemporaryView("Atable", a)
tEnv.createTemporaryView("Btable", b)
tEnv.createTemporaryView("Ctable", c)

tEnv.getConfig.getConfiguration
.setString(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED.key(), "true")
val query1 = "SELECT SUM(a2) AS a2, a1 FROM Atable group by a1"
val query2 = "SELECT SUM(b2) AS b2, b1 FROM Btable group by b1"
val query3 = "SELECT SUM(c2) AS c2, c1 FROM Ctable group by c1"
val query =
s"SELECT a1, b1, c1 FROM ($query1) JOIN ($query2) ON a1 = b1 JOIN ($query3) ON c2 = b2"

val sink = new TestingRetractSink
tEnv.sqlQuery(query).toRetractStream[Row].addSink(sink).setParallelism(1)
env.execute()

val expected = Seq("2,2,3", "3,3,3")
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
}

@TestTemplate
def testInnerJoinWithPk(): Unit = {
val query1 = "SELECT SUM(a2) AS a2, a1 FROM A group by a1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ private void initializeStateHandlers() {
"Keyed state store not found when initializing keyed state store handlers.");
}

boolean prohibitReuseRow = isHeapBackend();
if (prohibitReuseRow) {
boolean requiresKeyDeepCopy = isHeapBackend();
if (requiresKeyDeepCopy) {
this.keyExtractor.requiresKeyDeepCopy();
}

Expand All @@ -835,7 +835,8 @@ private void initializeStateHandlers() {
inputSpecs.get(i),
joinKeyType,
inputTypes.get(i),
stateRetentionTime[i]);
stateRetentionTime[i],
requiresKeyDeepCopy);
stateHandlers.add(stateView);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public static MultiJoinStateView create(
RowType
joinKeyType, /* joinKeyType is null for inputId = 0, see {@link InputSideHasUniqueKey}*/
RowType recordType,
long retentionTime) {
long retentionTime,
boolean requiresKeyDeepCopy) {
StateTtlConfig ttlConfig = createTtlConfig(retentionTime);

if (inputSideSpec.hasUniqueKey()) {
Expand All @@ -83,10 +84,12 @@ public static MultiJoinStateView create(
recordType,
inputSideSpec.getUniqueKeyType(),
inputSideSpec.getUniqueKeySelector(),
ttlConfig);
ttlConfig,
requiresKeyDeepCopy);
}
} else {
return new InputSideHasNoUniqueKey(ctx, stateName, joinKeyType, recordType, ttlConfig);
return new InputSideHasNoUniqueKey(
ctx, stateName, joinKeyType, recordType, ttlConfig, requiresKeyDeepCopy);
}
}

Expand Down Expand Up @@ -182,7 +185,9 @@ private static final class InputSideHasUniqueKey implements MultiJoinStateView {
private final MapState<RowData, RowData> recordState;
private final KeySelector<RowData, RowData> uniqueKeySelector;
private RowDataSerializer joinKeySerializer; // Null if joinKeyType is null
private RowDataSerializer stateKeySerializer; // Null if joinKeyType is null
private int joinKeyFieldCount = 0; // 0 if joinKeyType is null
private final boolean requiresKeyDeepCopy;

private InputSideHasUniqueKey(
RuntimeContext ctx,
Expand All @@ -191,7 +196,8 @@ private InputSideHasUniqueKey(
final RowType recordType,
final InternalTypeInfo<RowData> uniqueKeyType,
final KeySelector<RowData, RowData> uniqueKeySelector,
final StateTtlConfig ttlConfig) {
final StateTtlConfig ttlConfig,
final boolean requiresKeyDeepCopy) {
checkNotNull(uniqueKeyType);
checkNotNull(uniqueKeySelector);
this.uniqueKeySelector = uniqueKeySelector;
Expand All @@ -208,6 +214,7 @@ private InputSideHasUniqueKey(
// Composite key type: RowData with 2 fields (joinKey, uniqueKey)
// The composite key is a RowData with joinKey at index 0 and uniqueKey at index 1.
final RowType keyRowType = RowType.of(joinKeyType, uniqueKeyType.toRowType());
this.stateKeySerializer = new RowDataSerializer(keyRowType);
keyStateType = InternalTypeInfo.of(keyRowType);
}

Expand All @@ -216,6 +223,7 @@ private InputSideHasUniqueKey(
stateName, keyStateType, InternalTypeInfo.of(recordType), ttlConfig);

this.recordState = ctx.getMapState(recordStateDesc);
this.requiresKeyDeepCopy = requiresKeyDeepCopy;
}

private boolean joinKeysEqual(RowData joinKey, RowData currentJoinKeyInState) {
Expand All @@ -231,7 +239,14 @@ private RowData getStateKey(RowData joinKey, RowData uniqueKey) {
GenericRowData compositeKey = new GenericRowData(2);
compositeKey.setField(0, joinKey);
compositeKey.setField(1, uniqueKey);
return compositeKey;

// need to make deep copy when heap state backend is used
// because generic row data and binary row data are not equivalent
if (requiresKeyDeepCopy) {
return stateKeySerializer.toBinaryRow(compositeKey, true);
} else {
return compositeKey;
}
}
}

Expand Down Expand Up @@ -313,16 +328,19 @@ public Iterator<RowData> iterator() {
private static final class InputSideHasNoUniqueKey implements MultiJoinStateView {
private final MapState<RowData, Integer> recordState;
private RowDataSerializer joinKeySerializer; // Null if joinKeyType is null
private RowDataSerializer stateKeySerializer; // Null if joinKeyType is null
private int joinKeyFieldCount; // 0 if joinKeyType is null
private final int recordFieldCount;
@Nullable private final RowType joinKeyType; // Store to check for null
private final boolean requiresKeyDeepCopy;

private InputSideHasNoUniqueKey(
RuntimeContext ctx,
final String stateName,
@Nullable final RowType joinKeyType, // Can be null
final RowType recordType,
final StateTtlConfig ttlConfig) {
final StateTtlConfig ttlConfig,
final boolean requiresKeyDeepCopy) {
this.joinKeyType = joinKeyType;
this.recordFieldCount = recordType.getFieldCount();

Expand All @@ -335,13 +353,15 @@ private InputSideHasNoUniqueKey(
this.joinKeyFieldCount = this.joinKeyType.getFieldCount();
// Composite key type: RowData with 2 fields (joinKey, record)
final RowType keyRowType = RowType.of(this.joinKeyType, recordType);
this.stateKeySerializer = new RowDataSerializer(keyRowType);
keyStateType = InternalTypeInfo.of(keyRowType);
}

MapStateDescriptor<RowData, Integer> recordStateDesc =
createStateDescriptor(stateName, keyStateType, Types.INT, ttlConfig);

this.recordState = ctx.getMapState(recordStateDesc);
this.requiresKeyDeepCopy = requiresKeyDeepCopy;
}

private boolean joinKeysEqual(RowData joinKeyToLookup, RowData currentJoinKeyInState) {
Expand All @@ -357,7 +377,14 @@ private RowData getStateKey(@Nullable RowData joinKey, RowData record) {
GenericRowData compositeKey = new GenericRowData(2);
compositeKey.setField(0, joinKey);
compositeKey.setField(1, record);
return compositeKey;

// need to make deep copy when heap state backend is used
// because generic row data and binary row data are not equivalent
if (requiresKeyDeepCopy) {
return stateKeySerializer.toBinaryRow(compositeKey, true);
} else {
return compositeKey;
}
}
}

Expand Down