Skip to content

Commit

Permalink
add more case
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 committed Sep 10, 2024
1 parent 3111597 commit 03f5155
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
* Resolve having clause to the aggregation/repeat.
* need Top to Down to traverse plan,
* because we need to process FILL_UP_SORT_HAVING_AGGREGATE before FILL_UP_HAVING_AGGREGATE.
* be aware that when filling up the missing slots, we should exclude outer query's correlated slots.
* because these correlated slots belong to outer query, so should not try to find them in child node.
*/
public class FillUpMissingSlots implements AnalysisRuleFactory {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -122,27 +124,29 @@ public Expression visitScalarSubquery(ScalarSubquery scalar, T context) {
AnalyzedResult analyzedResult = analyzeSubquery(scalar);
boolean isCorrelated = analyzedResult.isCorrelated();
LogicalPlan analyzedSubqueryPlan = analyzedResult.logicalPlan;
checkOutputColumn(analyzedSubqueryPlan);
if (isCorrelated) {
if (analyzedSubqueryPlan instanceof LogicalLimit) {
if (ScalarSubquery.findTopLevelScalarAgg(analyzedResult.logicalPlan) == null) {
LogicalLimit limit = (LogicalLimit) analyzedSubqueryPlan;
if (limit.getOffset() == 0 && limit.getLimit() == 1) {
analyzedSubqueryPlan = (LogicalPlan) analyzedSubqueryPlan.child(0);
} else {
throw new AnalysisException("limit is not supported in correlated subquery "
+ analyzedResult.getLogicalPlan());
} else {
analyzedSubqueryPlan = (LogicalPlan) analyzedSubqueryPlan.child(0);
}
}
if (analyzedSubqueryPlan instanceof LogicalSort) {
// skip useless sort node
analyzedResult = new AnalyzedResult((LogicalPlan) analyzedSubqueryPlan.child(0),
analyzedResult.correlatedSlots);
}
CorrelatedSlotsValidator validator =
new CorrelatedSlotsValidator(ImmutableSet.copyOf(analyzedResult.correlatedSlots));
List<PlanNodeCorrelatedInfo> nodeInfoList = new ArrayList<>(16);
Set<LogicalAggregate> topAgg = new HashSet<>();
validateSubquery(analyzedResult.logicalPlan, validator, nodeInfoList, topAgg);
}
checkOutputColumn(analyzedResult.getLogicalPlan());
checkHasNoGroupBy(analyzedResult);
checkNoCorrelatedSlotsInAgg(analyzedResult);
checkNoCorrelatedSlotsUnderJoin(analyzedResult);

LogicalPlan subqueryPlan = analyzedResult.getLogicalPlan();
if (analyzedResult.getLogicalPlan() instanceof LogicalProject) {
LogicalProject project = (LogicalProject) analyzedResult.getLogicalPlan();
if (project.child() instanceof LogicalOneRowRelation
Expand All @@ -155,10 +159,6 @@ public Expression visitScalarSubquery(ScalarSubquery scalar, T context) {
return alias.child();
}
} else if (isCorrelated) {
if (ExpressionUtils.containsWindowExpression(project.getProjects())) {
throw new AnalysisException("window function is not supported in correlated subquery's output "
+ analyzedResult.getLogicalPlan());
}
Set<Slot> correlatedSlots = new HashSet<>(analyzedResult.getCorrelatedSlots());
if (!Sets.intersection(ExpressionUtils.getInputSlotSet(project.getProjects()),
correlatedSlots).isEmpty()) {
Expand All @@ -169,7 +169,7 @@ public Expression visitScalarSubquery(ScalarSubquery scalar, T context) {
}
}

return new ScalarSubquery(subqueryPlan, analyzedResult.getCorrelatedSlots());
return new ScalarSubquery(analyzedResult.getLogicalPlan(), analyzedResult.getCorrelatedSlots());
}

private void checkOutputColumn(LogicalPlan plan) {
Expand All @@ -179,16 +179,6 @@ private void checkOutputColumn(LogicalPlan plan) {
}
}

private void checkHasNoGroupBy(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;
}
if (analyzedResult.hasGroupBy()) {
throw new AnalysisException("Unsupported correlated subquery with grouping and/or aggregation "
+ analyzedResult.getLogicalPlan());
}
}

private void checkNoCorrelatedSlotsUnderAgg(AnalyzedResult analyzedResult) {
if (analyzedResult.hasCorrelatedSlotsUnderAgg()) {
throw new AnalysisException(
Expand All @@ -197,22 +187,6 @@ private void checkNoCorrelatedSlotsUnderAgg(AnalyzedResult analyzedResult) {
}
}

private void checkNoCorrelatedSlotsUnderJoin(AnalyzedResult analyzedResult) {
if (analyzedResult.hasCorrelatedSlotsUnderJoin()) {
throw new AnalysisException(
String.format("Unsupported accesss outer join's column under join operator : %s",
analyzedResult.getCorrelatedSlots()));
}
}

private void checkNoCorrelatedSlotsInAgg(AnalyzedResult analyzedResult) {
if (analyzedResult.hasCorrelatedSlotsInAgg()) {
throw new AnalysisException(String.format(
"outer query's column is not supported in subquery's aggregation operator : %s",
analyzedResult.getCorrelatedSlots()));
}
}

private void checkRootIsLimit(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;
Expand Down Expand Up @@ -271,61 +245,14 @@ public boolean isCorrelated() {
return !correlatedSlots.isEmpty();
}

public boolean hasAgg() {
return logicalPlan.anyMatch(LogicalAggregate.class::isInstance);
}

public boolean hasGroupBy() {
if (hasAgg()) {
return !((LogicalAggregate)
((ImmutableSet) logicalPlan.collect(LogicalAggregate.class::isInstance)).asList().get(0))
.getGroupByExpressions().isEmpty();
}
return false;
}

public boolean hasCorrelatedSlotsUnderAgg() {
return correlatedSlots.isEmpty() ? false
: findCorrelatedSlotsUnderNode(logicalPlan,
: hasCorrelatedSlotsUnderNode(logicalPlan,
ImmutableSet.copyOf(correlatedSlots), LogicalAggregate.class);
}

public boolean hasCorrelatedSlotsUnderJoin() {
return correlatedSlots.isEmpty() ? false
: findCorrelatedSlotsUnderNode(logicalPlan,
ImmutableSet.copyOf(correlatedSlots), LogicalJoin.class);
}

public boolean hasCorrelatedSlotsInAgg() {
return correlatedSlots.isEmpty() ? false
: findCorrelatedSlotsInNode(logicalPlan, ImmutableSet.copyOf(correlatedSlots),
LogicalAggregate.class);
}

private static <T> boolean findCorrelatedSlotsInNode(Plan rootPlan,
ImmutableSet<Slot> slots, Class<T> clazz) {
ArrayDeque<Plan> planQueue = new ArrayDeque<>();
planQueue.add(rootPlan);
while (!planQueue.isEmpty()) {
Plan plan = planQueue.poll();
if (plan.getClass().equals(clazz)) {
if (!Sets
.intersection(slots,
ExpressionUtils.getInputSlotSet(plan.getExpressions()))
.isEmpty()) {
return true;
}
} else {
for (Plan child : plan.children()) {
planQueue.add(child);
}
}
}
return false;
}

private static <T> boolean findCorrelatedSlotsUnderNode(Plan rootPlan,
ImmutableSet<Slot> slots, Class<T> clazz) {
private static <T> boolean hasCorrelatedSlotsUnderNode(Plan rootPlan,
ImmutableSet<Slot> slots, Class<T> clazz) {
ArrayDeque<Plan> planQueue = new ArrayDeque<>();
planQueue.add(rootPlan);
while (!planQueue.isEmpty()) {
Expand Down Expand Up @@ -355,4 +282,171 @@ public boolean rootIsLimitZero() {
return logicalPlan instanceof LogicalLimit && ((LogicalLimit<?>) logicalPlan).getLimit() == 0;
}
}

private static class PlanNodeCorrelatedInfo {
private PlanType planType;
private boolean containCorrelatedSlots;
private boolean hasGroupBy;
private LogicalAggregate aggregate;

public PlanNodeCorrelatedInfo(PlanType planType, boolean containCorrelatedSlots) {
this(planType, containCorrelatedSlots, null);
}

public PlanNodeCorrelatedInfo(PlanType planType, boolean containCorrelatedSlots,
LogicalAggregate aggregate) {
this.planType = planType;
this.containCorrelatedSlots = containCorrelatedSlots;
this.aggregate = aggregate;
this.hasGroupBy = aggregate != null ? !aggregate.getGroupByExpressions().isEmpty() : false;
}
}

private static class CorrelatedSlotsValidator
extends PlanVisitor<PlanNodeCorrelatedInfo, Void> {
private final ImmutableSet<Slot> correlatedSlots;

public CorrelatedSlotsValidator(ImmutableSet<Slot> correlatedSlots) {
this.correlatedSlots = correlatedSlots;
}

@Override
public PlanNodeCorrelatedInfo visit(Plan plan, Void context) {
return new PlanNodeCorrelatedInfo(plan.getType(), findCorrelatedSlots(plan));
}

public PlanNodeCorrelatedInfo visitLogicalProject(LogicalProject plan, Void context) {
boolean containCorrelatedSlots = findCorrelatedSlots(plan);
if (containCorrelatedSlots) {
throw new AnalysisException(
String.format("access outer query's column in project is not supported",
correlatedSlots));
} else {
PlanType planType = ExpressionUtils.containsWindowExpression(
((LogicalProject<?>) plan).getProjects()) ? PlanType.LOGICAL_WINDOW : plan.getType();
return new PlanNodeCorrelatedInfo(planType, false);
}
}

public PlanNodeCorrelatedInfo visitLogicalAggregate(LogicalAggregate plan, Void context) {
boolean containCorrelatedSlots = findCorrelatedSlots(plan);
if (containCorrelatedSlots) {
throw new AnalysisException(
String.format("access outer query's column in aggregate is not supported",
correlatedSlots, plan));
} else {
return new PlanNodeCorrelatedInfo(plan.getType(), false, plan);
}
}

public PlanNodeCorrelatedInfo visitLogicalJoin(LogicalJoin plan, Void context) {
boolean containCorrelatedSlots = findCorrelatedSlots(plan);
if (containCorrelatedSlots) {
throw new AnalysisException(
String.format("access outer query's column in join is not supported",
correlatedSlots, plan));
} else {
return new PlanNodeCorrelatedInfo(plan.getType(), false);
}
}

public PlanNodeCorrelatedInfo visitLogicalSort(LogicalSort plan, Void context) {
boolean containCorrelatedSlots = findCorrelatedSlots(plan);
if (containCorrelatedSlots) {
throw new AnalysisException(
String.format("access outer query's column in order by is not supported",
correlatedSlots, plan));
} else {
return new PlanNodeCorrelatedInfo(plan.getType(), false);
}
}

private boolean findCorrelatedSlots(Plan plan) {
return plan.getExpressions().stream().anyMatch(expression -> !Sets
.intersection(correlatedSlots, expression.getInputSlots()).isEmpty());
}
}

private LogicalAggregate validateNodeInfoList(List<PlanNodeCorrelatedInfo> nodeInfoList) {
LogicalAggregate topAggregate = null;
int size = nodeInfoList.size();
if (size > 0) {
List<PlanNodeCorrelatedInfo> correlatedNodes = new ArrayList<>(4);
boolean checkNodeTypeAfterCorrelatedNode = false;
boolean checkAfterAggNode = false;
for (int i = size - 1; i >= 0; --i) {
PlanNodeCorrelatedInfo nodeInfo = nodeInfoList.get(i);
if (checkNodeTypeAfterCorrelatedNode) {
switch (nodeInfo.planType) {
case LOGICAL_LIMIT:
throw new AnalysisException(
"limit is not supported in correlated subquery");
case LOGICAL_GENERATE:
throw new AnalysisException(
"access outer query's column before lateral view is not supported");
case LOGICAL_AGGREGATE:
if (checkAfterAggNode) {
throw new AnalysisException(
"access outer query's column before two agg nodes is not supported");
}
if (nodeInfo.hasGroupBy) {
// TODO support later
throw new AnalysisException(
"access outer query's column before agg with group by is not supported");
}
checkAfterAggNode = true;
topAggregate = nodeInfo.aggregate;
break;
case LOGICAL_WINDOW:
throw new AnalysisException(
"access outer query's column before window function is not supported");
case LOGICAL_JOIN:
throw new AnalysisException(
"access outer query's column before join is not supported");
case LOGICAL_SORT:
// allow any sort node, the sort node will be removed by ELIMINATE_ORDER_BY_UNDER_SUBQUERY
break;
case LOGICAL_PROJECT:
// allow any project node
break;
case LOGICAL_SUBQUERY_ALIAS:
// allow any subquery alias
break;
default:
if (checkAfterAggNode) {
throw new AnalysisException(
"only project, sort and subquery alias node is allowed after agg node");
}
break;
}
}
if (nodeInfo.containCorrelatedSlots) {
correlatedNodes.add(nodeInfo);
checkNodeTypeAfterCorrelatedNode = true;
}
}

// only support 1 correlated node for now
if (correlatedNodes.size() > 1) {
throw new AnalysisException(
"access outer query's column in two places is not supported");
}
}
return topAggregate;
}

private void validateSubquery(Plan plan, CorrelatedSlotsValidator validator,
List<PlanNodeCorrelatedInfo> nodeInfoList, Set<LogicalAggregate> topAgg) {
nodeInfoList.add(plan.accept(validator, null));
for (Plan child : plan.children()) {
validateSubquery(child, validator, nodeInfoList, topAgg);
}
if (plan.children().isEmpty()) {
LogicalAggregate topAggNode = validateNodeInfoList(nodeInfoList);
if (topAggNode != null) {
topAgg.add(topAggNode);
}
}
nodeInfoList.remove(nodeInfoList.size() - 1);
}
}
Loading

0 comments on commit 03f5155

Please sign in to comment.