Skip to content

Commit 9a40788

Browse files
[CALCITE-7278] Correlated subqueries in the join condition cannot reference both join inputs
1 parent 077d355 commit 9a40788

3 files changed

Lines changed: 526 additions & 4 deletions

File tree

core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java

Lines changed: 352 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.apache.calcite.rex.LogicVisitor;
3434
import org.apache.calcite.rex.RexBuilder;
3535
import org.apache.calcite.rex.RexCorrelVariable;
36+
import org.apache.calcite.rex.RexFieldAccess;
3637
import org.apache.calcite.rex.RexInputRef;
3738
import org.apache.calcite.rex.RexLiteral;
3839
import org.apache.calcite.rex.RexNode;
@@ -47,15 +48,19 @@
4748
import org.apache.calcite.sql2rel.RelDecorrelator;
4849
import org.apache.calcite.tools.RelBuilder;
4950
import org.apache.calcite.util.ImmutableBitSet;
51+
import org.apache.calcite.util.Litmus;
5052
import org.apache.calcite.util.Pair;
5153

5254
import com.google.common.collect.ImmutableList;
55+
import com.google.common.collect.ImmutableSet;
5356
import com.google.common.collect.Iterables;
5457

5558
import org.immutables.value.Value;
5659

5760
import java.util.ArrayList;
61+
import java.util.HashMap;
5862
import java.util.List;
63+
import java.util.Map;
5964
import java.util.Set;
6065
import java.util.stream.Collectors;
6166

@@ -967,10 +972,8 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
967972
boolean inputIntersectsRightSide =
968973
inputSet.intersects(ImmutableBitSet.range(nFieldsLeft, nFieldsLeft + nFieldsRight));
969974
if (inputIntersectsLeftSide && inputIntersectsRightSide) {
970-
// The current existential rewrite needs to make join with one side of the origin join and
971-
// generate a new condition to replace the on clause. But for RexNode whose operands are
972-
// on either side of the join, we can't push them into join. So this rewriting is not
973-
// supported.
975+
rewriteSubQueryOnDomain(rule, call, e, join, nFieldsLeft, nFieldsRight,
976+
inputSet, builder, variablesSet);
974977
return;
975978
}
976979

@@ -1079,6 +1082,232 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
10791082
call.transformTo(builder.build());
10801083
}
10811084

1085+
/**
1086+
* Rewrites a sub-query that references columns from both the left and right inputs of a Join.
1087+
*
1088+
* <p>This method handles the complex case where a sub-query in a Join condition is correlated
1089+
* with both sides of the Join. It performs the following steps:
1090+
* <ol>
1091+
* <li>Identifies the "Domain" of values from the left and right inputs that are relevant
1092+
* to the sub-query.</li>
1093+
* <li>Constructs a "Computation Domain" by cross-joining the distinct keys from the left
1094+
* and right domains.</li>
1095+
* <li>Remaps the sub-query to operate on this Computation Domain.</li>
1096+
* <li>Rewrites the sub-query using the standard {@link #apply} method, but applied to the
1097+
* Domain.</li>
1098+
* <li>Re-integrates the result of the sub-query rewrite back into the original Join structure,
1099+
* ensuring correct join types and conditions are maintained.</li>
1100+
* </ol>
1101+
*
1102+
* @param rule The rule instance
1103+
* @param call The rule call
1104+
* @param e The sub-query to rewrite
1105+
* @param join The join containing the sub-query
1106+
* @param nFieldsLeft Number of fields in the left input
1107+
* @param nFieldsRight Number of fields in the right input
1108+
* @param inputSet BitSet of columns used by the sub-query
1109+
* @param builder The RelBuilder
1110+
* @param variablesSet Set of correlation variables used by the sub-query
1111+
*/
1112+
private static void rewriteSubQueryOnDomain(SubQueryRemoveRule rule,
1113+
RelOptRuleCall call,
1114+
RexSubQuery e,
1115+
Join join,
1116+
int nFieldsLeft,
1117+
int nFieldsRight,
1118+
ImmutableBitSet inputSet,
1119+
RelBuilder builder,
1120+
Set<CorrelationId> variablesSet) {
1121+
// Map to store the offset of each correlation variable
1122+
final Map<CorrelationId, Integer> idToOffset = new HashMap<>();
1123+
// Helper to determine offset for each correlation variable
1124+
e.rel.accept(new CorrelationOffsetFinder(idToOffset, join, nFieldsLeft));
1125+
1126+
// 1. Identify which columns from Left and Right are used by the subquery.
1127+
// These will form the "Domain" on which the subquery is calculated.
1128+
final ImmutableBitSet leftUsed =
1129+
inputSet.intersect(ImmutableBitSet.range(0, nFieldsLeft));
1130+
final ImmutableBitSet rightUsed =
1131+
inputSet.intersect(ImmutableBitSet.range(nFieldsLeft, nFieldsLeft + nFieldsRight));
1132+
1133+
// 2. Build the "Computation Domain".
1134+
// This is a Cross Join of the distinct keys from Left and Right.
1135+
// Domain = Distinct(Project(LeftUsed)) x Distinct(Project(RightUsed))
1136+
1137+
// 2a. Left Domain
1138+
builder.push(join.getLeft());
1139+
builder.project(builder.fields(leftUsed));
1140+
builder.distinct();
1141+
1142+
// 2b. Right Domain
1143+
builder.push(join.getRight());
1144+
// We must shift the bitset to be 0-based for the Right input
1145+
ImmutableBitSet rightUsedShifted = rightUsed.shift(-nFieldsLeft);
1146+
builder.project(builder.fields(rightUsedShifted));
1147+
builder.distinct();
1148+
1149+
// 2c. Create Domain Cross Join
1150+
builder.join(JoinRelType.INNER, builder.literal(true));
1151+
1152+
// 3. Remap the SubQuery to run on the Domain.
1153+
// We need to map original field indices to their new positions in the Domain.
1154+
// Original: [LeftFields... | RightFields...]
1155+
// Domain: [LeftUsed... | RightUsed...]
1156+
final Map<Integer, Integer> mapping = new HashMap<>();
1157+
int targetIdx = 0;
1158+
for (int source : leftUsed) {
1159+
mapping.put(source, targetIdx++);
1160+
}
1161+
for (int source : rightUsed) {
1162+
mapping.put(source, targetIdx++);
1163+
}
1164+
1165+
final RexBuilder rexBuilder = builder.getRexBuilder();
1166+
final CorrelationId domainCorrId = join.getCluster().createCorrel();
1167+
final RexNode domainCorrVar = rexBuilder.makeCorrel(builder.peek().getRowType(), domainCorrId);
1168+
1169+
// Shuttle to replace InputRefs and Correlations with references to the Domain
1170+
RexShuttle shuttle = new InputRefAndCorrelationReplacer(mapping, variablesSet, idToOffset);
1171+
// Create the new subquery with operands remapped to the Domain
1172+
RexNode newSubQueryNode = e.accept(shuttle);
1173+
1174+
// Rewrite e.rel to use domainCorrId
1175+
RelNode newRel = e.rel.accept(new DomainRewriter(variablesSet, idToOffset, mapping,
1176+
rexBuilder, domainCorrVar));
1177+
1178+
if (newSubQueryNode instanceof RexSubQuery) {
1179+
newSubQueryNode = ((RexSubQuery) newSubQueryNode).clone(newRel);
1180+
}
1181+
1182+
// We introduced a new correlation variable domainCorrId.
1183+
Set<CorrelationId> newVariablesSet = ImmutableSet.of(domainCorrId);
1184+
1185+
final RelOptUtil.Logic logic =
1186+
LogicVisitor.find(join.getJoinType().generatesNullsOnRight()
1187+
? RelOptUtil.Logic.TRUE_FALSE_UNKNOWN : RelOptUtil.Logic.TRUE,
1188+
ImmutableList.of(join.getCondition()), e);
1189+
1190+
// 4. Apply the standard rewriting rule to the Domain.
1191+
// The builder is currently sitting on the Domain Join.
1192+
// 'target' is the CASE expression (or similar) resulting from the rewrite.
1193+
// The builder stack now has the result of the rewrite (e.g. Domain Left Join Aggregate).
1194+
assert newSubQueryNode instanceof RexSubQuery;
1195+
final RexNode target =
1196+
rule.apply((RexSubQuery) newSubQueryNode, newVariablesSet, logic, builder,
1197+
1, builder.peek().getRowType().getFieldCount(), 0);
1198+
1199+
// The target references the Domain Result (which is currently at the top of the builder).
1200+
// In the final plan, the Domain Result will be joined to the right of the original inputs.
1201+
// Furthermore, since we use a LEFT JOIN, the Domain Result columns become nullable.
1202+
// So we need to shift the references in target AND make them nullable.
1203+
final int offset = nFieldsLeft + nFieldsRight;
1204+
final RexShuttle shiftAndNullableShuttle = new RexShuttle() {
1205+
@Override public RexNode visitInputRef(RexInputRef inputRef) {
1206+
// Shift the index
1207+
int newIndex = inputRef.getIndex() + offset;
1208+
return new RexInputRef(newIndex, inputRef.getType());
1209+
}
1210+
};
1211+
final RexNode shiftedTarget = target.accept(shiftAndNullableShuttle);
1212+
1213+
// 5. Re-integrate with Original Inputs
1214+
// Stack has: [RewriteResult]
1215+
RelNode domainResult = builder.build();
1216+
1217+
// Rebuild the original Join structure
1218+
// We want to construct: Left JOIN (Right JOIN Domain) ON ...
1219+
// This preserves the JoinRelType of the original join.
1220+
JoinRelType joinType = join.getJoinType();
1221+
if (joinType == JoinRelType.RIGHT) {
1222+
// Symmetric to LEFT/INNER/FULL but attached to Left
1223+
builder.push(join.getLeft());
1224+
builder.push(domainResult);
1225+
1226+
// Join Left and Domain on Left Keys
1227+
List<RexNode> leftJoinConditions = new ArrayList<>();
1228+
int domainIdx = 0; // Left Keys are at start of Domain
1229+
for (int source : leftUsed) {
1230+
leftJoinConditions.add(
1231+
builder.equals(
1232+
builder.field(2, 0, source),
1233+
builder.field(2, 1, domainIdx++)));
1234+
}
1235+
builder.join(JoinRelType.INNER, builder.and(leftJoinConditions));
1236+
1237+
// Now Join Right
1238+
builder.push(join.getRight());
1239+
// Stack: (Left+Domain), Right
1240+
1241+
// Join Condition: Original + Right Keys match
1242+
List<RexNode> rightJoinConditions = new ArrayList<>();
1243+
// Domain starts after Left. Right Keys in Domain are after Left Keys.
1244+
int domainRightKeyIdx = nFieldsLeft + leftUsed.cardinality();
1245+
for (int source : rightUsed) {
1246+
// Right input (index 1)
1247+
RexInputRef field = builder.field(2, 1, source - nFieldsLeft);
1248+
// (Left+Domain) input (index 0)
1249+
RexInputRef field1 = builder.field(2, 0, domainRightKeyIdx++);
1250+
rightJoinConditions.add(builder.equals(field, field1));
1251+
}
1252+
1253+
RexShuttle replaceShuttle = new ReplaceSubQueryShuttle(e, shiftedTarget);
1254+
RexNode newJoinCondition = join.getCondition().accept(replaceShuttle);
1255+
1256+
builder.join(joinType, builder.and(builder.and(rightJoinConditions), newJoinCondition));
1257+
1258+
builder.project(fields(builder, nFieldsLeft + nFieldsRight));
1259+
} else {
1260+
// For INNER, LEFT, FULL join, we can attach Domain to Right, then Join Left.
1261+
// 1. Build (Right JOIN Domain)
1262+
builder.push(join.getRight());
1263+
builder.push(domainResult);
1264+
1265+
// Join Right and Domain on Right Keys
1266+
// Domain layout: [LeftKeys, RightKeys]
1267+
List<RexNode> rightJoinConditions = new ArrayList<>();
1268+
// Skip Left Keys
1269+
int domainIdx = leftUsed.cardinality();
1270+
for (int source : rightUsed) {
1271+
rightJoinConditions.add(
1272+
builder.equals(
1273+
builder.field(2, 0, source - nFieldsLeft), // Right input
1274+
builder.field(2, 1, domainIdx++))); // Domain input
1275+
}
1276+
// We use INNER join here to expand Right with Domain values.
1277+
// Since Domain contains all distinct Right keys, this is safe.
1278+
builder.join(JoinRelType.INNER, builder.and(rightJoinConditions));
1279+
1280+
// 2. Join Left with (Right JOIN Domain)
1281+
RelNode rightWithDomain = builder.build();
1282+
builder.push(join.getLeft());
1283+
builder.push(rightWithDomain);
1284+
1285+
// Join Condition: Original Condition (rewritten) AND Left.LeftKeys = Domain.LeftKeys
1286+
List<RexNode> leftJoinConditions = new ArrayList<>();
1287+
// In (Right+Domain), Domain fields start after Right fields
1288+
int domainStartInCombined = nFieldsRight;
1289+
int domainLeftKeyIdx = domainStartInCombined; // Left Keys are at start of Domain
1290+
1291+
for (int source : leftUsed) {
1292+
// Left input
1293+
RexInputRef field = builder.field(2, 0, source);
1294+
// (Right+Domain) input
1295+
RexInputRef field1 = builder.field(2, 1, domainLeftKeyIdx++);
1296+
leftJoinConditions.add(builder.equals(field, field1));
1297+
}
1298+
1299+
RexShuttle replaceShuttle = new ReplaceSubQueryShuttle(e, shiftedTarget);
1300+
RexNode newJoinCondition = join.getCondition().accept(replaceShuttle);
1301+
1302+
builder.join(joinType, builder.and(builder.and(leftJoinConditions), newJoinCondition));
1303+
1304+
// Project original fields (remove Domain columns)
1305+
builder.project(fields(builder, nFieldsLeft + nFieldsRight));
1306+
}
1307+
1308+
call.transformTo(builder.build());
1309+
}
1310+
10821311
private static void matchFilterEnableMarkJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
10831312
final Filter filter = call.rel(0);
10841313
final Set<CorrelationId> variablesSet = filter.getVariablesSet();
@@ -1212,6 +1441,125 @@ private static class ReplaceSubQueryShuttle extends RexShuttle {
12121441
return subQuery.equals(this.subQuery) ? replacement : subQuery;
12131442
}
12141443
}
1444+
1445+
/**
1446+
* Shuttle that finds correlation variables and determines their offset.
1447+
*/
1448+
private static class CorrelationOffsetFinder extends RelHomogeneousShuttle {
1449+
private final Map<CorrelationId, Integer> idToOffset;
1450+
private final Join join;
1451+
private final int nFieldsLeft;
1452+
1453+
CorrelationOffsetFinder(Map<CorrelationId, Integer> idToOffset, Join join, int nFieldsLeft) {
1454+
this.idToOffset = idToOffset;
1455+
this.join = join;
1456+
this.nFieldsLeft = nFieldsLeft;
1457+
}
1458+
1459+
@Override public RelNode visit(RelNode other) {
1460+
other.accept(new RexShuttle() {
1461+
@Override public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) {
1462+
if (!idToOffset.containsKey(correlVariable.id)) {
1463+
// Check if type matches Left
1464+
if (RelOptUtil.eq("type1", correlVariable.getType(),
1465+
"type2", join.getLeft().getRowType(), Litmus.IGNORE)) {
1466+
idToOffset.put(correlVariable.id, 0);
1467+
} else if (RelOptUtil.eq("type1", correlVariable.getType(),
1468+
"type2", join.getRight().getRowType(), Litmus.IGNORE)) {
1469+
idToOffset.put(correlVariable.id, nFieldsLeft);
1470+
} else {
1471+
// Default to 0 if unknown
1472+
idToOffset.put(correlVariable.id, 0);
1473+
}
1474+
}
1475+
return super.visitCorrelVariable(correlVariable);
1476+
}
1477+
});
1478+
return super.visit(other);
1479+
}
1480+
}
1481+
1482+
/**
1483+
* Shuttle that replaces InputRefs and Correlations with references to the Domain.
1484+
*/
1485+
private static class InputRefAndCorrelationReplacer extends RexShuttle {
1486+
private final Map<Integer, Integer> mapping;
1487+
private final Set<CorrelationId> variablesSet;
1488+
private final Map<CorrelationId, Integer> idToOffset;
1489+
1490+
InputRefAndCorrelationReplacer(Map<Integer, Integer> mapping,
1491+
Set<CorrelationId> variablesSet, Map<CorrelationId, Integer> idToOffset) {
1492+
this.mapping = mapping;
1493+
this.variablesSet = variablesSet;
1494+
this.idToOffset = idToOffset;
1495+
}
1496+
1497+
@Override public RexNode visitInputRef(RexInputRef inputRef) {
1498+
Integer newIndex = mapping.get(inputRef.getIndex());
1499+
if (newIndex != null) {
1500+
return new RexInputRef(newIndex, inputRef.getType());
1501+
}
1502+
return super.visitInputRef(inputRef);
1503+
}
1504+
1505+
@Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
1506+
RexNode refExpr = fieldAccess.getReferenceExpr();
1507+
if (refExpr instanceof RexCorrelVariable) {
1508+
CorrelationId id = ((RexCorrelVariable) refExpr).id;
1509+
if (variablesSet.contains(id)) {
1510+
int fieldIdx = fieldAccess.getField().getIndex();
1511+
int offset = idToOffset.getOrDefault(id, 0);
1512+
Integer newIndex = mapping.get(fieldIdx + offset);
1513+
if (newIndex != null) {
1514+
return new RexInputRef(newIndex, fieldAccess.getType());
1515+
}
1516+
}
1517+
}
1518+
return super.visitFieldAccess(fieldAccess);
1519+
}
1520+
}
1521+
1522+
/**
1523+
* Shuttle that rewrites RelNodes to use the Domain correlation variable.
1524+
*/
1525+
private static class DomainRewriter extends RelHomogeneousShuttle {
1526+
private final Set<CorrelationId> variablesSet;
1527+
private final Map<CorrelationId, Integer> idToOffset;
1528+
private final Map<Integer, Integer> mapping;
1529+
private final RexBuilder rexBuilder;
1530+
private final RexNode domainCorrVar;
1531+
1532+
DomainRewriter(Set<CorrelationId> variablesSet, Map<CorrelationId, Integer> idToOffset,
1533+
Map<Integer, Integer> mapping, RexBuilder rexBuilder, RexNode domainCorrVar) {
1534+
this.variablesSet = variablesSet;
1535+
this.idToOffset = idToOffset;
1536+
this.mapping = mapping;
1537+
this.rexBuilder = rexBuilder;
1538+
this.domainCorrVar = domainCorrVar;
1539+
}
1540+
1541+
@Override public RelNode visit(RelNode other) {
1542+
return super.visit(
1543+
other.accept(new RexShuttle() {
1544+
@Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
1545+
RexNode refExpr = fieldAccess.getReferenceExpr();
1546+
if (refExpr instanceof RexCorrelVariable) {
1547+
CorrelationId id = ((RexCorrelVariable) refExpr).id;
1548+
if (variablesSet.contains(id)) {
1549+
int fieldIdx = fieldAccess.getField().getIndex();
1550+
int offset = idToOffset.getOrDefault(id, 0);
1551+
Integer newIndex = mapping.get(fieldIdx + offset);
1552+
if (newIndex != null) {
1553+
return rexBuilder.makeFieldAccess(domainCorrVar, newIndex);
1554+
}
1555+
}
1556+
}
1557+
return super.visitFieldAccess(fieldAccess);
1558+
}
1559+
}));
1560+
}
1561+
}
1562+
12151563
/** Rule configuration. */
12161564
@Value.Immutable(singleton = false)
12171565
public interface Config extends RelRule.Config {

0 commit comments

Comments
 (0)