Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](nereids)support correlated scalar subquery without scalar agg #39471

Merged
merged 8 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,12 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after doing NormalizeAggregate in analysis job
// we need run the following 2 rules to make AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION work
bottomUp(new PullUpProjectUnderApply()),
topDown(new PushDownFilterThroughProject()),
topDown(
new PushDownFilterThroughProject(),
// the subquery may have where and having clause
// so there may be two filters we need to merge them
new MergeFilters()
),
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
AggScalarSubQueryToWindowFunction::new),
bottomUp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public class PushDownFilterThroughProject extends PlanPostProcessor {
public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesContext context) {
filter = (PhysicalFilter<? extends Plan>) super.visit(filter, context);
Plan child = filter.child();
if (!(child instanceof PhysicalProject)) {
// don't push down filter if child project contains NoneMovableFunction
if (!(child instanceof PhysicalProject) || ((PhysicalProject) child).containsNoneMovableFunction()) {
return filter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC

Plan child = filter.child();
// Forbidden filter-project, we must make filter-project -> project-filter.
if (child instanceof PhysicalProject) {
// except that the project contains NoneMovableFunction
if (child instanceof PhysicalProject && !((PhysicalProject<?>) child).containsNoneMovableFunction()) {
throw new AnalysisException(
"Nereids generate a filter-project plan, but backend not support:\n" + filter.treeString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,12 @@
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
Expand Down Expand Up @@ -425,18 +422,6 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi
return buildResult.first;
} else {
Expression castFunction = TypeCoercionUtils.processBoundFunction((BoundFunction) buildResult.first);
if (castFunction instanceof Count
&& context != null
&& context.cascadesContext.getOuterScope().isPresent()
&& !context.cascadesContext.getOuterScope().get().getCorrelatedSlots().isEmpty()) {
// consider sql: SELECT * FROM t1 WHERE t1.a <= (SELECT COUNT(t2.a) FROM t2 WHERE (t1.b = t2.b));
// when unnest correlated subquery, we create a left join node.
// outer query is left table and subquery is right one
// if there is no match, the row from right table is filled with nulls
// but COUNT function is always not nullable.
// so wrap COUNT with Nvl to ensure it's result is 0 instead of null to get the correct result
castFunction = new Nvl(castFunction, new BigIntLiteral(0));
}
return castFunction;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
Expand Down Expand Up @@ -52,21 +53,27 @@
* Resolve having clause to the aggregation/repeat.
* need Top to Down to traverse plan,
* because we need to process FILL_UP_SORT_HAVING_AGGREGATE before FILL_UP_HAVING_AGGREGATE.
* be aware that when filling up the missing slots, we should exclude outer query's correlated slots.
* because these correlated slots belong to outer query, so should not try to find them in child node.
*/
public class FillUpMissingSlots implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.FILL_UP_SORT_PROJECT.build(
logicalSort(logicalProject())
.then(sort -> {
.thenApply(ctx -> {
LogicalSort<LogicalProject<Plan>> sort = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
LogicalProject<Plan> project = sort.child();
Set<Slot> projectOutputSet = project.getOutputSet();
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s))
.filter(s -> !projectOutputSet.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
Comment on lines +75 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments?

.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
Expand All @@ -82,7 +89,9 @@ public List<Rule> buildRules() {
aggregate(logicalHaving(aggregate()))
.when(a -> a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
).when(this::checkSort)
.then(sort -> processDistinctProjectWithAggregate(sort, sort.child(), sort.child().child().child()))
.thenApply(ctx -> processDistinctProjectWithAggregate(ctx.root,
ctx.root.child(), ctx.root.child().child().child(),
ctx.cascadesContext.getOuterScope()))
),
// ATTN: process aggregate with distinct project, must run this rule before FILL_UP_SORT_AGGREGATE
// because this pattern will always fail in FILL_UP_SORT_AGGREGATE
Expand All @@ -91,14 +100,17 @@ public List<Rule> buildRules() {
aggregate(aggregate())
.when(a -> a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
).when(this::checkSort)
.then(sort -> processDistinctProjectWithAggregate(sort, sort.child(), sort.child().child()))
.thenApply(ctx -> processDistinctProjectWithAggregate(ctx.root,
ctx.root.child(), ctx.root.child().child(),
ctx.cascadesContext.getOuterScope()))
),
RuleType.FILL_UP_SORT_AGGREGATE.build(
logicalSort(aggregate())
.when(this::checkSort)
.then(sort -> {
.thenApply(ctx -> {
LogicalSort<Aggregate<Plan>> sort = ctx.root;
Aggregate<Plan> agg = sort.child();
Resolver resolver = new Resolver(agg);
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
Expand All @@ -118,10 +130,11 @@ public List<Rule> buildRules() {
RuleType.FILL_UP_SORT_HAVING_AGGREGATE.build(
logicalSort(logicalHaving(aggregate()))
.when(this::checkSort)
.then(sort -> {
.thenApply(ctx -> {
LogicalSort<LogicalHaving<Aggregate<Plan>>> sort = ctx.root;
LogicalHaving<Aggregate<Plan>> having = sort.child();
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg);
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
Expand All @@ -138,13 +151,17 @@ public List<Rule> buildRules() {
})
),
RuleType.FILL_UP_SORT_HAVING_PROJECT.build(
logicalSort(logicalHaving(logicalProject())).then(sort -> {
logicalSort(logicalHaving(logicalProject())).thenApply(ctx -> {
LogicalSort<LogicalHaving<LogicalProject<Plan>>> sort = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
Set<Slot> childOutput = sort.child().getOutputSet();
Set<Slot> notExistedInProject = sort.getOrderKeys().stream()
.map(OrderKey::getExpr)
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !childOutput.contains(s))
.filter(s -> !childOutput.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
Expand All @@ -158,9 +175,10 @@ public List<Rule> buildRules() {
})
),
RuleType.FILL_UP_HAVING_AGGREGATE.build(
logicalHaving(aggregate()).then(having -> {
logicalHaving(aggregate()).thenApply(ctx -> {
LogicalHaving<Aggregate<Plan>> having = ctx.root;
Aggregate<Plan> agg = having.child();
Resolver resolver = new Resolver(agg);
Resolver resolver = new Resolver(agg, ctx.cascadesContext.getOuterScope());
having.getConjuncts().forEach(resolver::resolve);
return createPlan(resolver, agg, (r, a) -> {
Set<Expression> newConjuncts = ExpressionUtils.replace(
Expand All @@ -175,7 +193,9 @@ public List<Rule> buildRules() {
),
// Convert having to filter
RuleType.FILL_UP_HAVING_PROJECT.build(
logicalHaving(logicalProject()).then(having -> {
logicalHaving(logicalProject()).thenApply(ctx -> {
LogicalHaving<LogicalProject<Plan>> having = ctx.root;
Optional<Scope> outerScope = ctx.cascadesContext.getOuterScope();
if (having.getExpressions().stream().anyMatch(e -> e.containsType(AggregateFunction.class))) {
// This is very weird pattern.
// There are some aggregate functions in having, but its child is project.
Expand All @@ -198,7 +218,7 @@ public List<Rule> buildRules() {
ImmutableList.of(), ImmutableList.of(), project.child());
// avoid throw exception even if having have slot from its child.
// because we will add a project between having and project.
Resolver resolver = new Resolver(agg, false);
Resolver resolver = new Resolver(agg, false, outerScope);
having.getConjuncts().forEach(resolver::resolve);
agg = agg.withAggOutput(resolver.getNewOutputSlots());
Set<Expression> newConjuncts = ExpressionUtils.replace(
Expand All @@ -212,7 +232,9 @@ public List<Rule> buildRules() {
Set<Slot> notExistedInProject = having.getExpressions().stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s))
.filter(s -> !projectOutputSet.contains(s)
&& (!outerScope.isPresent() || !outerScope.get()
.getCorrelatedSlots().contains(s)))
.collect(Collectors.toSet());
if (notExistedInProject.isEmpty()) {
return null;
Expand All @@ -235,18 +257,28 @@ static class Resolver {
private final List<NamedExpression> newOutputSlots = Lists.newArrayList();
private final Map<Slot, Expression> outputSubstitutionMap;
private final boolean checkSlot;
private final Optional<Scope> outerScope;

Resolver(Aggregate<?> aggregate, boolean checkSlot) {
Resolver(Aggregate<?> aggregate, boolean checkSlot, Optional<Scope> outerScope) {
outputExpressions = aggregate.getOutputExpressions();
groupByExpressions = aggregate.getGroupByExpressions();
outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance)
.collect(Collectors.toMap(NamedExpression::toSlot, alias -> alias.child(0),
(k1, k2) -> k1));
this.checkSlot = checkSlot;
this.outerScope = outerScope;
}

Resolver(Aggregate<?> aggregate, boolean checkSlot) {
this(aggregate, checkSlot, Optional.empty());
}

Resolver(Aggregate<?> aggregate) {
this(aggregate, true);
this(aggregate, true, Optional.empty());
}

Resolver(Aggregate<?> aggregate, Optional<Scope> outerScope) {
this(aggregate, true, outerScope);
}

public void resolve(Expression expression) {
Expand Down Expand Up @@ -274,7 +306,8 @@ public void resolve(Expression expression) {
// We couldn't find the equivalent expression in output expressions and group-by expressions,
// so we should check whether the expression is valid.
if (expression instanceof SlotReference) {
if (checkSlot) {
if (checkSlot && (!outerScope.isPresent()
|| !outerScope.get().getCorrelatedSlots().contains(expression))) {
Comment on lines +309 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments

throw new AnalysisException(expression.toSql() + " should be grouped by.");
}
} else if (expression instanceof AggregateFunction) {
Expand Down Expand Up @@ -401,8 +434,8 @@ private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
* @return filled up plan
*/
private Plan processDistinctProjectWithAggregate(LogicalSort<?> sort,
Aggregate<?> upperAggregate, Aggregate<Plan> bottomAggregate) {
Resolver resolver = new Resolver(bottomAggregate);
Aggregate<?> upperAggregate, Aggregate<Plan> bottomAggregate, Optional<Scope> outerScope) {
Resolver resolver = new Resolver(bottomAggregate, outerScope);
sort.getExpressions().forEach(resolver::resolve);
return createPlan(resolver, bottomAggregate, (r, a) -> {
List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
Expand Down
Loading
Loading