Skip to content

Commit

Permalink
[nereids] pull up join from union all rule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj committed Dec 22, 2023
1 parent c10723c commit 83ed78d
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateJoinByFK;
import org.apache.doris.nereids.rules.rewrite.EliminateJoinCondition;
import org.apache.doris.nereids.rules.rewrite.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.EliminateNotNull;
Expand Down Expand Up @@ -287,7 +288,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
),

// this rule should invoke after infer predicate and push down distinct, and before push down limit
//custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, EliminateJoinByFK::new),
custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, EliminateJoinByFK::new),
// this rule should be after topic "Column pruning and infer predicate"
topic("Join pull up",
topDown(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ public void replace(Map<Slot, Slot> replaceMap) {
.map(s -> replaceMap.getOrDefault(s, s))
.collect(Collectors.toSet());
slotSets = slotSets.stream()
.map(set -> set.stream().map(replaceMap::get).collect(ImmutableSet.toImmutableSet()))
.map(set -> set.stream().map(s -> replaceMap.getOrDefault(s, s)).collect(ImmutableSet.toImmutableSet()))
.collect(Collectors.toSet());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,11 @@ private boolean isPredicateCompatible(BiMap<Slot, Slot> equalSlots, Map<Column,
.map(e -> e.rewriteUp(
s -> s instanceof Slot ? primarySlotToForeign.getOrDefault(s, (Slot) s) : s))
.collect(Collectors.toSet());
return columnWithPredicates.get(fp.getKey()).containsAll(primaryPredicates);
if (columnWithPredicates.get(fp.getKey()) == null && !columnWithPredicates.isEmpty()) {
return false;
} else {
return columnWithPredicates.get(fp.getKey()).containsAll(primaryPredicates);
}
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public class PullUpJoinFromUnionAll extends OneRewriteRuleFactory {
);

private static class PullUpContext {
public final String unifiedOutputAlias = "PULL_UP_UNIFIED_OUTPUT_ALIAS";
public static final String unifiedOutputAlias = "PULL_UP_UNIFIED_OUTPUT_ALIAS";
public final Map<String, List<LogicalCatalogRelation>> pullUpCandidatesMaps = Maps.newHashMap();
public final Map<LogicalCatalogRelation, LogicalJoin> tableToJoinRootMap = Maps.newHashMap();
public final Map<LogicalCatalogRelation, LogicalAggregate> tableToAggrRootMap = Maps.newHashMap();
Expand Down Expand Up @@ -141,11 +141,8 @@ private boolean checkUnionPattern(LogicalUnion union, PullUpContext context) {
int tableListNumber = -1;
for (Plan child : union.children()) {
if (!(child instanceof LogicalProject
&& child.child(0) != null
&& child.child(0) instanceof LogicalAggregate
&& child.child(0).child(0) != null
&& child.child(0).child(0) instanceof LogicalProject
&& child.child(0).child(0).child(0) != null
&& child.child(0).child(0).child(0) instanceof LogicalJoin)) {
return false;
}
Expand Down
10 changes: 5 additions & 5 deletions regression-test/suites/nereids_tpcds_shape_sf1000_p0/load.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -808,23 +808,23 @@ suite("load") {
'''

sql '''
alter table customer add constraint pk primary key (c_customer_sk);
alter table customer add constraint customer_pk_${database} primary key (c_customer_sk);
'''

sql '''
alter table customer add constraint uk unique (c_customer_id);
alter table customer add constraint customer_uk_${database} unique (c_customer_id);
'''

sql '''
alter table store_sales add constraint ss_fk foreign key(ss_customer_sk) references customer(c_customer_sk);
alter table store_sales add constraint ss_fk_${database} foreign key(ss_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table web_sales add constraint ws_fk foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
alter table web_sales add constraint ws_fk_${database} foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table catalog_sales add constraint cs_fk foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
alter table catalog_sales add constraint cs_fk_${database} foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
'''

sql """
Expand Down
10 changes: 5 additions & 5 deletions regression-test/suites/nereids_tpcds_shape_sf100_p0/load.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -819,23 +819,23 @@ suite("load") {
'''

sql '''
alter table customer add constraint pk primary key (c_customer_sk);
alter table customer add constraint customer_pk_${database} primary key (c_customer_sk);
'''

sql '''
alter table customer add constraint uk unique (c_customer_id);
alter table customer add constraint customer_uk_${database} unique (c_customer_id);
'''

sql '''
alter table store_sales add constraint ss_fk foreign key(ss_customer_sk) references customer(c_customer_sk);
alter table store_sales add constraint ss_fk_${database} foreign key(ss_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table web_sales add constraint ws_fk foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
alter table web_sales add constraint ws_fk_${database} foreign key(ws_bill_customer_sk) references customer(c_customer_sk);
'''

sql '''
alter table catalog_sales add constraint cs_fk foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
alter table catalog_sales add constraint cs_fk_${database} foreign key(cs_bill_customer_sk) references customer(c_customer_sk);
'''

sql """
Expand Down

0 comments on commit 83ed78d

Please sign in to comment.