Skip to content

Commit

Permalink
[CALCITE-6322] Casts to DECIMAL types are ignored
Browse files Browse the repository at this point in the history
Signed-off-by: Mihai Budiu <[email protected]>
  • Loading branch information
mihaibudiu committed Apr 22, 2024
1 parent 1566663 commit 7ef62e8
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 207 deletions.
8 changes: 8 additions & 0 deletions babel/src/test/resources/sql/big-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ SELECT SAFE_ADD(CAST(1.7e308 as DOUBLE), CAST(1.7e308 as DOUBLE)) as double_over

!ok

!if (fixed.calcite6328) {
SELECT SAFE_ADD(9, cast(9.999999999999999999e75 as DECIMAL(38, 19))) as decimal_overflow;
+------------------+
| decimal_overflow |
Expand All @@ -669,6 +670,7 @@ SELECT SAFE_ADD(9, cast(9.999999999999999999e75 as DECIMAL(38, 19))) as decimal_
(1 row)

!ok
!}

# NaN arguments should return NaN
SELECT SAFE_ADD(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
Expand Down Expand Up @@ -720,6 +722,7 @@ SELECT SAFE_DIVIDE(CAST(1.7e308 as DOUBLE),

!ok

!if (fixed.calcite6328) {
SELECT SAFE_DIVIDE(CAST(-3.5e75 AS DECIMAL(76, 0)),
CAST(3.5e-75 AS DECIMAL(76, 0))) as decimal_overflow;
+------------------+
Expand All @@ -730,6 +733,7 @@ SELECT SAFE_DIVIDE(CAST(-3.5e75 AS DECIMAL(76, 0)),
(1 row)

!ok
!}

# NaN arguments should return NaN
SELECT SAFE_DIVIDE(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
Expand Down Expand Up @@ -801,6 +805,7 @@ SELECT SAFE_MULTIPLY(CAST(1.7e308 as DOUBLE), CAST(3 as BIGINT)) as double_overf

!ok

!if (fixed.calcite6328) {
SELECT SAFE_MULTIPLY(CAST(-3.5e75 AS DECIMAL(76, 0)), CAST(10 AS BIGINT)) as decimal_overflow;
+------------------+
| decimal_overflow |
Expand All @@ -810,6 +815,7 @@ SELECT SAFE_MULTIPLY(CAST(-3.5e75 AS DECIMAL(76, 0)), CAST(10 AS BIGINT)) as dec
(1 row)

!ok
!}

# NaN arguments should return NaN
SELECT SAFE_MULTIPLY(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
Expand Down Expand Up @@ -916,6 +922,7 @@ SELECT SAFE_SUBTRACT(CAST(1.7e308 as DOUBLE), CAST(-1.7e308 as DOUBLE)) as doubl

!ok

!if (fixed.calcite6328) {
SELECT SAFE_SUBTRACT(9, cast(-9.999999999999999999e75 as DECIMAL(38, 19))) as decimal_overflow;
+------------------+
| decimal_overflow |
Expand All @@ -925,6 +932,7 @@ SELECT SAFE_SUBTRACT(9, cast(-9.999999999999999999e75 as DECIMAL(38, 19))) as de
(1 row)

!ok
!}

# NaN arguments should return NaN
SELECT SAFE_SUBTRACT(CAST('NaN' AS DOUBLE), CAST(3 as BIGINT)) as NaN_result;
Expand Down
106 changes: 53 additions & 53 deletions babel/src/test/resources/sql/redshift.iq
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ select approximate percentile_disc(0.5) within group (order by sal) from emp gro
# AVG
select avg(sal) from emp;
EXPR$0
2073.214285714286
2073.21
!ok

# COUNT
Expand Down Expand Up @@ -288,12 +288,12 @@ select percentile_disc(0.6) within group (order by sal) from emp group by deptno
# STDDEV_SAMP and STDDEV_POP
select stddev_samp(sal) from emp;
EXPR$0
1182.503223516271873450023122131824493408203125
1182.50
!ok

select stddev_pop(sal) from emp;
EXPR$0
1139.488618295281639802851714193820953369140625
1139.48
!ok

# SUM
Expand All @@ -308,24 +308,24 @@ EXPR$0
!ok

# VAR_SAMP and VAR_POP
select var_samp(sal) from emp;
select var_samp(CAST(sal AS DECIMAL(12, 4))) from emp;
EXPR$0
1398313.873626374
1398313.8736
!ok

select var_samp(distinct sal) from emp;
select var_samp(distinct CAST(sal AS DECIMAL(12, 4))) from emp;
EXPR$0
1512779.356060606
1512779.3560
!ok

select var_samp(all sal) from emp;
select var_samp(all CAST(sal AS DECIMAL(12, 4))) from emp;
EXPR$0
1398313.873626374
1398313.8736
!ok

select var_pop(sal) from emp;
select var_pop(CAST(sal AS DECIMAL(12, 4))) from emp;
EXPR$0
1298434.31122449
1298434.3112
!ok

# 4 Bit-Wise Aggregate Functions
Expand Down Expand Up @@ -378,10 +378,10 @@ select empno, avg(sal) over (order by empno rows unbounded preceding) from emp w
EMPNO, EXPR$1
7499, 1600.00
7521, 1425.00
7654, 1366.666666666667
7654, 1366.66
7698, 1737.50
7844, 1690.00
7900, 1566.666666666667
7900, 1566.66
!ok

# COUNT
Expand Down Expand Up @@ -523,78 +523,78 @@ select deptno, ratio_to_report(sal) over (partition by deptno) from emp;
!}

# STDDEV_POP
select empno, stddev_pop(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, stddev_pop(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, 0
7521, 100
7654, 478.42333648024424519462627358734607696533203125
7698, 478.42333648024424519462627358734607696533203125
7844, 522.0153254455275373402400873601436614990234375
7900, 522.0153254455275373402400873601436614990234375
7499, 0.0000
7521, 100.0000
7654, 478.4233
7698, 478.4233
7844, 522.0153
7900, 522.0153
!ok

# STDDEV_SAMP (synonym for STDDEV)
select empno, stddev_samp(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, stddev_samp(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, null
7521, 141.421356237309510106570087373256683349609375
7654, 585.9465277082316561063635163009166717529296875
7698, 585.9465277082316561063635163009166717529296875
7844, 602.7713773341707792496890760958194732666015625
7900, 602.7713773341707792496890760958194732666015625
7521, 141.4213
7654, 585.9465
7698, 585.9465
7844, 602.7713
7900, 602.7713
!ok

select empno, stddev(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, stddev(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, null
7521, 141.421356237309510106570087373256683349609375
7654, 585.9465277082316561063635163009166717529296875
7698, 585.9465277082316561063635163009166717529296875
7844, 602.7713773341707792496890760958194732666015625
7900, 602.7713773341707792496890760958194732666015625
7521, 141.4213
7654, 585.9465
7698, 585.9465
7844, 602.7713
7900, 602.7713
!ok

# SUM
select empno, sum(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, sum(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, 300.00
7521, 800.00
7654, 2200.00
7698, 2200.00
7844, 2200.00
7900, 2200.00
7499, 300.0000
7521, 800.0000
7654, 2200.0000
7698, 2200.0000
7844, 2200.0000
7900, 2200.0000
!ok

# VAR_POP
select empno, var_pop(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, var_pop(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, 0.0000
7521, 10000.0000
7654, 228888.888888889
7698, 228888.888888889
7654, 228888.8889
7698, 228888.8889
7844, 272500.0000
7900, 272500.0000
!ok

# VAR_SAMP (synonym for VARIANCE)
select empno, var_samp(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, var_samp(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, null
7521, 20000.0000
7654, 343333.3333333335
7698, 343333.3333333335
7844, 363333.3333333333
7900, 363333.3333333333
7654, 343333.3333
7698, 343333.3333
7844, 363333.3333
7900, 363333.3333
!ok

select empno, variance(comm) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
select empno, variance(CAST(comm AS DECIMAL(12, 4))) over (order by empno rows unbounded preceding) from emp where deptno = 30 order by 1;
EMPNO, EXPR$1
7499, null
7521, 20000.0000
7654, 343333.3333333335
7698, 343333.3333333335
7844, 363333.3333333333
7900, 363333.3333333333
7654, 343333.3333
7698, 343333.3333
7844, 363333.3333
7900, 363333.3333
!ok

# 5.2 Ranking functions
Expand Down Expand Up @@ -2013,7 +2013,7 @@ SELECT "JSON_EXTRACT_PATH_TEXT"('{"f2":{"f3":1},"f4":{"f5":99,"f6":"star"}}', 'f
# CAST and CONVERT
select cast(stddev_samp(sal) as dec(14, 2)) from emp;
EXPR$0
1182.503223516271873450023122131824493408203125
1182.50
!ok

select 123.456::decimal(8,4);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,36 @@ private Expression getConvertExpression(
return defaultExpression.get();
}

case DECIMAL: {
int precision = targetType.getPrecision();
int scale = targetType.getScale();
if (precision != RelDataType.PRECISION_NOT_SPECIFIED
&& scale != RelDataType.SCALE_NOT_SPECIFIED) {
if (sourceType.getSqlTypeName() == SqlTypeName.DECIMAL) {
// Cast from DECIMAL to DECIMAL, may adjust scale and precision.
return Expressions.call(
BuiltInMethod.DECIMAL_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
} else if (SqlTypeName.INT_TYPES.contains(sourceType.getSqlTypeName())) {
// Cast from INTEGER to DECIMAL, check for overflow
return Expressions.call(
BuiltInMethod.INTEGER_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
} else if (SqlTypeName.APPROX_TYPES.contains(sourceType.getSqlTypeName())) {
// Cast from FLOAT/DOUBLE to DECIMAL
return Expressions.call(
BuiltInMethod.FP_DECIMAL_CAST.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
}
}
return defaultExpression.get();
}
case BIGINT:
case INTEGER:
case TINYINT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ public static boolean isValidDecimalValue(@Nullable BigDecimal value, RelDataTyp
case DECIMAL:
final int intDigits = value.precision() - value.scale();
final int maxIntDigits = toType.getPrecision() - toType.getScale();
return intDigits <= maxIntDigits;
return (intDigits <= maxIntDigits) && (value.scale() <= toType.getScale());
default:
return true;
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/org/apache/calcite/util/Bug.java
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ public abstract class Bug {
* Fix failing quidem tests for FORMAT in CAST</a> is fixed. */
public static final boolean CALCITE_6375_FIXED = false;

/** Whether
* <a href="https://issues.apache.org/jira/browse/CALCITE-6328">[CALCITE-6328]
* The BigQuery functions SAFE_* do not match the BigQuery specification</a>
* is fixed. */
public static final boolean CALCITE_6328_FIXED = false;

/**
* Use this to flag temporary code.
*/
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ public enum BuiltInMethod {
ENUMERABLE_TO_LIST(ExtendedEnumerable.class, "toList"),
ENUMERABLE_TO_MAP(ExtendedEnumerable.class, "toMap", Function1.class, Function1.class),
AS_LIST(Primitive.class, "asList", Object.class),
DECIMAL_DECIMAL_CAST(Primitive.class, "decimalDecimalCast",
BigDecimal.class, int.class, int.class),
INTEGER_DECIMAL_CAST(Primitive.class, "integerDecimalCast",
Object.class, int.class, int.class),
FP_DECIMAL_CAST(Primitive.class, "fpDecimalCast",
Object.class, int.class, int.class),
INTEGER_CAST(Primitive.class, "integerCast", Primitive.class, Object.class),
MEMORY_GET0(MemoryFactory.Memory.class, "get"),
MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3814,7 +3814,8 @@ private SqlSpecialOperatorWithPolicy(String name, SqlKind kind, int prec, boolea
RexNode one = literal(1);

RexNode b = vDecimalNotNull(2);
RexNode half = literal(new BigDecimal(0.5), b.getType());
RelDataType decimal = typeFactory.createSqlType(SqlTypeName.DECIMAL, 2, 1);
RexNode half = literal(new BigDecimal("0.5"), decimal);

checkSimplify(add(a, zero), "?0.notNullInt1");
checkSimplify(add(zero, a), "?0.notNullInt1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ private static void assertRows(Interpreter interpreter,
final String sql = "select x, min(y), max(y), sum(y), avg(y)\n"
+ "from (values ('a', -1.2), ('a', 2.3), ('a', 15)) as t(x, y)\n"
+ "group by x";
sql(sql).returnsRows("[a, -1.2, 15.0, 16.1, 5.366666666666667]");
sql(sql).returnsRows("[a, -1.2, 15.0, 16.1, 5.3]");
}

@Test void testInterpretUnnest() {
Expand Down
Loading

0 comments on commit 7ef62e8

Please sign in to comment.