From 0f8176dee0a15893323b58b9e631dd9f45f1b6fb Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Thu, 12 Sep 2024 14:08:50 +0800 Subject: [PATCH] [fix](nereids) build agg for random distributed agg table in bindRelation phase (#40181) (#40702) pick from master #40181 --- .../doris/nereids/jobs/executor/Analyzer.java | 3 - .../apache/doris/nereids/rules/RuleType.java | 4 - .../nereids/rules/analysis/BindRelation.java | 151 +++++++++- .../BuildAggForRandomDistributedTable.java | 271 ------------------ .../nereids/rules/analysis/CheckPolicy.java | 21 +- .../rules/analysis/BindRelationTest.java | 31 +- .../rules/analysis/CheckRowPolicyTest.java | 97 +++++++ .../select_random_distributed_tbl.out | 14 +- .../select_random_distributed_tbl.groovy | 19 +- 9 files changed, 302 insertions(+), 309 deletions(-) delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 605a848181c16f..1ffbac97d741a4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindRelation; import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver; import org.apache.doris.nereids.rules.analysis.BindSink; -import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable; import org.apache.doris.nereids.rules.analysis.CheckAfterBind; import org.apache.doris.nereids.rules.analysis.CheckAnalysis; import org.apache.doris.nereids.rules.analysis.CheckPolicy; @@ -163,8 +162,6 @@ private static List buildAnalyzerJobs(Optional topDown(new EliminateGroupByConstant()), topDown(new SimplifyAggGroupBy()), - // run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan - topDown(new BuildAggForRandomDistributedTable()), topDown(new NormalizeAggregate()), topDown(new HavingToFilter()), bottomUp(new SemiJoinCommute()), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index db0f0703dcb9a0..082ee72fbed220 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -336,10 +336,6 @@ public enum RuleType { // topn opts DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE), - // pre agg for random distributed table - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE), - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE), - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE), // short circuit rule SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE), // exploration rules diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index e6f550305e3a9d..67000b3fee997f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -17,10 +17,17 @@ package org.apache.doris.nereids.rules.analysis; +import org.apache.doris.catalog.AggStateType; +import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.DistributionInfo; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.catalog.KeysType; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Partition; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.Type; import org.apache.doris.catalog.View; import org.apache.doris.common.Config; import org.apache.doris.common.Pair; @@ -44,13 +51,26 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.algebra.Relation; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan; import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; @@ -74,6 +94,7 @@ import com.google.common.collect.Sets; import org.apache.commons.collections.CollectionUtils; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -215,25 +236,127 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation, unboundRelation.getTableSample()); } } - if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign() - && !ConnectContext.get().getSessionVariable().skipDeleteSign()) { - // table qualifier is catalog.db.table, we make db.table.column - Slot deleteSlot = null; - for (Slot slot : scan.getOutput()) { - if (slot.getName().equals(Column.DELETE_SIGN)) { - deleteSlot = slot; + if (needGenerateLogicalAggForRandomDistAggTable(scan)) { + // it's a random distribution agg table + // add agg on olap scan + return preAggForRandomDistribution(scan); + } else { + // it's a duplicate, unique or hash distribution agg table + // add delete sign filter on olap scan if needed + if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign() + && !ConnectContext.get().getSessionVariable().skipDeleteSign()) { + // table qualifier is catalog.db.table, we make db.table.column + Slot deleteSlot = null; + for (Slot slot : scan.getOutput()) { + if (slot.getName().equals(Column.DELETE_SIGN)) { + deleteSlot = slot; + break; + } + } + Preconditions.checkArgument(deleteSlot != null); + Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot); + if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) { + scan = scan.withPreAggStatus( + PreAggStatus.off(Column.DELETE_SIGN + " is used as conjuncts.")); + } + return new LogicalFilter<>(Sets.newHashSet(conjunct), scan); + } + return scan; + } + } + + private boolean needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) { + if (ConnectContext.get() != null && ConnectContext.get().getState() != null + && ConnectContext.get().getState().isQuery()) { + // we only need to add an agg node for query, and should not do it for deleting + // from random distributed table. see https://github.com/apache/doris/pull/37985 for more info + OlapTable olapTable = olapScan.getTable(); + KeysType keysType = olapTable.getKeysType(); + DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo(); + return keysType == KeysType.AGG_KEYS + && distributionInfo.getType() == DistributionInfo.DistributionInfoType.RANDOM; + } else { + return false; + } + } + + /** + * add LogicalAggregate above olapScan for preAgg + * @param olapScan olap scan plan + * @return rewritten plan + */ + private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) { + OlapTable olapTable = olapScan.getTable(); + List childOutputSlots = olapScan.computeOutput(); + List groupByExpressions = new ArrayList<>(); + List outputExpressions = new ArrayList<>(); + List columns = olapTable.getBaseSchema(); + + for (Column col : columns) { + // use exist slot in the plan + SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.qualified()); + ExprId exprId = slot.getExprId(); + for (Slot childSlot : childOutputSlots) { + if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) { + exprId = childSlot.getExprId(); + slot = slot.withExprId(exprId); break; } } - Preconditions.checkArgument(deleteSlot != null); - Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot); - if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) { - scan = scan.withPreAggStatus(PreAggStatus.off( - Column.DELETE_SIGN + " is used as conjuncts.")); + if (col.isKey()) { + groupByExpressions.add(slot); + outputExpressions.add(slot); + } else { + Expression function = generateAggFunction(slot, col); + // DO NOT rewrite + if (function == null) { + return olapScan; + } + Alias alias = new Alias(exprId, ImmutableList.of(function), col.getName(), + olapScan.qualified(), true); + outputExpressions.add(alias); } - return new LogicalFilter<>(Sets.newHashSet(conjunct), scan); } - return scan; + LogicalAggregate aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions, + olapScan); + return aggregate; + } + + /** + * generate aggregation function according to the aggType of column + * + * @param slot slot of column + * @return aggFunction generated + */ + private Expression generateAggFunction(SlotReference slot, Column column) { + AggregateType aggregateType = column.getAggregationType(); + switch (aggregateType) { + case SUM: + return new Sum(slot); + case MAX: + return new Max(slot); + case MIN: + return new Min(slot); + case HLL_UNION: + return new HllUnion(slot); + case BITMAP_UNION: + return new BitmapUnion(slot); + case QUANTILE_UNION: + return new QuantileUnion(slot); + case GENERIC: + Type type = column.getType(); + if (!type.isAggStateType()) { + return null; + } + AggStateType aggState = (AggStateType) type; + // use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one + String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX; + FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry(); + FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot); + return builder.build(funcName, ImmutableList.of(slot)).first; + default: + return null; + } } private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java deleted file mode 100644 index e547a55f9e39fe..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java +++ /dev/null @@ -1,271 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.analysis; - -import org.apache.doris.catalog.AggStateType; -import org.apache.doris.catalog.AggregateType; -import org.apache.doris.catalog.Column; -import org.apache.doris.catalog.DistributionInfo; -import org.apache.doris.catalog.DistributionInfo.DistributionInfoType; -import org.apache.doris.catalog.Env; -import org.apache.doris.catalog.FunctionRegistry; -import org.apache.doris.catalog.KeysType; -import org.apache.doris.catalog.OlapTable; -import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; - -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - -/** - * build agg plan for querying random distributed table - */ -public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory { - - @Override - public List buildRules() { - return ImmutableList.of( - // Project(Scan) -> project(agg(scan)) - logicalProject(logicalOlapScan()) - .when(this::isQuery) - .when(project -> isRandomDistributedTbl(project.child())) - .then(project -> preAggForRandomDistribution(project, project.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN), - // agg(scan) -> agg(agg(scan)), agg(agg) may optimized by MergeAggregate - logicalAggregate(logicalOlapScan()) - .when(this::isQuery) - .when(agg -> isRandomDistributedTbl(agg.child())) - .whenNot(agg -> { - Set functions = agg.getAggregateFunctions(); - List groupByExprs = agg.getGroupByExpressions(); - // check if need generate an inner agg plan or not - // should not rewrite twice if we had rewritten olapScan to aggregate(olapScan) - return functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream() - .allMatch(this::isKeyOrConstantExpr); - }) - .then(agg -> preAggForRandomDistribution(agg, agg.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN), - // filter(scan) -> filter(agg(scan)) - logicalFilter(logicalOlapScan()) - .when(this::isQuery) - .when(filter -> isRandomDistributedTbl(filter.child())) - .then(filter -> preAggForRandomDistribution(filter, filter.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN)); - - } - - /** - * check the olapTable of olapScan is randomDistributed table - * - * @param olapScan olap scan plan - * @return true if olapTable is randomDistributed table - */ - private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) { - OlapTable olapTable = olapScan.getTable(); - KeysType keysType = olapTable.getKeysType(); - DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo(); - return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == DistributionInfoType.RANDOM; - } - - private boolean isQuery(LogicalPlan plan) { - return ConnectContext.get() != null - && ConnectContext.get().getState() != null - && ConnectContext.get().getState().isQuery(); - } - - /** - * add LogicalAggregate above olapScan for preAgg - * - * @param logicalPlan parent plan of olapScan - * @param olapScan olap scan plan, it may be LogicalProject, LogicalFilter, LogicalAggregate - * @return rewritten plan - */ - private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, LogicalOlapScan olapScan) { - OlapTable olapTable = olapScan.getTable(); - List childOutputSlots = olapScan.computeOutput(); - List groupByExpressions = new ArrayList<>(); - List outputExpressions = new ArrayList<>(); - List columns = olapTable.getBaseSchema(); - - for (Column col : columns) { - // use exist slot in the plan - SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.getQualifier()); - ExprId exprId = slot.getExprId(); - for (Slot childSlot : childOutputSlots) { - if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) { - exprId = childSlot.getExprId(); - slot = slot.withExprId(exprId); - break; - } - } - if (col.isKey()) { - groupByExpressions.add(slot); - outputExpressions.add(slot); - } else { - Expression function = generateAggFunction(slot, col); - // DO NOT rewrite - if (function == null) { - return logicalPlan; - } - Alias alias = new Alias(exprId, function, col.getName()); - outputExpressions.add(alias); - } - } - LogicalAggregate aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions, - olapScan); - return logicalPlan.withChildren(aggregate); - } - - /** - * generate aggregation function according to the aggType of column - * - * @param slot slot of column - * @return aggFunction generated - */ - private Expression generateAggFunction(SlotReference slot, Column column) { - AggregateType aggregateType = column.getAggregationType(); - switch (aggregateType) { - case SUM: - return new Sum(slot); - case MAX: - return new Max(slot); - case MIN: - return new Min(slot); - case HLL_UNION: - return new HllUnion(slot); - case BITMAP_UNION: - return new BitmapUnion(slot); - case QUANTILE_UNION: - return new QuantileUnion(slot); - case GENERIC: - Type type = column.getType(); - if (!type.isAggStateType()) { - return null; - } - AggStateType aggState = (AggStateType) type; - // use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one - String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX; - FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry(); - FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot); - return builder.build(funcName, ImmutableList.of(slot)).first; - default: - return null; - } - } - - /** - * if the agg type of AggregateFunction is as same as the agg type of column, DO NOT need to rewrite - * - * @param function agg function to check - * @return true if agg type match - */ - private boolean aggTypeMatch(AggregateFunction function) { - List children = function.children(); - if (function.getName().equalsIgnoreCase("count")) { - Count count = (Count) function; - // do not rewrite for count distinct for key column - if (count.isDistinct()) { - return children.stream().allMatch(this::isKeyOrConstantExpr); - } - if (count.isStar()) { - return false; - } - } - return children.stream().allMatch(child -> aggTypeMatch(function, child)); - } - - /** - * check if the agg type of functionCall match the agg type of column - * - * @param function the functionCall - * @param expression expr to check - * @return true if agg type match - */ - private boolean aggTypeMatch(AggregateFunction function, Expression expression) { - if (expression.children().isEmpty()) { - if (expression instanceof SlotReference && ((SlotReference) expression).getColumn().isPresent()) { - Column col = ((SlotReference) expression).getColumn().get(); - String functionName = function.getName(); - if (col.isKey()) { - return functionName.equalsIgnoreCase("max") || functionName.equalsIgnoreCase("min"); - } - if (col.isAggregated()) { - AggregateType aggType = col.getAggregationType(); - // agg type not mach - if (aggType == AggregateType.GENERIC) { - return col.getType().isAggStateType(); - } - if (aggType == AggregateType.HLL_UNION) { - return function instanceof HllFunction; - } - if (aggType == AggregateType.BITMAP_UNION) { - return function instanceof BitmapFunction; - } - return functionName.equalsIgnoreCase(aggType.name()); - } - } - return false; - } - List children = expression.children(); - return children.stream().allMatch(child -> aggTypeMatch(function, child)); - } - - /** - * check if the columns in expr is key column or constant, if group by clause contains value column, need rewrite - * - * @param expr expr to check - * @return true if all columns is key column or constant - */ - private boolean isKeyOrConstantExpr(Expression expr) { - if (expr instanceof SlotReference && ((SlotReference) expr).getColumn().isPresent()) { - Column col = ((SlotReference) expr).getColumn().get(); - return col.isKey(); - } else if (expr.isConstant()) { - return true; - } - List children = expr.children(); - return children.stream().allMatch(this::isKeyOrConstantExpr); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java index 94f7c36b10827d..4beed413d0908e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; @@ -49,12 +50,23 @@ public List buildRules() { logicalCheckPolicy(any().when(child -> !(child instanceof UnboundRelation))).thenApply(ctx -> { LogicalCheckPolicy checkPolicy = ctx.root; LogicalFilter upperFilter = null; + Plan upAgg = null; Plan child = checkPolicy.child(); // Because the unique table will automatically include a filter condition - if (child instanceof LogicalFilter && child.bound() && child - .child(0) instanceof LogicalRelation) { + if ((child instanceof LogicalFilter) && child.bound()) { upperFilter = (LogicalFilter) child; + if (child.child(0) instanceof LogicalRelation) { + child = child.child(0); + } else if (child.child(0) instanceof LogicalAggregate + && child.child(0).child(0) instanceof LogicalRelation) { + upAgg = child.child(0); + child = child.child(0).child(0); + } + } + if ((child instanceof LogicalAggregate) + && child.bound() && child.child(0) instanceof LogicalRelation) { + upAgg = child; child = child.child(0); } if (!(child instanceof LogicalRelation) @@ -76,16 +88,17 @@ public List buildRules() { RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext); relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll( ExpressionUtils.extractConjunctionToSet(expression))); - Plan result = relation; + Plan result = upAgg != null ? upAgg.withChildren(relation) : relation; if (upperFilter != null) { combineFilter.addAll(upperFilter.getConjuncts()); } if (!combineFilter.isEmpty()) { - result = new LogicalFilter<>(combineFilter, relation); + result = new LogicalFilter<>(combineFilter, result); } if (relatedPolicy.dataMaskProjects.isPresent()) { result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result); } + return result; }) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java index 0926d8024cec59..67115e676871b4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanRewriter; @@ -54,6 +55,12 @@ protected void runBeforeAll() throws Exception { + ")ENGINE=OLAP\n" + "DISTRIBUTED BY HASH(`a`) BUCKETS 3\n" + "PROPERTIES (\"replication_num\"= \"1\");"); + createTable("CREATE TABLE db1.tagg ( \n" + + " \ta INT,\n" + + " \tb INT SUM\n" + + ")ENGINE=OLAP AGGREGATE KEY(a)\n " + + "DISTRIBUTED BY random BUCKETS 3\n" + + "PROPERTIES (\"replication_num\"= \"1\");"); connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); } @@ -117,14 +124,30 @@ public boolean hasDeleteSign() { .customAnalyzer(Optional.of(customTableResolver)) // analyze internal relation .matches( logicalJoin( - logicalSubQueryAlias( - logicalOlapScan().when(r -> r.getTable() == externalOlapTable) - ), - logicalOlapScan().when(r -> r.getTable().getName().equals("t")) + logicalSubQueryAlias( + logicalOlapScan().when(r -> r.getTable() == externalOlapTable) + ), + logicalOlapScan().when(r -> r.getTable().getName().equals("t")) ) ); } + @Test + void bindRandomAggTable() { + connectContext.setDatabase(DEFAULT_CLUSTER_PREFIX + DB1); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("tagg")), + connectContext, new BindRelation()); + + Assertions.assertTrue(plan instanceof LogicalAggregate); + Assertions.assertEquals( + ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, "tagg"), + plan.getOutput().get(0).getQualifier()); + Assertions.assertEquals( + ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, "tagg"), + plan.getOutput().get(1).getQualifier()); + } + @Override public RulePromise defaultPromise() { return RulePromise.REWRITE; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java index 196d99037e2e63..b807bbbbc7a4bd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java @@ -34,6 +34,9 @@ import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; import org.apache.doris.common.FeConstants; +import org.apache.doris.mysql.privilege.AccessControllerManager; +import org.apache.doris.mysql.privilege.DataMaskPolicy; +import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; @@ -41,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.util.PlanRewriter; import org.apache.doris.thrift.TStorageType; @@ -48,17 +52,22 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; +import java.util.Optional; public class CheckRowPolicyTest extends TestWithFeService { private static String dbName = "check_row_policy"; private static String fullDbName = "" + dbName; private static String tableName = "table1"; + + private static String tableNameRanddomDist = "tableRandomDist"; private static String userName = "user1"; private static String policyName = "policy1"; @@ -76,6 +85,10 @@ protected void runBeforeAll() throws Exception { + tableName + " (k1 int, k2 int) distributed by hash(k1) buckets 1" + " properties(\"replication_num\" = \"1\");"); + createTable("create table " + + tableNameRanddomDist + + " (k1 int, k2 int) AGGREGATE KEY(k1, k2) distributed by random buckets 1" + + " properties(\"replication_num\" = \"1\");"); Database db = Env.getCurrentInternalCatalog().getDbOrMetaException(fullDbName); long tableId = db.getTableOrMetaException("table1").getId(); olapTable.setId(tableId); @@ -85,6 +98,7 @@ protected void runBeforeAll() throws Exception { 0, 0, (short) 0, TStorageType.COLUMN, KeysType.PRIMARY_KEYS); + // create user UserIdentity user = new UserIdentity(userName, "%"); user.analyze(); @@ -98,6 +112,27 @@ protected void runBeforeAll() throws Exception { Analyzer analyzer = new Analyzer(connectContext.getEnv(), connectContext); grantStmt.analyze(analyzer); Env.getCurrentEnv().getAuth().grant(grantStmt); + + new MockUp() { + @Mock + public Optional evalDataMaskPolicy(UserIdentity currentUser, String ctl, + String db, String tbl, String col) { + return tbl.equalsIgnoreCase(tableNameRanddomDist) + ? Optional.of(new DataMaskPolicy() { + @Override + public String getMaskTypeDef() { + return String.format("concat(%s, '_****_', %s)", col, col); + } + + @Override + public String getPolicyIdent() { + return String.format("custom policy: concat(%s, '_****_', %s)", col, + col); + } + }) + : Optional.empty(); + } + }; } @Test @@ -115,6 +150,24 @@ public void checkUser() throws AnalysisException, org.apache.doris.common.Analys Assertions.assertEquals(plan, relation); } + @Test + public void checkUserRandomDist() throws AnalysisException, org.apache.doris.common.AnalysisException { + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + + useUser("root"); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan); + + useUser("notFound"); + rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan.child(0)); + } + @Test public void checkNoPolicy() throws org.apache.doris.common.AnalysisException { useUser(userName); @@ -125,6 +178,18 @@ public void checkNoPolicy() throws org.apache.doris.common.AnalysisException { Assertions.assertEquals(plan, relation); } + @Test + public void checkNoPolicyRandomDist() throws org.apache.doris.common.AnalysisException { + useUser(userName); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan.child(0)); + } + @Test public void checkOnePolicy() throws Exception { useUser(userName); @@ -152,4 +217,36 @@ public void checkOnePolicy() throws Exception { + " ON " + tableName); } + + @Test + public void checkOnePolicyRandomDist() throws Exception { + useUser(userName); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + connectContext.getSessionVariable().setEnableNereidsPlanner(true); + createPolicy("CREATE ROW POLICY " + + policyName + + " ON " + + tableNameRanddomDist + + " AS PERMISSIVE TO " + + userName + + " USING (k1 = 1)"); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + + Assertions.assertTrue(rewrittenPlan instanceof LogicalProject + && rewrittenPlan.child(0) instanceof LogicalFilter); + LogicalFilter filter = (LogicalFilter) rewrittenPlan.child(0); + Assertions.assertEquals(filter.child(), plan); + Assertions.assertTrue(ImmutableList.copyOf(filter.getConjuncts()).get(0) instanceof EqualTo); + Assertions.assertTrue(filter.getConjuncts().toString().contains("k1#0 = 1")); + + dropPolicy("DROP ROW POLICY " + + policyName + + " ON " + + tableNameRanddomDist); + } } diff --git a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out index c03e72c8f9e398..eb099225960a2b 100644 --- a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out +++ b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out @@ -217,13 +217,25 @@ -- !sql_17 -- 1 +3 -- !sql_18 -- 1 +3 -- !sql_19 -- -1 +999999999999999.99 +1999999999999999.98 -- !sql_20 -- 1 +3 + +-- !sql_21 -- +1 +3 + +-- !sql_22 -- +999999999999999.99 +1999999999999999.98 diff --git a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy index c818454c261262..5c99a0a4aa02de 100644 --- a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy +++ b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy @@ -123,7 +123,8 @@ suite("select_random_distributed_tbl") { // test all keys are NOT NULL for AGG table sql "drop table if exists random_distributed_tbl_test_2;" sql """ CREATE TABLE random_distributed_tbl_test_2 ( - `k1` LARGEINT NOT NULL + `k1` LARGEINT NOT NULL, + `k2` DECIMAL(18, 2) SUM NOT NULL ) ENGINE=OLAP AGGREGATE KEY(`k1`) COMMENT 'OLAP' @@ -133,17 +134,19 @@ suite("select_random_distributed_tbl") { ); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ + sql """ insert into random_distributed_tbl_test_2 values(1, 999999999999999.99); """ + sql """ insert into random_distributed_tbl_test_2 values(1, 999999999999999.99); """ + sql """ insert into random_distributed_tbl_test_2 values(3, 999999999999999.99); """ sql "set enable_nereids_planner = false;" - qt_sql_17 "select k1 from random_distributed_tbl_test_2;" - qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2;" + qt_sql_17 "select k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_19 "select k2 from random_distributed_tbl_test_2 order by k2;" sql "set enable_nereids_planner = true;" - qt_sql_19 "select k1 from random_distributed_tbl_test_2;" - qt_sql_20 "select distinct k1 from random_distributed_tbl_test_2;" + qt_sql_20 "select k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_21 "select distinct k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_22 "select k2 from random_distributed_tbl_test_2 order by k2;" sql "drop table random_distributed_tbl_test_2;" }