Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt](nereids) refine expression estimation #40698

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,49 +128,60 @@ public static ColumnStatistic estimate(Expression expression, Statistics stats)

@Override
public ColumnStatistic visit(Expression expr, Statistics context) {
ColumnStatistic stats = context.findColumnStatistics(expr);
if (stats != null) {
return stats;
}
List<Expression> childrenExpr = expr.children();
if (CollectionUtils.isEmpty(childrenExpr)) {
return ColumnStatistic.UNKNOWN;
}
return expr.child(0).accept(this, context);
}

//TODO: case-when need to re-implemented
@Override
public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) {
double ndv = caseWhen.getWhenClauses().size();
double width = 1;
if (caseWhen.getDefaultValue().isPresent()) {
ndv += 1;
}
for (WhenClause clause : caseWhen.getWhenClauses()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(clause.getResult(), context);
ndv = Math.max(ndv, colStats.ndv);
width = Math.max(width, clause.getResult().getDataType().width());
}
if (caseWhen.getDefaultValue().isPresent()) {
ColumnStatistic colStats = ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context);
ndv = Math.max(ndv, colStats.ndv);
width = Math.max(width, caseWhen.getDefaultValue().get().getDataType().width());
}
return new ColumnStatisticBuilder()
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setAvgSizeByte(width)
.setNumNulls(0)
.build();
}

@Override
public ColumnStatistic visitIf(If ifClause, Statistics context) {
double ndv = 2;
double width = 1;
ColumnStatistic colStatsThen = ExpressionEstimation.estimate(ifClause.child(1), context);
ndv = Math.max(ndv, colStatsThen.ndv);
width = Math.max(width, ifClause.child(1).getDataType().width());

ColumnStatistic colStatsElse = ExpressionEstimation.estimate(ifClause.child(2), context);
ndv = Math.max(ndv, colStatsElse.ndv);
width = Math.max(width, ifClause.child(2).getDataType().width());

return new ColumnStatisticBuilder()
.setNdv(ndv)
.setMinValue(Double.NEGATIVE_INFINITY)
.setMaxValue(Double.POSITIVE_INFINITY)
.setAvgSizeByte(8)
.setAvgSizeByte(width)
.setNumNulls(0)
.build();
}
Expand Down Expand Up @@ -242,9 +253,9 @@ public ColumnStatistic visitLiteral(Literal literal, Statistics context) {
return new ColumnStatisticBuilder()
.setMaxValue(literalVal)
.setMinValue(literalVal)
.setNdv(1)
.setNdv(literal.isNullLiteral() ? 0 : 1)
.setNumNulls(literal.isNullLiteral() ? 1 : 0)
.setAvgSizeByte(1)
.setAvgSizeByte(literal.getDataType().width())
.setMinExpr(literal.toLegacyLiteral())
.setMaxExpr(literal.toLegacyLiteral())
.build();
Expand Down Expand Up @@ -343,8 +354,7 @@ public ColumnStatistic visitMin(Min min, Statistics context) {
return ColumnStatistic.UNKNOWN;
}
// if this is scalar agg, we will update count and ndv to 1 when visiting group clause
return new ColumnStatisticBuilder(columnStat)
.build();
return new ColumnStatisticBuilder(columnStat).build();
}

@Override
Expand All @@ -355,8 +365,7 @@ public ColumnStatistic visitMax(Max max, Statistics context) {
return ColumnStatistic.UNKNOWN;
}
// if this is scalar agg, we will update count and ndv to 1 when visiting group clause
return new ColumnStatisticBuilder(columnStat)
.build();
return new ColumnStatisticBuilder(columnStat).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ private Statistics estimateEqualTo(ComparisonPredicate cp, ColumnStatistic stats
} else {
double val = statsForRight.maxValue;
if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
// TODO: will fix this in the next pr by adding RangeScalable protection
selectivity = 0.0;
} else if (ndv >= 1.0) {
selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ private StringType() {
super(-1);
}

@Override
public int width() {
return len;
}

@Override
public Type toCatalogDataType() {
return Type.STRING;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
*/
public abstract class CharacterType extends PrimitiveType {

public static final int DEFAULT_SLOT_SIZE = 20;
private static final int WIDTH = 16;
public static final int DEFAULT_WIDTH = WIDTH;

protected final int len;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ public void normalizeAvgSizeByte(SlotReference slot) {
// When defining SQL schemas, users often tend to set the length of string \
// fields much longer than actually needed for storage.
if (slot.getDataType() instanceof CharacterType) {
avgSizeByte = Math.min(avgSizeByte,
CharacterType.DEFAULT_SLOT_SIZE);
avgSizeByte = Math.min(avgSizeByte, CharacterType.DEFAULT_WIDTH);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public double computeTupleSize(List<Slot> slots) {
for (Slot slot : slots) {
ColumnStatistic s = expressionToColumnStats.get(slot);
if (s != null) {
tempSize += Math.max(1, Math.min(CharacterType.DEFAULT_SLOT_SIZE, s.avgSizeByte));
tempSize += Math.max(1, Math.min(CharacterType.DEFAULT_WIDTH, s.avgSizeByte));
}
}
tupleSize = Math.max(1, tempSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -44,6 +51,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -357,6 +365,7 @@ public void testCaseWhen() {
CaseWhen caseWhen = new CaseWhen(whens);
ColumnStatistic est = ExpressionEstimation.estimate(caseWhen, stats);
Assertions.assertEquals(est.ndv, 100);
Assertions.assertEquals(est.avgSizeByte, 16);
}

@Test
Expand All @@ -383,5 +392,59 @@ public void testIf() {
If ifClause = new If(BooleanLiteral.TRUE, a, b);
ColumnStatistic est = ExpressionEstimation.estimate(ifClause, stats);
Assertions.assertEquals(est.ndv, 100);
Assertions.assertEquals(est.avgSizeByte, 16);
}

@Test
public void testLiteral() {
Statistics stats = new Statistics(1000, new HashMap<>());

BigIntLiteral l1 = new BigIntLiteral(1000000);
ColumnStatistic est = ExpressionEstimation.estimate(l1, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 8);
Assertions.assertEquals(est.numNulls, 0);

VarcharLiteral l2 = new VarcharLiteral("abcdefghij");
est = ExpressionEstimation.estimate(l2, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 10);
Assertions.assertEquals(est.numNulls, 0);

DoubleLiteral l3 = new DoubleLiteral(0.01);
est = ExpressionEstimation.estimate(l3, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 8);
Assertions.assertEquals(est.numNulls, 0);

DateV2Literal l4 = new DateV2Literal("2024-09-10");
est = ExpressionEstimation.estimate(l4, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 4);
Assertions.assertEquals(est.numNulls, 0);

DateTimeLiteral l5 = new DateTimeLiteral("2024-09-10 00:00:00");
est = ExpressionEstimation.estimate(l5, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 16);
Assertions.assertEquals(est.numNulls, 0);

BooleanLiteral l6 = BooleanLiteral.TRUE;
est = ExpressionEstimation.estimate(l6, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 1);
Assertions.assertEquals(est.numNulls, 0);

DecimalLiteral l7 = new DecimalLiteral(BigDecimal.valueOf(2024.0928));
est = ExpressionEstimation.estimate(l7, stats);
Assertions.assertEquals(est.ndv, 1);
Assertions.assertEquals(est.avgSizeByte, 16);
Assertions.assertEquals(est.numNulls, 0);

NullLiteral l8 = new NullLiteral();
est = ExpressionEstimation.estimate(l8, stats);
Assertions.assertEquals(est.ndv, 0);
Assertions.assertEquals(est.avgSizeByte, 1);
Assertions.assertEquals(est.numNulls, 1);
}
}
Loading