Skip to content

Commit

Permalink
[opt](nereids) refine expression estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjian.xzj authored and zhongjian.xzj committed Sep 19, 2024
1 parent 07a5362 commit 0c0bc72
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
Expand Down Expand Up @@ -128,49 +129,60 @@ public static ColumnStatistic estimate(Expression expression, Statistics stats)

@Override
public ColumnStatistic visit(Expression expr, Statistics context) {
ColumnStatistic stats = context.findColumnStatistics(expr);
if (stats != null) {
return stats;
}
List<Expression> childrenExpr = expr.children();
if (CollectionUtils.isEmpty(childrenExpr)) {
return ColumnStatistic.UNKNOWN;
}
return expr.child(0).accept(this, context);
}

//TODO: case-when need to re-implemented
@Override
public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) {
double ndv = caseWhen.getWhenClauses().size();
double width = 1;
if (caseWhen.getDefaultValue().isPresent()) {
ndv += 1;
}
for (WhenClause clause : caseWhen.getWhenClauses()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(clause.getResult(), context);
ndv = Math.max(ndv, colStats.ndv);
width = Math.max(width, clause.getResult().getDataType().width());
}
if (caseWhen.getDefaultValue().isPresent()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context);
ndv = Math.max(ndv, colStats.ndv);
width = Math.max(width, caseWhen.getDefaultValue().get().getDataType().width());
}
return new ColumnStatisticBuilder()
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setAvgSizeByte(width)
.setNumNulls(0)
.build();
}

@Override
public ColumnStatistic visitIf(If ifClause, Statistics context) {
double ndv = 2;
double width = 1;
ColumnStatistic colStatsThen = ExpressionEstimation.estimate(ifClause.child(1), context);
ndv = Math.max(ndv, colStatsThen.ndv);
width = Math.max(width, ifClause.child(1).getDataType().width());

ColumnStatistic colStatsElse = ExpressionEstimation.estimate(ifClause.child(2), context);
ndv = Math.max(ndv, colStatsElse.ndv);
width = Math.max(width, ifClause.child(2).getDataType().width());

return new ColumnStatisticBuilder()
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setAvgSizeByte(width)
.setNumNulls(0)
.build();
}
Expand Down Expand Up @@ -242,9 +254,9 @@ public ColumnStatistic visitLiteral(Literal literal, Statistics context) {
return new ColumnStatisticBuilder()
.setMaxValue(literalVal)
.setMinValue(literalVal)
.setNdv(1)
.setNdv(literal.isNullLiteral() ? 0 : 1)
.setNumNulls(literal.isNullLiteral() ? 1 : 0)
.setAvgSizeByte(1)
.setAvgSizeByte(literal.getWidth())
.setMinExpr(literal.toLegacyLiteral())
.setMaxExpr(literal.toLegacyLiteral())
.build();
Expand Down Expand Up @@ -343,8 +355,7 @@ public ColumnStatistic visitMin(Min min, Statistics context) {
return ColumnStatistic.UNKNOWN;
}
// if this is scalar agg, we will update count and ndv to 1 when visiting group clause
return new ColumnStatisticBuilder(columnStat)
.build();
return new ColumnStatisticBuilder(columnStat).build();
}

@Override
Expand All @@ -355,15 +366,21 @@ public ColumnStatistic visitMax(Max max, Statistics context) {
return ColumnStatistic.UNKNOWN;
}
// if this is scalar agg, we will update count and ndv to 1 when visiting group clause
return new ColumnStatisticBuilder(columnStat)
.build();
return new ColumnStatisticBuilder(columnStat).build();
}

@Override
public ColumnStatistic visitCount(Count count, Statistics context) {
double width = count.getDataType().width();
// for scalar agg, ndv and row count will be normalized by 1 in StatsCalculator.computeAggregate()
<<<<<<< HEAD
return new ColumnStatisticBuilder(ColumnStatistic.UNKNOWN).setAvgSizeByte(width).build();
=======
return new ColumnStatisticBuilder(ColumnStatistic.UNKNOWN)
.setCount(context.getRowCount())
.setAvgSizeByte(width)
.build();
>>>>>>> f0f7c7eaed ([opt](nereids) refine expression estimation)
}

// TODO: return a proper estimated stat after supports histogram
Expand Down Expand Up @@ -432,6 +449,24 @@ public ColumnStatistic visitAggregateExpression(AggregateExpression aggregateExp
return aggregateExpression.child().accept(this, context);
}

@Override
public ColumnStatistic visitAggregateFunction(AggregateFunction aggregateExpression,
Statistics context) {
if (aggregateExpression.children().size() == 1) {
ColumnStatistic columnStat = aggregateExpression.child(0).accept(this, context);
if (columnStat.isUnKnown) {
return ColumnStatistic.UNKNOWN;
}
// by default reset min/max as invalid to avoid misused
return new ColumnStatisticBuilder(columnStat)
.setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY)
.setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY)
.build();
} else {
return visit(aggregateExpression, context);
}
}

@Override
public ColumnStatistic visitComparisonPredicate(ComparisonPredicate cp, Statistics context) {
ColumnStatistic leftStats = cp.left().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic stats
} else {
double val = statsForRight.maxValue;
if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
selectivity = 0.0;
// do a lower bound protection to avoid using 0 directly
selectivity = RANGE_SELECTIVITY_THRESHOLD;
} else if (ndv >= 1.0) {
selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
} else {
Expand Down

0 comments on commit 0c0bc72

Please sign in to comment.