diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 19d38873f9a014..ebef71feb850b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -199,7 +199,7 @@ public synchronized boolean isTimeout() { } public void toMemo() { - this.memo = new Memo(plan); + this.memo = new Memo(getConnectContext(), plan); } public Analyzer newAnalyzer() { @@ -358,7 +358,7 @@ public TableIf getTableInMinidumpCache(String tableName) { return table; } } - if (ConnectContext.get().getSessionVariable().isPlayNereidsDump()) { + if (getConnectContext().getSessionVariable().isPlayNereidsDump()) { throw new AnalysisException("Minidump cache can not find table:" + tableName); } return null; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java index 07a2cf4eed2e7a..f560dabf8b34b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java @@ -19,6 +19,8 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import java.util.ArrayList; @@ -31,7 +33,7 @@ * Inspired by GPORCA-CExpressionHandle. */ public class PlanContext { - + private final ConnectContext connectContext; private final List childrenStats; private final Statistics planStats; private final int arity; @@ -41,7 +43,8 @@ public class PlanContext { /** * Constructor for PlanContext. */ - public PlanContext(GroupExpression groupExpression) { + public PlanContext(ConnectContext connectContext, GroupExpression groupExpression) { + this.connectContext = connectContext; this.arity = groupExpression.arity(); this.planStats = groupExpression.getOwnerGroup().getStatistics(); this.isStatsReliable = groupExpression.getOwnerGroup().isStatsReliable(); @@ -51,12 +54,8 @@ public PlanContext(GroupExpression groupExpression) { } } - // This is used in GraphSimplifier - public PlanContext(Statistics planStats, List childrenStats) { - this.planStats = planStats; - this.childrenStats = childrenStats; - this.isStatsReliable = false; - this.arity = this.childrenStats.size(); + public SessionVariable getSessionVariable() { + return connectContext.getSessionVariable(); } public void setBroadcastJoin() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java index 1e0e448dd69f45..7b7c6776aa4c31 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/Cost.java @@ -17,7 +17,7 @@ package org.apache.doris.nereids.cost; -import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; /** * Cost encapsulate the real cost with double type. @@ -27,21 +27,22 @@ public interface Cost { double getValue(); - /** - * return zero cost - */ - static Cost zero() { - if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) { + static Cost zero(SessionVariable sessionVariable) { + if (sessionVariable.getEnableNewCostModel()) { return CostV2.zero(); } return CostV1.zero(); } - static Cost infinite() { - if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) { + static Cost infinite(SessionVariable sessionVariable) { + if (sessionVariable.getEnableNewCostModel()) { return CostV2.infinite(); } return CostV1.infinite(); } + + static Cost zeroV1() { + return CostV1.zero(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java index 7beb65239bda1e..2bcc6e2a79a5d1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostCalculator.java @@ -18,53 +18,42 @@ package org.apache.doris.nereids.cost; import org.apache.doris.nereids.PlanContext; -import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.DistributionSpecReplicated; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import java.util.List; /** * Calculate the cost of a plan. */ -@Developing -//TODO: memory cost and network cost should be estimated by byte size. +// TODO: memory cost and network cost should be estimated by byte size. public class CostCalculator { /** * Calculate cost for groupExpression */ - public static Cost calculateCost(GroupExpression groupExpression, List childrenProperties) { - PlanContext planContext = new PlanContext(groupExpression); + public static Cost calculateCost(ConnectContext connectContext, GroupExpression groupExpression, + List childrenProperties) { + PlanContext planContext = new PlanContext(connectContext, groupExpression); if (childrenProperties.size() >= 2 && childrenProperties.get(1).getDistributionSpec() instanceof DistributionSpecReplicated) { planContext.setBroadcastJoin(); } - CostModelV1 costModelV1 = new CostModelV1(); + CostModelV1 costModelV1 = new CostModelV1(connectContext); return groupExpression.getPlan().accept(costModelV1, planContext); } - /** - * Calculate cost without groupExpression - */ - public static Cost calculateCost(Plan plan, PlanContext planContext) { - if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) { - CostModelV2 costModel = new CostModelV2(); - return plan.accept(costModel, planContext); - } else { - CostModelV1 costModel = new CostModelV1(); - return plan.accept(costModel, planContext); - } - } - - public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) { - if (ConnectContext.get().getSessionVariable().getEnableNewCostModel()) { + public static Cost addChildCost(ConnectContext connectContext, Plan plan, Cost planCost, Cost childCost, + int index) { + SessionVariable sessionVariable = connectContext.getSessionVariable(); + if (sessionVariable.getEnableNewCostModel()) { return CostModelV2.addChildCost(plan, planCost, childCost, index); } - return CostModelV1.addChildCost(plan, planCost, childCost, index); + return CostModelV1.addChildCost(sessionVariable, planCost, childCost); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java index c6bb6e6fdff2d9..33b3c80171ac43 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV1.java @@ -43,23 +43,12 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; class CostModelV1 extends PlanVisitor { - /** - * The intuition behind `HEAVY_OPERATOR_PUNISH_FACTOR` is we need to avoid this form of join patterns: - * Plan1: L join ( AGG1(A) join AGG2(B)) - * But - * Plan2: L join AGG1(A) join AGG2(B) is welcomed. - * AGG is time-consuming operator. From the perspective of rowCount, nereids may choose Plan1, - * because `Agg1 join Agg2` generates few tuples. But in Plan1, Agg1 and Agg2 are done in serial, in Plan2, Agg1 and - * Agg2 are done in parallel. And hence, Plan1 should be punished. - *

- * An example is tpch q15. - */ - static final double HEAVY_OPERATOR_PUNISH_FACTOR = 0.0; // for a join, skew = leftRowCount/rightRowCount // the higher skew is, the more we prefer broadcast join than shuffle join @@ -69,22 +58,24 @@ class CostModelV1 extends PlanVisitor { static final double BROADCAST_JOIN_SKEW_PENALTY_LIMIT = 2.0; private final int beNumber; - public CostModelV1() { - if (ConnectContext.get().getSessionVariable().isPlayNereidsDump()) { + public CostModelV1(ConnectContext connectContext) { + SessionVariable sessionVariable = connectContext.getSessionVariable(); + if (sessionVariable.isPlayNereidsDump()) { // TODO: @bingfeng refine minidump setting, and pass testMinidumpUt beNumber = 1; - } else if (ConnectContext.get().getSessionVariable().getBeNumberForTest() != -1) { - beNumber = ConnectContext.get().getSessionVariable().getBeNumberForTest(); + } else if (sessionVariable.getBeNumberForTest() != -1) { + beNumber = sessionVariable.getBeNumberForTest(); } else { beNumber = Math.max(1, ConnectContext.get().getEnv().getClusterInfo().getBackendsNumber(true)); } } - public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) { + public static Cost addChildCost(SessionVariable sessionVariable, Cost planCost, Cost childCost) { Preconditions.checkArgument(childCost instanceof CostV1 && planCost instanceof CostV1); CostV1 childCostV1 = (CostV1) childCost; CostV1 planCostV1 = (CostV1) planCost; - return new CostV1(childCostV1.getCpuCost() + planCostV1.getCpuCost(), + return new CostV1(sessionVariable, + childCostV1.getCpuCost() + planCostV1.getCpuCost(), childCostV1.getMemoryCost() + planCostV1.getMemoryCost(), childCostV1.getNetworkCost() + planCostV1.getNetworkCost()); } @@ -97,7 +88,7 @@ public Cost visit(Plan plan, PlanContext context) { @Override public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.ofCpu(statistics.getRowCount()); + return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount()); } @Override @@ -108,7 +99,7 @@ public Cost visitPhysicalDeferMaterializeOlapScan(PhysicalDeferMaterializeOlapSc public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.ofCpu(statistics.getRowCount()); + return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount()); } @Override @@ -116,31 +107,31 @@ public Cost visitPhysicalStorageLayerAggregate( PhysicalStorageLayerAggregate storageLayerAggregate, PlanContext context) { CostV1 costValue = (CostV1) storageLayerAggregate.getRelation().accept(this, context); // multiply a factor less than 1, so we can select PhysicalStorageLayerAggregate as far as possible - return new CostV1(costValue.getCpuCost() * 0.7, costValue.getMemoryCost(), + return new CostV1(context.getSessionVariable(), costValue.getCpuCost() * 0.7, costValue.getMemoryCost(), costValue.getNetworkCost()); } @Override public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.ofCpu(statistics.getRowCount()); + return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount()); } @Override public Cost visitPhysicalProject(PhysicalProject physicalProject, PlanContext context) { - return CostV1.ofCpu(1); + return CostV1.ofCpu(context.getSessionVariable(), 1); } @Override public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.ofCpu(statistics.getRowCount()); + return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount()); } @Override public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.ofCpu(statistics.getRowCount()); + return CostV1.ofCpu(context.getSessionVariable(), statistics.getRowCount()); } @Override @@ -156,7 +147,7 @@ public Cost visitPhysicalQuickSort( // Now we do more like two-phase sort, so penalise one-phase sort rowCount *= 100; } - return CostV1.of(childRowCount, rowCount, childRowCount); + return CostV1.of(context.getSessionVariable(), childRowCount, rowCount, childRowCount); } @Override @@ -171,7 +162,7 @@ public Cost visitPhysicalTopN(PhysicalTopN topN, PlanContext con // Now we do more like two-phase sort, so penalise one-phase sort rowCount *= 100; } - return CostV1.of(childRowCount, rowCount, childRowCount); + return CostV1.of(context.getSessionVariable(), childRowCount, rowCount, childRowCount); } @Override @@ -184,10 +175,10 @@ public Cost visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN partitionTopN, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); Statistics childStatistics = context.getChildStatistics(0); - return CostV1.of( - childStatistics.getRowCount(), - statistics.getRowCount(), - childStatistics.getRowCount()); + return CostV1.of(context.getSessionVariable(), + childStatistics.getRowCount(), + statistics.getRowCount(), + childStatistics.getRowCount()); } @Override @@ -199,7 +190,7 @@ public Cost visitPhysicalDistribute( // shuffle if (spec instanceof DistributionSpecHash) { - return CostV1.of( + return CostV1.of(context.getSessionVariable(), 0, 0, intputRowCount * childStatistics.dataSizeFactor() / beNumber); @@ -210,7 +201,7 @@ public Cost visitPhysicalDistribute( // estimate broadcast cost by an experience formula: beNumber^0.5 * rowCount // - sender number and receiver number is not available at RBO stage now, so we use beNumber // - senders and receivers work in parallel, that why we use square of beNumber - return CostV1.of( + return CostV1.of(context.getSessionVariable(), 0, 0, intputRowCount * childStatistics.dataSizeFactor()); @@ -219,14 +210,14 @@ public Cost visitPhysicalDistribute( // gather if (spec instanceof DistributionSpecGather) { - return CostV1.of( + return CostV1.of(context.getSessionVariable(), 0, 0, intputRowCount * childStatistics.dataSizeFactor() / beNumber); } // any - return CostV1.of( + return CostV1.of(context.getSessionVariable(), intputRowCount, 0, 0); @@ -237,11 +228,11 @@ public Cost visitPhysicalHashAggregate( PhysicalHashAggregate aggregate, PlanContext context) { Statistics inputStatistics = context.getChildStatistics(0); if (aggregate.getAggPhase().isLocal()) { - return CostV1.of(inputStatistics.getRowCount() / beNumber, + return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount() / beNumber, inputStatistics.getRowCount() / beNumber, 0); } else { // global - return CostV1.of(inputStatistics.getRowCount(), + return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount(), inputStatistics.getRowCount(), 0); } } @@ -278,7 +269,7 @@ public Cost visitPhysicalHashJoin( in pattern2, join1 and join2 takes more time, but Agg1 and agg2 can be processed in parallel. */ if (physicalHashJoin.getJoinType().isCrossJoin()) { - return CostV1.of(leftRowCount + rightRowCount + outputRowCount, + return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount + outputRowCount, 0, leftRowCount + rightRowCount ); @@ -293,8 +284,8 @@ public Cost visitPhysicalHashJoin( // bigger cost for ProbeWhenBuildSideOutput effort and ProbeWhenSearchHashTableTime // on the output rows, taken on outputRowCount() double probeSideFactor = 1.0; - double buildSideFactor = ConnectContext.get().getSessionVariable().getBroadcastRightTableScaleFactor(); - int parallelInstance = Math.max(1, ConnectContext.get().getSessionVariable().getParallelExecInstanceNum()); + double buildSideFactor = context.getSessionVariable().getBroadcastRightTableScaleFactor(); + int parallelInstance = Math.max(1, context.getSessionVariable().getParallelExecInstanceNum()); int totalInstanceNumber = parallelInstance * beNumber; if (buildSideFactor <= 1.0) { // use totalInstanceNumber to the power of 2 as the default factor value @@ -304,22 +295,24 @@ public Cost visitPhysicalHashJoin( // will refine this in next generation cost model. if (!context.isStatsReliable()) { // forbid broadcast join when stats is unknown - return CostV1.of(rightRowCount * buildSideFactor + 1 / leftRowCount, + return CostV1.of(context.getSessionVariable(), rightRowCount * buildSideFactor + 1 / leftRowCount, rightRowCount, 0 ); } - return CostV1.of(leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor, + return CostV1.of(context.getSessionVariable(), + leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor, rightRowCount, 0 ); } if (!context.isStatsReliable()) { - return CostV1.of(rightRowCount + 1 / leftRowCount, + return CostV1.of(context.getSessionVariable(), + rightRowCount + 1 / leftRowCount, rightRowCount, 0); } - return CostV1.of(leftRowCount + rightRowCount + outputRowCount, + return CostV1.of(context.getSessionVariable(), leftRowCount + rightRowCount + outputRowCount, rightRowCount, 0 ); @@ -334,11 +327,12 @@ public Cost visitPhysicalNestedLoopJoin( Statistics leftStatistics = context.getChildStatistics(0); Statistics rightStatistics = context.getChildStatistics(1); if (!context.isStatsReliable()) { - return CostV1.of(rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(), + return CostV1.of(context.getSessionVariable(), + rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(), rightStatistics.getRowCount(), 0); } - return CostV1.of( + return CostV1.of(context.getSessionVariable(), leftStatistics.getRowCount() * rightStatistics.getRowCount(), rightStatistics.getRowCount(), 0); @@ -347,7 +341,7 @@ public Cost visitPhysicalNestedLoopJoin( @Override public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows assertNumRows, PlanContext context) { - return CostV1.of( + return CostV1.of(context.getSessionVariable(), assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows(), 0 @@ -357,7 +351,7 @@ public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows ass @Override public Cost visitPhysicalGenerate(PhysicalGenerate generate, PlanContext context) { Statistics statistics = context.getStatisticsWithCheck(); - return CostV1.of( + return CostV1.of(context.getSessionVariable(), statistics.getRowCount(), statistics.getRowCount(), 0 diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV2.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV2.java index c90f28c6b8f3e0..7f1c6a6c5ac9fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV2.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModelV2.java @@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion; import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; @@ -58,6 +59,12 @@ class CostModelV2 extends PlanVisitor { static double CMP_COST = 1.5; static double PUSH_DOWN_AGG_COST = 0.1; + private final SessionVariable sessionVariable; + + CostModelV2(SessionVariable sessionVariable) { + this.sessionVariable = sessionVariable; + } + public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) { Preconditions.checkArgument(childCost instanceof CostV2 && planCost instanceof CostV2); CostV2 planCostV2 = (CostV2) planCost; @@ -103,7 +110,7 @@ public Cost visitPhysicalStorageLayerAggregate(PhysicalStorageLayerAggregate sto double ioCost = stats.computeSize(); - double runCost1 = CostWeight.get().weightSum(0, ioCost, 0) / stats.getBENumber(); + double runCost1 = CostWeight.get(sessionVariable).weightSum(0, ioCost, 0) / stats.getBENumber(); // Note the stats of this operator is the stats of relation. // We need add a plenty for this cost. Maybe changing rowCount of storageLayer is better @@ -125,7 +132,7 @@ public Cost visitPhysicalProject(PhysicalProject physicalProject double cpuCost = statistics.getRowCount() * ExprCostModel.calculateExprCost(physicalProject.getProjects()); double startCost = 0; - double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber(); return new CostV2(startCost, runCost, 0); } @@ -185,7 +192,7 @@ public Cost visitPhysicalDistribute(PhysicalDistribute distribut } double startCost = 0; - double runCost = CostWeight.get().weightSum(0, 0, netCost) / childStatistics.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(0, 0, netCost) / childStatistics.getBENumber(); return new CostV2(startCost, runCost, 0); } @@ -212,8 +219,8 @@ public Cost visitPhysicalHashJoin(PhysicalHashJoin generate, Pla double cpuCost = exprCost * statistics.getRowCount(); double startCost = 0; - double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / statistics.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber(); return new CostV2(startCost, runCost, 0); } @@ -274,7 +281,7 @@ public Cost visitPhysicalWindow(PhysicalWindow window, PlanConte double cpuCost = stats.getRowCount() * exprCost; double startCost = 0; - double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber(); return new CostV2(startCost, runCost, 0); } @@ -293,7 +300,7 @@ public Cost visitPhysicalSetOperation(PhysicalSetOperation intersect, PlanContex size += childStats.computeSize(); } - double startCost = CostWeight.get().weightSum(rowCount * HASH_COST, 0, 0); + double startCost = CostWeight.get(sessionVariable).weightSum(rowCount * HASH_COST, 0, 0); double runCost = 0; return new CostV2(startCost, runCost, size); @@ -307,7 +314,7 @@ public Cost visitPhysicalFilter(PhysicalFilter physicalFilter, PlanContext conte double cpuCost = exprCost * stats.getRowCount(); double startCost = 0; - double runCost = CostWeight.get().weightSum(cpuCost, 0, 0) / stats.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber(); return new CostV2(startCost, runCost, 0); } @@ -316,13 +323,13 @@ private CostV2 calculateScanWithoutRF(Statistics stats) { //TODO: consider runtimeFilter double io = stats.computeSize(); double startCost = 0; - double runCost = CostWeight.get().weightSum(0, io, 0) / stats.getBENumber(); + double runCost = CostWeight.get(sessionVariable).weightSum(0, io, 0) / stats.getBENumber(); return new CostV2(startCost, runCost, 0); } private CostV2 calculateAggregate(Statistics stats, Statistics childStats, double exprCost) { // Build HashTable - double startCost = CostWeight.get() + double startCost = CostWeight.get(sessionVariable) .weightSum(HASH_COST * childStats.getRowCount() + exprCost * childStats.getRowCount(), 0, 0); double runCost = 0; return new CostV2(startCost, runCost, stats.computeSize()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java index fb00bacc2877a7..ea47e1dedbeda5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java @@ -17,10 +17,13 @@ package org.apache.doris.nereids.cost; +import org.apache.doris.qe.SessionVariable; + class CostV1 implements Cost { private static final CostV1 INFINITE = new CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY); - private static final CostV1 ZERO = new CostV1(0, 0, 0); + private static final CostV1 ZERO = new CostV1(0, 0, 0, 0); private final double cpuCost; private final double memoryCost; @@ -29,9 +32,9 @@ class CostV1 implements Cost { private final double cost; /** - * Constructor of CostEstimate. + * Constructor of CostV1. */ - public CostV1(double cpuCost, double memoryCost, double networkCost) { + public CostV1(SessionVariable sessionVariable, double cpuCost, double memoryCost, double networkCost) { // TODO: fix stats cpuCost = Double.max(0, cpuCost); memoryCost = Double.max(0, memoryCost); @@ -40,11 +43,18 @@ public CostV1(double cpuCost, double memoryCost, double networkCost) { this.memoryCost = memoryCost; this.networkCost = networkCost; - CostWeight costWeight = CostWeight.get(); + CostWeight costWeight = CostWeight.get(sessionVariable); this.cost = costWeight.cpuWeight * cpuCost + costWeight.memoryWeight * memoryCost + costWeight.networkWeight * networkCost; } + private CostV1(double cost, double cpuCost, double memoryCost, double networkCost) { + this.cost = cost; + this.cpuCost = cpuCost; + this.memoryCost = memoryCost; + this.networkCost = networkCost; + } + public static CostV1 infinite() { return INFINITE; } @@ -69,16 +79,12 @@ public double getValue() { return cost; } - public static CostV1 of(double cpuCost, double maxMemory, double networkCost) { - return new CostV1(cpuCost, maxMemory, networkCost); - } - - public static CostV1 ofCpu(double cpuCost) { - return new CostV1(cpuCost, 0, 0); + public static CostV1 of(SessionVariable sessionVariable, double cpuCost, double maxMemory, double networkCost) { + return new CostV1(sessionVariable, cpuCost, maxMemory, networkCost); } - public static CostV1 ofMemory(double memoryCost) { - return new CostV1(0, memoryCost, 0); + public static CostV1 ofCpu(SessionVariable sessionVariable, double cpuCost) { + return new CostV1(sessionVariable, cpuCost, 0, 0); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java index 3c62f8857428a2..f92ac0115a8815 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostWeight.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.cost; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import com.google.common.base.Preconditions; @@ -66,11 +67,18 @@ public CostWeight(double cpuWeight, double memoryWeight, double networkWeight, d } public static CostWeight get() { - double cpuWeight = ConnectContext.get().getSessionVariable().getCboCpuWeight(); - double memWeight = ConnectContext.get().getSessionVariable().getCboMemWeight(); - double netWeight = ConnectContext.get().getSessionVariable().getCboNetWeight(); - return new CostWeight(cpuWeight, memWeight, netWeight, - ConnectContext.get().getSessionVariable().getNereidsCboPenaltyFactor()); + SessionVariable sessionVariable = ConnectContext.get().getSessionVariable(); + double cpuWeight = sessionVariable.getCboCpuWeight(); + double memWeight = sessionVariable.getCboMemWeight(); + double netWeight = sessionVariable.getCboNetWeight(); + return new CostWeight(cpuWeight, memWeight, netWeight, sessionVariable.getNereidsCboPenaltyFactor()); + } + + public static CostWeight get(SessionVariable sessionVariable) { + double cpuWeight = sessionVariable.getCboCpuWeight(); + double memWeight = sessionVariable.getCboMemWeight(); + double netWeight = sessionVariable.getCboNetWeight(); + return new CostWeight(cpuWeight, memWeight, netWeight, sessionVariable.getNereidsCboPenaltyFactor()); } //TODO: add it in session variable diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 5e4097bbfec472..ebf7a8dcd86e57 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -244,11 +244,11 @@ public PlanFragment translatePlan(PhysicalPlan physicalPlan) { Collections.reverse(context.getPlanFragments()); // TODO: maybe we need to trans nullable directly? and then we could remove call computeMemLayout context.getDescTable().computeMemLayout(); - if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().forbidUnknownColStats) { + if (context.getSessionVariable() != null && context.getSessionVariable().forbidUnknownColStats) { Set scans = context.getScanNodeWithUnknownColumnStats(); if (!scans.isEmpty()) { StringBuilder builder = new StringBuilder(); - scans.forEach(scanNode -> builder.append(scanNode)); + scans.forEach(builder::append); throw new AnalysisException("tables with unknown column stats: " + builder); } } @@ -607,7 +607,7 @@ public PlanFragment visitPhysicalOlapScan(PhysicalOlapScan olapScan, PlanTransla // TODO: move all node set cardinality into one place if (olapScan.getStats() != null) { olapScanNode.setCardinality((long) olapScan.getStats().getRowCount()); - if (ConnectContext.get().getSessionVariable().forbidUnknownColStats) { + if (context.getSessionVariable() != null && context.getSessionVariable().forbidUnknownColStats) { for (int i = 0; i < slots.size(); i++) { Slot slot = slots.get(i); if (olapScan.getStats().findColumnStatistics(slot).isUnKnown() @@ -1027,7 +1027,7 @@ public PlanFragment visitPhysicalFilter(PhysicalFilter filter, P updateLegacyPlanIdToPhysicalPlan(inputFragment.getPlanRoot(), filter); } } - //in ut, filter.stats may be null + // in ut, filter.stats may be null if (filter.getStats() != null) { inputFragment.getPlanRoot().setCardinalityAfterFilter((long) filter.getStats().getRowCount()); } @@ -1226,7 +1226,7 @@ public PlanFragment visitPhysicalHashJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) { hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptor.getId()); hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), leftSlotDescriptor.getId()); @@ -1247,7 +1247,7 @@ public PlanFragment visitPhysicalHashJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) { hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptor.getId()); hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), rightSlotDescriptor.getId()); @@ -1267,7 +1267,7 @@ public PlanFragment visitPhysicalHashJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) { hashJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDescriptor.getId()); hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), leftSlotDescriptor.getId()); @@ -1286,7 +1286,7 @@ public PlanFragment visitPhysicalHashJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); if (hashOutputSlotReferenceMap.get(sf.getExprId()) != null) { hashJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDescriptor.getId()); hashJoinNode.getHashOutputExprSlotIdMap().put(sf.getExprId(), rightSlotDescriptor.getId()); @@ -1454,7 +1454,7 @@ public PlanFragment visitPhysicalNestedLoopJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, leftSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, leftSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); } leftIntermediateSlotDescriptor.add(sd); } @@ -1469,7 +1469,7 @@ public PlanFragment visitPhysicalNestedLoopJoin( sd = context.getDescTable().copySlotDescriptor(intermediateDescriptor, rightSlotDescriptor); } else { sd = context.createSlotDesc(intermediateDescriptor, sf, rightSlotDescriptor.getParent().getTable()); - //sd = context.createSlotDesc(intermediateDescriptor, sf); + // sd = context.createSlotDesc(intermediateDescriptor, sf); } rightIntermediateSlotDescriptor.add(sd); } @@ -1494,7 +1494,7 @@ public PlanFragment visitPhysicalNestedLoopJoin( .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); if (!nestedLoopJoin.isBitMapRuntimeFilterConditionsEmpty() && joinConjuncts.isEmpty()) { - //left semi join need at least one conjunct. otherwise left-semi-join fallback to cross-join + // left semi join need at least one conjunct. otherwise left-semi-join fallback to cross-join joinConjuncts.add(new BoolLiteral(true)); } @@ -2099,7 +2099,8 @@ private void updateScanSlotsMaterialization(ScanNode scanNode, scanNode.getTupleDesc().getSlots().add(smallest); } try { - if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().forbidUnknownColStats + if (context.getSessionVariable() != null + && context.getSessionVariable().forbidUnknownColStats && !StatisticConstants.isSystemTable(scanNode.getTupleDesc().getTable())) { for (SlotId slotId : requiredByProjectSlotIdSet) { if (context.isColumnStatsUnknown(scanNode, slotId)) { @@ -2316,7 +2317,7 @@ private void injectRowIdColumnSlot(TupleDescriptor tupleDesc) { private boolean checkPushSort(SortNode sortNode, OlapTable olapTable) { // Ensure limit is less than threshold if (sortNode.getLimit() <= 0 - || sortNode.getLimit() > ConnectContext.get().getSessionVariable().topnOptLimitThreshold) { + || sortNode.getLimit() > context.getSessionVariable().topnOptLimitThreshold) { return false; } @@ -2365,7 +2366,7 @@ private List translateToLegacyConjuncts(Set conjuncts) { List outputExprs = Lists.newArrayList(); if (conjuncts != null) { conjuncts.stream() - .map(e -> ExpressionTranslator.translate(e, context)) + .map(e -> ExpressionTranslator.translate(e, context)) .forEach(outputExprs::add); } return outputExprs; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java index 721eea37b77379..8f273d4ec089e3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java @@ -43,6 +43,8 @@ import org.apache.doris.planner.PlanNode; import org.apache.doris.planner.PlanNodeId; import org.apache.doris.planner.ScanNode; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.thrift.TPushAggOp; import com.google.common.annotations.VisibleForTesting; @@ -62,6 +64,7 @@ * Context of physical plan. */ public class PlanTranslatorContext { + private final ConnectContext connectContext; private final List planFragments = Lists.newArrayList(); private final DescriptorTable descTable = new DescriptorTable(); @@ -110,12 +113,14 @@ public class PlanTranslatorContext { private final Map> statsUnknownColumnsMap = Maps.newHashMap(); public PlanTranslatorContext(CascadesContext ctx) { + this.connectContext = ctx.getConnectContext(); this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext()); } @VisibleForTesting public PlanTranslatorContext() { - translator = null; + this.connectContext = null; + this.translator = null; } /** @@ -142,6 +147,10 @@ public void removeScanFromStatsUnknownColumnsMap(ScanNode scan) { statsUnknownColumnsMap.remove(scan); } + public SessionVariable getSessionVariable() { + return connectContext == null ? null : connectContext.getSessionVariable(); + } + public Set getScanNodeWithUnknownColumnStats() { return statsUnknownColumnsMap.keySet(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java index 7cb73a332d1c23..8126a0aee01da9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java @@ -31,6 +31,8 @@ import org.apache.doris.nereids.properties.EnforceMissingPropertiesHelper; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.RequestPropertyDeriver; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.Lists; import org.apache.logging.log4j.LogManager; @@ -79,6 +81,14 @@ public CostAndEnforcerJob(GroupExpression groupExpression, JobContext context) { this.groupExpression = groupExpression; } + private ConnectContext getConnectContext() { + return context.getCascadesContext().getConnectContext(); + } + + private SessionVariable getSessionVariable() { + return context.getCascadesContext().getConnectContext().getSessionVariable(); + } + /*- * Please read the ORCA paper * - 4.1.4 Optimization. @@ -113,17 +123,19 @@ public void execute() { return; } + SessionVariable sessionVariable = getSessionVariable(); + countJobExecutionTimesOfGroupExpressions(groupExpression); // Do init logic of root plan/groupExpr of `subplan`, only run once per task. if (curChildIndex == -1) { - curNodeCost = Cost.zero(); - curTotalCost = Cost.zero(); + curNodeCost = Cost.zero(sessionVariable); + curTotalCost = Cost.zero(sessionVariable); curChildIndex = 0; // List // [ child item: [leftProperties, rightProperties]] // like :[ [Properties {"", ANY}, Properties {"", BROADCAST}], // [Properties {"", SHUFFLE_JOIN}, Properties {"", SHUFFLE_JOIN}] ] - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(context); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(getConnectContext(), context); requestChildrenPropertiesList = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); for (List requestChildrenProperties : requestChildrenPropertiesList) { outputChildrenPropertiesList.add(new ArrayList<>(requestChildrenProperties)); @@ -139,7 +151,8 @@ public void execute() { = outputChildrenPropertiesList.get(requestPropertiesIndex); // Calculate cost if (curChildIndex == 0 && prevChildIndex == -1) { - curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties); + curNodeCost = CostCalculator.calculateCost(getConnectContext(), groupExpression, + requestChildrenProperties); groupExpression.setCost(curNodeCost); curTotalCost = curNodeCost; } @@ -184,7 +197,9 @@ public void execute() { // plan's requestChildProperty).getOutputProperties(current plan's requestChildProperty) == child // plan's outputProperties`, the outputProperties must satisfy the origin requestChildProperty outputChildrenProperties.set(curChildIndex, outputProperties); - curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + curTotalCost = CostCalculator.addChildCost( + getConnectContext(), + groupExpression.getPlan(), curNodeCost, lowestCostExpr.getCostValueByProperties(requestChildProperty), curChildIndex); @@ -194,7 +209,7 @@ public void execute() { // Group1 : betterExpr, currentExpr(child: Group2), otherExpr(child: Group) // steps // 1. CostAndEnforce(currentExpr) with upperBound betterExpr.cost - // 2. OptimzeGroup(Group2) with upperBound bestExpr.cost - currentExpr.nodeCost + // 2. OptimizeGroup(Group2) with upperBound bestExpr.cost - currentExpr.nodeCost // 3. CostAndEnforce(Expr in Group2) trigger here and exit // ... // n. CostAndEnforce(otherExpr) can trigger optimize group2 again for the same requireProp @@ -240,7 +255,8 @@ private boolean calculateEnforce(List requestChildrenPropert ChildOutputPropertyDeriver childOutputPropertyDeriver = new ChildOutputPropertyDeriver(outputChildrenProperties); // the physical properties the group expression support for its parent. - PhysicalProperties outputProperty = childOutputPropertyDeriver.getOutputProperties(groupExpression); + PhysicalProperties outputProperty = childOutputPropertyDeriver.getOutputProperties(getConnectContext(), + groupExpression); // update current group statistics and re-compute costs. if (groupExpression.children().stream().anyMatch(group -> group.getStatistics() == null) @@ -251,12 +267,14 @@ private boolean calculateEnforce(List requestChildrenPropert } // recompute cost after adjusting property - curNodeCost = CostCalculator.calculateCost(groupExpression, requestChildrenProperties); + curNodeCost = CostCalculator.calculateCost(getConnectContext(), groupExpression, requestChildrenProperties); groupExpression.setCost(curNodeCost); curTotalCost = curNodeCost; for (int i = 0; i < outputChildrenProperties.size(); i++) { PhysicalProperties childProperties = outputChildrenProperties.get(i); - curTotalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + curTotalCost = CostCalculator.addChildCost( + getConnectContext(), + groupExpression.getPlan(), curTotalCost, groupExpression.child(i).getLowestCostPlan(childProperties).get().first, i); @@ -301,7 +319,7 @@ private void enforce(PhysicalProperties outputProperty, List } EnforceMissingPropertiesHelper enforceMissingPropertiesHelper - = new EnforceMissingPropertiesHelper(context, groupExpression, curTotalCost); + = new EnforceMissingPropertiesHelper(getConnectContext(), groupExpression, curTotalCost); PhysicalProperties addEnforcedProperty = enforceMissingPropertiesHelper .enforceProperty(outputProperty, requiredProperties); curTotalCost = enforceMissingPropertiesHelper.getCurTotalCost(); @@ -338,8 +356,8 @@ private void clear() { lowestCostChildren.clear(); prevChildIndex = -1; curChildIndex = 0; - curTotalCost = Cost.zero(); - curNodeCost = Cost.zero(); + curTotalCost = Cost.zero(getSessionVariable()); + curNodeCost = Cost.zero(getSessionVariable()); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java index ec65b9af14d604..8aa940589f85a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.plans.algebra.Project; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import java.util.HashMap; @@ -102,20 +103,19 @@ public void execute() { } } } else { + ConnectContext connectContext = context.getCascadesContext().getConnectContext(); + SessionVariable sessionVariable = connectContext.getSessionVariable(); StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression, - context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(), - context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(), - context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(), + sessionVariable.getForbidUnknownColStats(), + connectContext.getTotalColumnStatisticMap(), + sessionVariable.isPlayNereidsDump(), cteIdToStats, context.getCascadesContext()); STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression, groupExpression.getOwnerGroup().getStatistics())); - if (ConnectContext.get().getSessionVariable().isEnableMinidump() - && !ConnectContext.get().getSessionVariable().isPlayNereidsDump()) { - context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap() - .putAll(statsCalculator.getTotalColumnStatisticMap()); - context.getCascadesContext().getConnectContext().getTotalHistogramMap() - .putAll(statsCalculator.getTotalHistogramMap()); + if (sessionVariable.isEnableMinidump() && !sessionVariable.isPlayNereidsDump()) { + connectContext.getTotalColumnStatisticMap().putAll(statsCalculator.getTotalColumnStatisticMap()); + connectContext.getTotalHistogramMap().putAll(statsCalculator.getTotalHistogramMap()); } if (groupExpression.getPlan() instanceof Project) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index ea35558c42e4db..5b75cb2d76e193 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -68,6 +68,7 @@ public class Memo { private static final EventProducer GROUP_MERGE_TRACER = new EventProducer(GroupMergeEvent.class, EventChannel.getDefaultChannel().addConsumers(new LogConsumer(GroupMergeEvent.class, EventChannel.LOG))); private static long stateId = 0; + private final ConnectContext connectContext; private final IdGenerator groupIdGenerator = GroupId.createGenerator(); private final Map groups = Maps.newLinkedHashMap(); // we could not use Set, because Set does not have get method. @@ -76,11 +77,13 @@ public class Memo { // FOR TEST ONLY public Memo() { - root = null; + this.root = null; + this.connectContext = null; } - public Memo(Plan plan) { - root = init(plan); + public Memo(ConnectContext connectContext, Plan plan) { + this.root = init(plan); + this.connectContext = connectContext; } public static long getStateId() { @@ -214,8 +217,7 @@ public CopyInResult copyIn(Plan plan, @Nullable Group target, boolean rewrite) { } private void maybeAddStateId(CopyInResult result) { - if (ConnectContext.get() != null - && ConnectContext.get().getSessionVariable().isEnableNereidsTrace() + if (connectContext != null && connectContext.getSessionVariable().isEnableNereidsTrace() && result.generateNewExpression) { stateId++; } @@ -850,11 +852,12 @@ private List> rankGroupExpression(GroupExpression groupExpressi List>> childrenId = new ArrayList<>(); permute(children, 0, childrenId, new ArrayList<>()); - Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties); + Cost cost = CostCalculator.calculateCost(connectContext, groupExpression, inputProperties); for (Pair> c : childrenId) { Cost totalCost = cost; for (int i = 0; i < children.size(); i++) { - totalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + totalCost = CostCalculator.addChildCost(connectContext, + groupExpression.getPlan(), totalCost, children.get(i).get(c.second.get(i)).second, i); @@ -942,7 +945,7 @@ private List> extractInputProperties(GroupExpression gr // return any if exits except RequirePropertiesSupplier and SetOperators // Because PropRegulator could change their input properties - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(prop); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(connectContext, prop); List> requestList = requestPropertyDeriver .getRequestChildrenPropertyList(groupExpression); Optional> any = requestList.stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/minidump/MinidumpUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/minidump/MinidumpUtils.java index 44e14b4c8cdde4..fc7d3939dabc05 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/minidump/MinidumpUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/minidump/MinidumpUtils.java @@ -286,8 +286,6 @@ private static void serializeStatsUsed(JSONObject jsonObj, List tables) /** * serialize output plan to dump file and persistent into disk - * @param resultPlan - * */ public static void serializeOutputToDumpFile(Plan resultPlan) { if (ConnectContext.get().getSessionVariable().isPlayNereidsDump() @@ -401,22 +399,22 @@ private static JSONObject serializeChangedSessionVariable(SessionVariable sessio } switch (field.getType().getSimpleName()) { case "boolean": - root.put(attr.name(), (Boolean) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; case "int": - root.put(attr.name(), (Integer) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; case "long": - root.put(attr.name(), (Long) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; case "float": - root.put(attr.name(), (Float) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; case "double": - root.put(attr.name(), (Double) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; case "String": - root.put(attr.name(), (String) field.get(sessionVariable)); + root.put(attr.name(), field.get(sessionVariable)); break; default: // Unsupported type variable. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java index 1c94cb65ea2549..3b07f2bbe985b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java @@ -60,6 +60,7 @@ import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -89,8 +90,8 @@ public ChildOutputPropertyDeriver(List childrenOutputPropert this.childrenOutputProperties = Objects.requireNonNull(childrenOutputProperties); } - public PhysicalProperties getOutputProperties(GroupExpression groupExpression) { - return groupExpression.getPlan().accept(this, new PlanContext(groupExpression)); + public PhysicalProperties getOutputProperties(ConnectContext connectContext, GroupExpression groupExpression) { + return groupExpression.getPlan().accept(this, new PlanContext(connectContext, groupExpression)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 93f6e947894c83..1174602da773d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -43,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -453,7 +454,7 @@ private List calAnotherSideRequiredShuffleIds(DistributionSpecHash notSh * * @param shuffleType real output shuffle type * @param notShuffleSideOutput not shuffle side real output used hash spec - * @param shuffleSideOutput shuffle side real output used hash spec + * @param shuffleSideOutput shuffle side real output used hash spec * @param notShuffleSideRequired not shuffle side required used hash spec * @param shuffleSideRequired shuffle side required hash spec * @return shuffle side new required hash spec @@ -481,7 +482,7 @@ private void updateChildEnforceAndCost(int index, PhysicalProperties targetPrope private void updateChildEnforceAndCost(GroupExpression child, PhysicalProperties childOutput, DistributionSpec target, Cost currentCost) { if (child.getPlan() instanceof PhysicalDistribute) { - //To avoid continuous distribute operator, we just enforce the child's child + // To avoid continuous distribute operator, we just enforce the child's child childOutput = child.getInputPropertiesList(childOutput).get(0); Pair newChildAndCost = child.getOwnerGroup().getLowestCostPlan(childOutput).get(); child = newChildAndCost.second; @@ -491,8 +492,9 @@ private void updateChildEnforceAndCost(GroupExpression child, PhysicalProperties PhysicalProperties newOutputProperty = new PhysicalProperties(target); GroupExpression enforcer = target.addEnforcer(child.getOwnerGroup()); child.getOwnerGroup().addEnforcer(enforcer); - Cost totalCost = CostCalculator.addChildCost(enforcer.getPlan(), - CostCalculator.calculateCost(enforcer, Lists.newArrayList(childOutput)), + ConnectContext connectContext = jobContext.getCascadesContext().getConnectContext(); + Cost totalCost = CostCalculator.addChildCost(connectContext, enforcer.getPlan(), + CostCalculator.calculateCost(connectContext, enforcer, Lists.newArrayList(childOutput)), currentCost, 0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java index d548e3254c47d8..31d99765beb233 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/EnforceMissingPropertiesHelper.java @@ -19,7 +19,6 @@ import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.cost.CostCalculator; -import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.metrics.EventChannel; import org.apache.doris.nereids.metrics.EventProducer; @@ -28,6 +27,7 @@ import org.apache.doris.nereids.minidump.NereidsTracer; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.Lists; @@ -38,13 +38,13 @@ public class EnforceMissingPropertiesHelper { private static final EventProducer ENFORCER_TRACER = new EventProducer(EnforcerEvent.class, EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG))); - private final JobContext context; + private final ConnectContext connectContext; private final GroupExpression groupExpression; private Cost curTotalCost; - public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression, + public EnforceMissingPropertiesHelper(ConnectContext connectContext, GroupExpression groupExpression, Cost curTotalCost) { - this.context = context; + this.connectContext = connectContext; this.groupExpression = groupExpression; this.curTotalCost = curTotalCost; } @@ -155,12 +155,15 @@ private void addEnforcerUpdateCost(GroupExpression enforcer, ENFORCER_TRACER.log(EnforcerEvent.of(groupExpression, ((PhysicalPlan) enforcer.getPlan()), oldOutputProperty, newOutputProperty)); enforcer.setEstOutputRowCount(enforcer.getOwnerGroup().getStatistics().getRowCount()); - Cost enforcerCost = CostCalculator.calculateCost(enforcer, Lists.newArrayList(oldOutputProperty)); + Cost enforcerCost = CostCalculator.calculateCost(connectContext, enforcer, + Lists.newArrayList(oldOutputProperty)); enforcer.setCost(enforcerCost); - curTotalCost = CostCalculator.addChildCost(enforcer.getPlan(), - enforcerCost, - curTotalCost, - 0); + curTotalCost = CostCalculator.addChildCost( + connectContext, + enforcer.getPlan(), + enforcerCost, + curTotalCost, + 0); if (enforcer.updateLowestCostTable(newOutputProperty, Lists.newArrayList(oldOutputProperty), curTotalCost)) { enforcer.putOutputPropertiesMap(newOutputProperty, newOutputProperty); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 7d9f8f994e3d98..ef7d72b095ed5c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -65,14 +65,17 @@ public class RequestPropertyDeriver extends PlanVisitor { * ▼ * requestPropertyToChildren */ + private final ConnectContext connectContext; private final PhysicalProperties requestPropertyFromParent; private List> requestPropertyToChildren; - public RequestPropertyDeriver(JobContext context) { + public RequestPropertyDeriver(ConnectContext connectContext, JobContext context) { + this.connectContext = connectContext; this.requestPropertyFromParent = context.getRequiredProperties(); } - public RequestPropertyDeriver(PhysicalProperties requestPropertyFromParent) { + public RequestPropertyDeriver(ConnectContext connectContext, PhysicalProperties requestPropertyFromParent) { + this.connectContext = connectContext; this.requestPropertyFromParent = requestPropertyFromParent; } @@ -81,7 +84,7 @@ public RequestPropertyDeriver(PhysicalProperties requestPropertyFromParent) { */ public List> getRequestChildrenPropertyList(GroupExpression groupExpression) { requestPropertyToChildren = Lists.newArrayList(); - groupExpression.getPlan().accept(this, new PlanContext(groupExpression)); + groupExpression.getPlan().accept(this, new PlanContext(connectContext, groupExpression)); return requestPropertyToChildren; } @@ -110,8 +113,7 @@ public Void visit(Plan plan, PlanContext context) { @Override public Void visitPhysicalOlapTableSink(PhysicalOlapTableSink olapTableSink, PlanContext context) { - if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable() != null - && !ConnectContext.get().getSessionVariable().enableStrictConsistencyDml) { + if (connectContext != null && !connectContext.getSessionVariable().enableStrictConsistencyDml) { addRequestPropertyToChildren(PhysicalProperties.ANY); } else { addRequestPropertyToChildren(olapTableSink.getRequirePhysicalProperties()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index b1757a4d3f0f35..62d62ebf4324fc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -96,7 +96,7 @@ void testMergeGroup() { FakePlan fakePlan = new FakePlan(); GroupExpression srcParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(srcGroup)); Group srcParentGroup = new Group(new GroupId(0), srcParentExpression, new LogicalProperties(ArrayList::new)); - srcParentGroup.setBestPlan(srcParentExpression, Cost.zero(), PhysicalProperties.ANY); + srcParentGroup.setBestPlan(srcParentExpression, Cost.zeroV1(), PhysicalProperties.ANY); GroupExpression dstParentExpression = new GroupExpression(fakePlan, Lists.newArrayList(dstGroup)); Group dstParentGroup = new Group(new GroupId(1), dstParentExpression, new LogicalProperties(ArrayList::new)); @@ -1069,7 +1069,7 @@ void testRewriteMiddlePlans() { ); // Project -> Project -> Relation - Memo memo = new Memo(rootProject); + Memo memo = new Memo(null, rootProject); Group leafGroup = memo.getGroups().stream().filter(g -> g.getGroupId().asInt() == 0).findFirst().get(); Group targetGroup = memo.getGroups().stream().filter(g -> g.getGroupId().asInt() == 1).findFirst().get(); LogicalProject rewriteInsideProject = new LogicalProject<>( @@ -1126,7 +1126,7 @@ void testEliminateRootWithChildPlanThreeLevels() { ); // Project -> Project -> Relation - Memo memo = new Memo(rootProject); + Memo memo = new Memo(null, rootProject); Group leafGroup = memo.getGroups().stream().filter(g -> g.getGroupId().asInt() == 0).findFirst().get(); Group targetGroup = memo.getGroups().stream().filter(g -> g.getGroupId().asInt() == 2).findFirst().get(); LogicalPlan rewriteProject = insideProject.withChildren(Lists.newArrayList(new GroupPlan(leafGroup))); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java index 6c110d0b158819..53a459859b2448 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java @@ -48,7 +48,7 @@ public class GroupExpressionMatchingTest { public void testLeafNode() { Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION); - Memo memo = new Memo(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); + Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); GroupExpressionMatching groupExpressionMatching = new GroupExpressionMatching(pattern, memo.getRoot().getLogicalExpression()); @@ -69,7 +69,7 @@ public void testDepth2() { LogicalProject root = new LogicalProject(ImmutableList .of(new SlotReference("name", StringType.INSTANCE, true, ImmutableList.of("test"))), leaf); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); Plan anotherLeaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test2")); memo.copyIn(anotherLeaf, memo.getRoot().getLogicalExpression().child(0), false); @@ -100,7 +100,7 @@ public void testDepth2WithGroup() { LogicalProject root = new LogicalProject(ImmutableList .of(new SlotReference("name", StringType.INSTANCE, true, ImmutableList.of("test"))), leaf); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); Plan anotherLeaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test2")); memo.copyIn(anotherLeaf, memo.getRoot().getLogicalExpression().child(0), false); @@ -122,7 +122,7 @@ public void testDepth2WithGroup() { public void testLeafAny() { Pattern pattern = Pattern.ANY; - Memo memo = new Memo(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); + Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); GroupExpressionMatching groupExpressionMatching = new GroupExpressionMatching(pattern, memo.getRoot().getLogicalExpression()); @@ -140,7 +140,7 @@ public void testAnyWithChild() { ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true, ImmutableList.of("test"))), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); Plan anotherLeaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("test2")); memo.copyIn(anotherLeaf, memo.getRoot().getLogicalExpression().child(0), false); @@ -165,7 +165,7 @@ public void testInnerLogicalJoinMatch() { new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) ); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); GroupExpressionMatching groupExpressionMatching = new GroupExpressionMatching(patterns().innerLogicalJoin().pattern, @@ -187,7 +187,7 @@ public void testInnerLogicalJoinMismatch() { new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) ); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); GroupExpressionMatching groupExpressionMatching = new GroupExpressionMatching(patterns().innerLogicalJoin().pattern, @@ -204,7 +204,7 @@ public void testTopMatchButChildrenNotMatch() { new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) ); - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); Pattern pattern = patterns() .innerLogicalJoin(patterns().logicalFilter(), patterns().any()).pattern; @@ -332,7 +332,7 @@ private org.apache.doris.nereids.pattern.GeneratedMemoPatterns patterns() { } private Iterator match(Plan root, Pattern pattern) { - Memo memo = new Memo(root); + Memo memo = new Memo(null, root); GroupExpressionMatching groupExpressionMatching = new GroupExpressionMatching(pattern, memo.getRoot().getLogicalExpression()); return groupExpressionMatching.iterator(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java index be0f64dba9b6eb..f283e6fb87c8fe 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java @@ -126,7 +126,7 @@ void testInnerJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -165,7 +165,7 @@ void testCrossJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -204,7 +204,7 @@ void testLeftOuterJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -244,7 +244,7 @@ void testLeftSemiJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -284,7 +284,7 @@ void testLeftAntiJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -324,7 +324,7 @@ void testNullAwareLeftAntiJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -364,7 +364,7 @@ void testRightSemiJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -405,7 +405,7 @@ void testRightAntiJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -446,7 +446,7 @@ void testRightOuterJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -487,7 +487,7 @@ void testFullOuterJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecAny); } @@ -530,7 +530,7 @@ Pair, List> getOnClauseUsedSlots( List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -580,7 +580,7 @@ Pair, List> getOnClauseUsedSlots( List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -613,7 +613,7 @@ void testNestedLoopJoin() { List childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -639,7 +639,7 @@ void testLocalPhaseAggregate() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertEquals(child.getDistributionSpec(), result.getDistributionSpec()); } @@ -666,7 +666,7 @@ void testGlobalPhaseAggregate() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertTrue(result.getOrderSpec().getOrderKeys().isEmpty()); Assertions.assertTrue(result.getDistributionSpec() instanceof DistributionSpecHash); DistributionSpecHash actual = (DistributionSpecHash) result.getDistributionSpec(); @@ -695,7 +695,7 @@ void testAggregateWithoutGroupBy() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(PhysicalProperties.GATHER, result); } @@ -711,7 +711,7 @@ void testLocalQuickSort() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(orderKeys, result.getOrderSpec().getOrderKeys()); Assertions.assertEquals(DistributionSpecReplicated.INSTANCE, result.getDistributionSpec()); } @@ -728,7 +728,7 @@ void testQuickSort() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(orderKeys, result.getOrderSpec().getOrderKeys()); Assertions.assertEquals(DistributionSpecGather.INSTANCE, result.getDistributionSpec()); } @@ -746,7 +746,7 @@ void testTopN() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(orderKeys, result.getOrderSpec().getOrderKeys()); Assertions.assertEquals(DistributionSpecReplicated.INSTANCE, result.getDistributionSpec()); // merge/gather sort requires gather @@ -758,7 +758,7 @@ void testTopN() { new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - result = deriver.getOutputProperties(groupExpression); + result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(orderKeys, result.getOrderSpec().getOrderKeys()); Assertions.assertEquals(DistributionSpecGather.INSTANCE, result.getDistributionSpec()); } @@ -774,7 +774,7 @@ void testLimit() { new OrderSpec(orderKeys)); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(orderKeys, result.getOrderSpec().getOrderKeys()); Assertions.assertEquals(DistributionSpecGather.INSTANCE, result.getDistributionSpec()); } @@ -790,7 +790,7 @@ void testAssertNumRows() { new Group(null, groupExpression, null); PhysicalProperties child = new PhysicalProperties(DistributionSpecGather.INSTANCE, new OrderSpec()); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); - PhysicalProperties result = deriver.getOutputProperties(groupExpression); + PhysicalProperties result = deriver.getOutputProperties(null, groupExpression); Assertions.assertEquals(PhysicalProperties.GATHER, result); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index 37767273b399cc..e6eaf3884e4fdc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -90,7 +90,7 @@ void testNestedLoopJoin() { GroupExpression groupExpression = new GroupExpression(join); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); @@ -115,7 +115,7 @@ Pair, List> getHashConjunctsExprIds() { GroupExpression groupExpression = new GroupExpression(join, Lists.newArrayList(group, group)); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); @@ -151,7 +151,7 @@ ConnectContext get() { GroupExpression groupExpression = new GroupExpression(join, Lists.newArrayList(group, group)); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); @@ -179,7 +179,7 @@ void testLocalAggregate() { ); GroupExpression groupExpression = new GroupExpression(aggregate); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); List> expected = Lists.newArrayList(); @@ -202,7 +202,7 @@ void testGlobalAggregate() { ); GroupExpression groupExpression = new GroupExpression(aggregate); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); List> expected = Lists.newArrayList(); @@ -227,7 +227,7 @@ void testGlobalAggregateWithoutPartition() { ); GroupExpression groupExpression = new GroupExpression(aggregate); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); List> expected = Lists.newArrayList(); @@ -244,7 +244,7 @@ void testAssertNumRows() { ); GroupExpression groupExpression = new GroupExpression(assertNumRows); new Group(null, groupExpression, null); - RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(null, jobContext); List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); List> expected = Lists.newArrayList(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MatchingUtils.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MatchingUtils.java index c698cdc0f4f555..852d13cd0f3467 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MatchingUtils.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/MatchingUtils.java @@ -33,7 +33,7 @@ public class MatchingUtils { public static void assertMatches(Plan plan, PatternDescriptor patternDesc) { - Memo memo = new Memo(plan); + Memo memo = new Memo(null, plan); if (plan instanceof PhysicalPlan) { assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern, memo.getRoot().getPhysicalExpressions().get(0)).iterator().hasNext(), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java index 1e3484ae406d84..f4e5a6946ca07b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java @@ -36,13 +36,13 @@ public PlanParseChecker(String sql) { public PlanParseChecker matches(PatternDescriptor patternDesc) { assertMatches(() -> MatchingUtils.topDownFindMatching( - new Memo(parsedSupplier.get()).getRoot(), patternDesc.pattern)); + new Memo(null, parsedSupplier.get()).getRoot(), patternDesc.pattern)); return this; } public PlanParseChecker matchesFromRoot(PatternDescriptor patternDesc) { assertMatches(() -> new GroupExpressionMatching(patternDesc.pattern, - new Memo(parsedSupplier.get()).getRoot().getLogicalExpression()) + new Memo(null, parsedSupplier.get()).getRoot().getLogicalExpression()) .iterator().hasNext()); return this; }