diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 126e90417213125..f76f63218ca0d8c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -41,6 +41,7 @@ import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; @@ -128,6 +129,10 @@ public static ColumnStatistic estimate(Expression expression, Statistics stats) @Override public ColumnStatistic visit(Expression expr, Statistics context) { + ColumnStatistic stats = context.findColumnStatistics(expr); + if (stats != null) { + return stats; + } List childrenExpr = expr.children(); if (CollectionUtils.isEmpty(childrenExpr)) { return ColumnStatistic.UNKNOWN; @@ -135,26 +140,28 @@ public ColumnStatistic visit(Expression expr, Statistics context) { return expr.child(0).accept(this, context); } - //TODO: case-when need to re-implemented @Override public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) { double ndv = caseWhen.getWhenClauses().size(); + double width = 1; if (caseWhen.getDefaultValue().isPresent()) { ndv += 1; } for (WhenClause clause : caseWhen.getWhenClauses()) { ColumnStatistic colStats = ExpressionEstimation.estimate(clause.getResult(), context); ndv = Math.max(ndv, colStats.ndv); + width = Math.max(width, clause.getResult().getDataType().width()); } if (caseWhen.getDefaultValue().isPresent()) { ColumnStatistic colStats = ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context); ndv = Math.max(ndv, colStats.ndv); + width = Math.max(width, caseWhen.getDefaultValue().get().getDataType().width()); } return new ColumnStatisticBuilder() .setNdv(ndv) .setMinValue(Double.NEGATIVE_INFINITY) .setMaxValue(Double.POSITIVE_INFINITY) - .setAvgSizeByte(8) + .setAvgSizeByte(width) .setNumNulls(0) .build(); } @@ -162,15 +169,20 @@ public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) { @Override public ColumnStatistic visitIf(If ifClause, Statistics context) { double ndv = 2; + double width = 1; ColumnStatistic colStatsThen = ExpressionEstimation.estimate(ifClause.child(1), context); ndv = Math.max(ndv, colStatsThen.ndv); + width = Math.max(width, ifClause.child(1).getDataType().width()); + ColumnStatistic colStatsElse = ExpressionEstimation.estimate(ifClause.child(2), context); ndv = Math.max(ndv, colStatsElse.ndv); + width = Math.max(width, ifClause.child(2).getDataType().width()); + return new ColumnStatisticBuilder() .setNdv(ndv) .setMinValue(Double.NEGATIVE_INFINITY) .setMaxValue(Double.POSITIVE_INFINITY) - .setAvgSizeByte(8) + .setAvgSizeByte(width) .setNumNulls(0) .build(); } @@ -242,9 +254,9 @@ public ColumnStatistic visitLiteral(Literal literal, Statistics context) { return new ColumnStatisticBuilder() .setMaxValue(literalVal) .setMinValue(literalVal) - .setNdv(1) + .setNdv(literal.isNullLiteral() ? 0 : 1) .setNumNulls(literal.isNullLiteral() ? 1 : 0) - .setAvgSizeByte(1) + .setAvgSizeByte(literal.getWidth()) .setMinExpr(literal.toLegacyLiteral()) .setMaxExpr(literal.toLegacyLiteral()) .build(); @@ -343,8 +355,7 @@ public ColumnStatistic visitMin(Min min, Statistics context) { return ColumnStatistic.UNKNOWN; } // if this is scalar agg, we will update count and ndv to 1 when visiting group clause - return new ColumnStatisticBuilder(columnStat) - .build(); + return new ColumnStatisticBuilder(columnStat).build(); } @Override @@ -355,15 +366,21 @@ public ColumnStatistic visitMax(Max max, Statistics context) { return ColumnStatistic.UNKNOWN; } // if this is scalar agg, we will update count and ndv to 1 when visiting group clause - return new ColumnStatisticBuilder(columnStat) - .build(); + return new ColumnStatisticBuilder(columnStat).build(); } @Override public ColumnStatistic visitCount(Count count, Statistics context) { double width = count.getDataType().width(); // for scalar agg, ndv and row count will be normalized by 1 in StatsCalculator.computeAggregate() +<<<<<<< HEAD return new ColumnStatisticBuilder(ColumnStatistic.UNKNOWN).setAvgSizeByte(width).build(); +======= + return new ColumnStatisticBuilder(ColumnStatistic.UNKNOWN) + .setCount(context.getRowCount()) + .setAvgSizeByte(width) + .build(); +>>>>>>> f0f7c7eaed ([opt](nereids) refine expression estimation) } // TODO: return a proper estimated stat after supports histogram @@ -432,6 +449,24 @@ public ColumnStatistic visitAggregateExpression(AggregateExpression aggregateExp return aggregateExpression.child().accept(this, context); } + @Override + public ColumnStatistic visitAggregateFunction(AggregateFunction aggregateExpression, + Statistics context) { + if (aggregateExpression.children().size() == 1) { + ColumnStatistic columnStat = aggregateExpression.child(0).accept(this, context); + if (columnStat.isUnKnown) { + return ColumnStatistic.UNKNOWN; + } + // by default reset min/max as invalid to avoid misused + return new ColumnStatisticBuilder(columnStat) + .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) + .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) + .build(); + } else { + return visit(aggregateExpression, context); + } + } + @Override public ColumnStatistic visitComparisonPredicate(ComparisonPredicate cp, Statistics context) { ColumnStatistic leftStats = cp.left().accept(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index b3576a0e58e61e6..7a8a698713ac35e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -334,7 +334,8 @@ private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic stats } else { double val = statsForRight.maxValue; if (val > statsForLeft.maxValue || val < statsForLeft.minValue) { - selectivity = 0.0; + // do a lower bound protection to avoid using 0 directly + selectivity = RANGE_SELECTIVITY_THRESHOLD; } else if (ndv >= 1.0) { selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv); } else {