Skip to content

Commit

Permalink
[nereids] pull up join from union all rule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj authored and zhongjian.xzj committed Dec 21, 2023
1 parent 48d5cbe commit a2473e0
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ public String getReferencedColumnName(String column) {
return foreignToReference.get(column);
}

public ImmutableMap<String, String> getForeignToReference() {
return foreignToReference;
}

public Map<Column, Column> getForeignToPrimary(TableIf curTable) {
ImmutableMap.Builder<Column, Column> columnBuilder = new ImmutableMap.Builder<>();
TableIf refTable = referencedTable.toTableIf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ public class Rewriter extends AbstractBatchJobExecutor {

// this rule should invoke after infer predicate and push down distinct, and before push down limit
custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, EliminateJoinByFK::new),

// this rule should be after topic "Column pruning and infer predicate"
// this rule should be after topic "Column pruning and infer predicate"
topic("Join pull up",
topDown(
new EliminateFilter(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.catalog.constraint.Constraint;
import org.apache.doris.catalog.constraint.ForeignKeyConstraint;
import org.apache.doris.catalog.constraint.PrimaryKeyConstraint;
import org.apache.doris.catalog.constraint.UniqueConstraint;
import org.apache.doris.nereids.jobs.JobContext;
Expand Down Expand Up @@ -303,7 +304,6 @@ private boolean checkAggrKeyOnUkOrPk(LogicalAggregate aggregate, LogicalCatalogR
}

private boolean checkJoinConditionOnPk(LogicalJoin joinRoot, LogicalCatalogRelation table, PullUpContext context) {
// get pk info from table
Set<String> pkInfos = getPkInfoFromConstraint(table);
if (pkInfos == null || pkInfos.size() != 1) {
return false;
Expand All @@ -314,6 +314,9 @@ private boolean checkJoinConditionOnPk(LogicalJoin joinRoot, LogicalCatalogRelat
boolean found = false;
for (LogicalJoin join : joinList) {
List<Expression> conditions = join.getHashJoinConjuncts();
List<LogicalCatalogRelation> basicTableList = new ArrayList<>();
basicTableList.addAll((Collection<? extends LogicalCatalogRelation>) join
.collect(LogicalCatalogRelation.class::isInstance));
for (Expression equalTo : conditions) {
if (equalTo instanceof EqualTo
&& ((EqualTo) equalTo).left() instanceof SlotReference
Expand All @@ -322,14 +325,38 @@ private boolean checkJoinConditionOnPk(LogicalJoin joinRoot, LogicalCatalogRelat
SlotReference rightSlot = (SlotReference) ((EqualTo) equalTo).right();
if (table.getOutputExprIds().contains(leftSlot.getExprId())
&& pkSlot.equals(leftSlot.getName())) {
found = true;
context.replaceColumns.add(rightSlot);
context.pullUpTableToPkSlotMap.put(table, leftSlot);
// pk-fk join condition, check other side's join key is on fk
LogicalCatalogRelation rightTable = findTableFromSlot(rightSlot, basicTableList);
if (rightTable != null && getFkInfoFromConstraint(rightTable) != null) {
ForeignKeyConstraint fkInfo = getFkInfoFromConstraint(rightTable);
if (fkInfo.getReferencedTable().getId() == table.getTable().getId()) {
for (Map.Entry<String, String> entry : fkInfo.getForeignToReference().entrySet()) {
if (entry.getValue().equals(pkSlot) && entry.getKey().equals(rightSlot.getName())) {
found = true;
context.replaceColumns.add(rightSlot);
context.pullUpTableToPkSlotMap.put(table, leftSlot);
break;
}
}
}
}
} else if (table.getOutputExprIds().contains(rightSlot.getExprId())
&& pkSlot.equals(rightSlot.getName())) {
found = true;
context.replaceColumns.add(leftSlot);
context.pullUpTableToPkSlotMap.put(table, leftSlot);
// pk-fk join condition, check other side's join key is on fk
LogicalCatalogRelation leftTable = findTableFromSlot(leftSlot, basicTableList);
if (leftTable != null && getFkInfoFromConstraint(leftTable) != null) {
ForeignKeyConstraint fkInfo = getFkInfoFromConstraint(leftTable);
if (fkInfo.getReferencedTable().getId() == table.getTable().getId()) {
for (Map.Entry<String, String> entry : fkInfo.getForeignToReference().entrySet()) {
if (entry.getValue().equals(pkSlot) && entry.getKey().equals(leftSlot.getName())) {
found = true;
context.replaceColumns.add(leftSlot);
context.pullUpTableToPkSlotMap.put(table, rightSlot);
break;
}
}
}
}
}
if (found) {
break;
Expand All @@ -343,6 +370,31 @@ private boolean checkJoinConditionOnPk(LogicalJoin joinRoot, LogicalCatalogRelat
return found;
}

private LogicalCatalogRelation findTableFromSlot(SlotReference targetSlot,
List<LogicalCatalogRelation> tableList) {
for (LogicalCatalogRelation table : tableList) {
if (table.getOutputExprIds().contains(targetSlot.getExprId())) {
return table;
}
}
return null;
}

private ForeignKeyConstraint getFkInfoFromConstraint(LogicalCatalogRelation table) {
table.getTable().readLock();
try {
for (Map.Entry<String, Constraint> constraintMap : table.getTable().getConstraintsMap().entrySet()) {
Constraint constraint = constraintMap.getValue();
if (constraint instanceof ForeignKeyConstraint) {
return (ForeignKeyConstraint) constraint;
}
}
return null;
} finally {
table.getTable().readUnlock();
}
}

private Set<String> getPkInfoFromConstraint(LogicalCatalogRelation table) {
table.getTable().readLock();
try {
Expand Down Expand Up @@ -386,7 +438,9 @@ private boolean checkJoinRoot(LogicalJoin joinRoot) {
List<LogicalJoin> allJoinNodes = Lists.newArrayList();
allJoinNodes.addAll((Collection<? extends LogicalJoin>) joinRoot.collect(LogicalJoin.class::isInstance));
boolean joinTypeMatch = allJoinNodes.stream().allMatch(e -> e.getJoinType() == JoinType.INNER_JOIN);
if (!joinTypeMatch) {
boolean joinConditionMatch = allJoinNodes.stream()
.allMatch(e -> !e.getHashJoinConjuncts().isEmpty() && e.getOtherJoinConjuncts().isEmpty());
if (!joinTypeMatch || !joinConditionMatch) {
return false;
}

Expand Down
12 changes: 12 additions & 0 deletions regression-test/suites/nereids_tpcds_shape_sf1000_p0/load.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,18 @@ suite("load") {
alter table customer add constraint uk unique (c_customer_id);
'''

sql '''
alter table store_sales add constraint ss_fk foreign key(ss_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table web_sales add constraint ws_fk foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table catalog_sales add constraint cs_fk foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
'''

sql """
alter table customer_demographics modify column cd_dep_employed_count set stats ('row_count'='1920800', 'ndv'='7', 'num_nulls'='0', 'min_value'='0', 'max_value'='6', 'data_size'='7683200')
"""
Expand Down
14 changes: 13 additions & 1 deletion regression-test/suites/nereids_tpcds_shape_sf100_p0/load.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,19 @@ alter table customer add constraint pk primary key (c_customer_sk);
sql '''
alter table customer add constraint uk unique (c_customer_id);
'''


sql '''
alter table store_sales add constraint ss_fk foreign key(ss_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table web_sales add constraint ws_fk foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table catalog_sales add constraint cs_fk foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
'''

sql """
alter table web_sales modify column ws_web_site_sk set stats ('row_count'='72001237', 'ndv'='24', 'min_value'='1', 'max_value'='24', 'avg_size'='576009896', 'max_size'='576009896' )
"""
Expand Down

0 comments on commit a2473e0

Please sign in to comment.