Skip to content

Commit

Permalink
[fix](Nereids) simplify decimal comparison wrong when cast to smaller…
Browse files Browse the repository at this point in the history
… scale (apache#41151) (apache#41364)

pick from master apache#41151
  • Loading branch information
morrySnow authored Sep 26, 2024
1 parent cb98961 commit 446c6b5
Show file tree
Hide file tree
Showing 14 changed files with 2,212 additions and 2,039 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.google.common.collect.ImmutableList;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;

/**
Expand Down Expand Up @@ -112,8 +113,14 @@ public static Expression simplifyCast(Cast cast) {
return new DecimalV3Literal(decimalV3Type,
new BigDecimal(((BigIntLiteral) child).getValue()));
} else if (child instanceof DecimalV3Literal) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue());
DecimalV3Type childType = (DecimalV3Type) child.getDataType();
if (childType.getRange() <= decimalV3Type.getRange()) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue()
.setScale(decimalV3Type.getScale(), RoundingMode.HALF_UP));
} else {
return cast;
}
}
}
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
Expand All @@ -253,24 +254,25 @@ private static Expression processDecimalV3TypeCoercion(ComparisonPredicate compa
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal.roundFloor(toScale)));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left,
literal.roundCeiling(toScale));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal.roundCeiling(toScale)));
}
}
} else if (left.getDataType().isIntegerLikeType()) {
return processIntegerDecimalLiteralComparison(comparisonPredicate, left,
literal.getValue());
return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal.getValue());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.Cast;
Expand All @@ -25,7 +26,6 @@
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.math.BigDecimal;
Expand Down Expand Up @@ -58,15 +58,17 @@ public static Expression simplify(ComparisonPredicate cp) {
if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& ((DecimalV3Type) left.getDataType()).getScale()
>= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale()
&& right instanceof DecimalV3Literal) {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
try {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
} catch (ArithmeticException e) {
return cp;
}
}

if (left != cp.left() || right != cp.right()) {
return cp.withChildren(left, right);
} else {
return cp;
}
return cp;
}

private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
Expand All @@ -80,13 +82,16 @@ private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3
}

Expression castChild = left.child();
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
if (!(castChild.getDataType() instanceof DecimalV3Type)) {
throw new AnalysisException("cast child's type should be DecimalV3Type, but its type is "
+ castChild.getDataType().toSql());
}
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) {
if (scale <= leftType.getScale() && precision - scale <= leftType.getRange()) {
// precision and scale of literal all smaller than left, we don't need the cast
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3TypeLooseCheck(leftType.getPrecision(), leftType.getScale()),
trailingZerosValue);
trailingZerosValue.setScale(leftType.getScale(), RoundingMode.UNNECESSARY));
return cp.withChildren(castChild, newRight);
} else {
return cp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,11 @@ public double getDouble() {
}

public DecimalV3Literal roundCeiling(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.CEILING));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.CEILING));
}

public DecimalV3Literal roundFloor(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.FLOOR));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.FLOOR));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public static DataType convertPrimitiveFromStrings(List<String> types) {
case "decimalv3":
switch (types.size()) {
case 1:
dataType = DecimalV3Type.CATALOG_DEFAULT;
dataType = DecimalV3Type.createDecimalV3Type(38, 9);
break;
case 2:
dataType = DecimalV3Type.createDecimalV3Type(Integer.parseInt(types.get(1)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.IntegerType;
Expand All @@ -42,7 +41,6 @@
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand All @@ -54,17 +52,17 @@ public void testSimplify() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
ExpressionRewrite.bottomUp(SimplifyCastRule.INSTANCE))
);
assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'", StringType.INSTANCE);
assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'",
VarcharType.createVarcharType(-1));
assertRewriteAfterSimplify("CAST(1 AS DECIMAL)", "1.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1000 AS DECIMAL)", "1000.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1 AS DECIMALV3)", "1",
DecimalV3Type.createDecimalV3Type(9, 0));
assertRewriteAfterSimplify("CAST(1000 AS DECIMALV3)", "1000",
DecimalV3Type.createDecimalV3Type(9, 0));

assertRewrite(new Cast(new VarcharLiteral("1"), StringType.INSTANCE),
new StringLiteral("1"));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT),
new VarcharLiteral("1", -1));
assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1.000000000")));
assertRewrite(new Cast(new SmallIntLiteral((short) 1000), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1000.000000000")));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));

Expression tinyIntLiteral = new TinyIntLiteral((byte) 12);
// cast tinyint as tinyint
Expand Down Expand Up @@ -143,17 +141,20 @@ public void testSimplify() {
// cast char(5) as string
assertRewrite(new Cast(charLiteral, StringType.INSTANCE), new StringLiteral("12345"));

Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
new BigDecimal("12.0"));
Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 3),
new BigDecimal("12.000"));
// cast decimalv3(3,1) as decimalv3(5,1)
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(5, 1)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 1),
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(7, 3)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 3),
new BigDecimal("12.000")));
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(3, 1)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
new BigDecimal("12.0")));

assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1)),
new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(2, 1)));

// TODO unsupported but should?
// TODO unsupported, supported by org.apache.doris.nereids.trees.expressions.literal.Literal.uncheckedCastTo
// cast tinyint as smallint
assertRewrite(new Cast(tinyIntLiteral, SmallIntType.INSTANCE),
new Cast(tinyIntLiteral, SmallIntType.INSTANCE));
Expand Down Expand Up @@ -186,13 +187,4 @@ public void testSimplify() {
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(6, 1),
new BigDecimal("12.0")));
}

private void assertRewriteAfterSimplify(String expr, String expected, DataType expectedType) {
Expression needRewriteExpression = PARSER.parseExpression(expr);
Expression rewritten = executor.rewrite(needRewriteExpression, context);
Expression expectedExpression = PARSER.parseExpression(expected);
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
Assertions.assertEquals(expectedType, rewritten.getDataType());
}

}
Loading

0 comments on commit 446c6b5

Please sign in to comment.