From 067ede14d8331c8ad2c53d8acf96ef12789bb6f0 Mon Sep 17 00:00:00 2001 From: "zhongjian.xzj" Date: Sat, 14 Sep 2024 17:45:18 +0800 Subject: [PATCH] [opt](nereids) refine expression estimation --- .../nereids/stats/ExpressionEstimation.java | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 710a01d59b34ce4..578ba1b00446350 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -41,6 +41,7 @@ import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; @@ -188,6 +189,10 @@ public ColumnStatistic visitIf(If ifClause, Statistics context) { @Override public ColumnStatistic visitCast(Cast cast, Statistics context) { + ColumnStatistic stats = context.findColumnStatistics(cast); + if (stats != null) { + return stats; + } ColumnStatistic childColStats = cast.child().accept(this, context); Preconditions.checkNotNull(childColStats, "childColStats is null"); return castMinMax(childColStats, cast.getDataType()); @@ -350,10 +355,7 @@ public ColumnStatistic visitMin(Min min, Statistics context) { return ColumnStatistic.UNKNOWN; } // if this is scalar agg, we will update count and ndv to 1 when visiting group clause - return new ColumnStatisticBuilder(columnStat) - .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) - .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) - .build(); + return new ColumnStatisticBuilder(columnStat).build(); } @Override @@ -364,10 +366,7 @@ public ColumnStatistic visitMax(Max max, Statistics context) { return ColumnStatistic.UNKNOWN; } // if this is scalar agg, we will update count and ndv to 1 when visiting group clause - return new ColumnStatisticBuilder(columnStat) - .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) - .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) - .build(); + return new ColumnStatisticBuilder(columnStat).build(); } @Override @@ -383,27 +382,13 @@ public ColumnStatistic visitCount(Count count, Statistics context) { // TODO: return a proper estimated stat after supports histogram @Override public ColumnStatistic visitSum(Sum sum, Statistics context) { - ColumnStatistic columnStat = sum.child().accept(this, context); - if (columnStat.isUnKnown) { - return ColumnStatistic.UNKNOWN; - } - return new ColumnStatisticBuilder(columnStat) - .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) - .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) - .build(); + return sum.child().accept(this, context); } // TODO: return a proper estimated stat after supports histogram @Override public ColumnStatistic visitAvg(Avg avg, Statistics context) { - ColumnStatistic columnStat = avg.child().accept(this, context); - if (columnStat.isUnKnown) { - return ColumnStatistic.UNKNOWN; - } - return new ColumnStatisticBuilder(columnStat) - .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) - .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) - .build(); + return avg.child().accept(this, context); } @Override @@ -467,6 +452,24 @@ public ColumnStatistic visitAggregateExpression(AggregateExpression aggregateExp .build(); } + @Override + public ColumnStatistic visitAggregateFunction(AggregateFunction aggregateExpression, + Statistics context) { + if (aggregateExpression.children().size() == 1) { + ColumnStatistic columnStat = aggregateExpression.child(0).accept(this, context); + if (columnStat.isUnKnown) { + return ColumnStatistic.UNKNOWN; + } + // by default reset min/max as invalid to avoid misused + return new ColumnStatisticBuilder(columnStat) + .setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY) + .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY) + .build(); + } else { + return visit(aggregateExpression, context); + } + } + @Override public ColumnStatistic visitComparisonPredicate(ComparisonPredicate cp, Statistics context) { ColumnStatistic leftStats = cp.left().accept(this, context);