diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java index 5389fdccf721..4ea1d23701f5 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java @@ -31,6 +31,7 @@ import org.apache.calcite.linq4j.tree.ParameterExpression; import org.apache.calcite.linq4j.tree.Primitive; import org.apache.calcite.linq4j.tree.Statement; +import org.apache.calcite.linq4j.tree.Types; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; @@ -1374,11 +1375,26 @@ private Result toInnerStorageType(Result result, Type storageType) { } final Type storageType = currentStorageType != null ? currentStorageType : typeFactory.getJavaClass(dynamicParam.getType()); - final Expression valueExpression = + + final boolean isNumeric = SqlTypeFamily.NUMERIC.contains(dynamicParam.getType()); + + // For numeric types, use java.lang.Number to prevent cast exception + // when the parameter type differs from the target type + Expression argumentExpression = EnumUtils.convert( Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET.method, Expressions.constant("?" + dynamicParam.getIndex())), - storageType); + isNumeric ? java.lang.Number.class : storageType); + + // Short-circuit if the expression evaluates to null. The cast + // may throw a NullPointerException as it calls methods on the + // object such as longValue(). + Expression valueExpression = + Expressions.condition( + Expressions.equal(argumentExpression, Expressions.constant(null)), + Expressions.constant(null), + Types.castIfNecessary(storageType, argumentExpression)); + final ParameterExpression valueVariable = Expressions.parameter(valueExpression.getType(), list.newName("value_dynamic_param")); diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java index 087b3854d13d..8fe3de21ac09 100644 --- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java +++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java @@ -84,6 +84,7 @@ import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.parser.impl.SqlParserImpl; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql2rel.SqlToRelConverter.Config; import org.apache.calcite.test.schemata.catchall.CatchallSchema; import org.apache.calcite.test.schemata.foodmart.FoodmartSchema; @@ -8423,6 +8424,98 @@ private void checkGetTimestamp(Connection con) throws SQLException { }); } + @Test void bindByteParameter() { + for (SqlTypeName tpe : SqlTypeName.INT_TYPES) { + final String sql = + "with cte as (select cast(100 as " + tpe.getName() + ") as empid)" + + "select * from cte where empid = ?"; + CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setByte(1, (byte) 100); + }) + .returnsUnordered("EMPID=100"); + } + } + + @Test void bindShortParameter() { + for (SqlTypeName tpe : SqlTypeName.INT_TYPES) { + final String sql = + "with cte as (select cast(100 as " + tpe.getName() + ") as empid)" + + "select * from cte where empid = ?"; + + CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setShort(1, (short) 100); + }) + .returnsUnordered("EMPID=100"); + } + } + + @Test void bindOverflowingTinyIntParameter() { + final String sql = + "with cte as (select cast(300 as smallint) as empid)" + + "select * from cte where empid = cast(? as tinyint)"; + + java.sql.SQLException t = + assertThrows( + java.sql.SQLException.class, + () -> CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setShort(1, (short) 300); + }) + .returns("")); + + assertThat( + "message matches", + t.getMessage().contains("value is outside the range of java.lang.Byte")); + } + + @Test void bindIntParameter() { + for (SqlTypeName tpe : SqlTypeName.INT_TYPES) { + final String sql = + "with cte as (select cast(100 as " + tpe.getName() + ") as empid)" + + "select * from cte where empid = ?"; + + CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setInt(1, 100); + }) + .returnsUnordered("EMPID=100"); + } + } + + @Test void bindLongParameter() { + for (SqlTypeName tpe : SqlTypeName.INT_TYPES) { + final String sql = + "with cte as (select cast(100 as " + tpe.getName() + ") as empid)" + + "select * from cte where empid = ?"; + + CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setLong(1, 100); + }) + .returnsUnordered("EMPID=100"); + } + } + + @Test void bindNumericParameter() { + final String sql = + "with cte as (select cast(100 as numeric(5)) as empid)" + + "select * from cte where empid = ?"; + + CalciteAssert.hr() + .query(sql) + .consumesPreparedStatement(p -> { + p.setLong(1, 100); + }) + .returnsUnordered("EMPID=100"); + } + private static String sums(int n, boolean c) { final StringBuilder b = new StringBuilder(); for (int i = 0; i < n; i++) { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java index 2e49cbcb05c1..a7f7545b8e8a 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java @@ -622,9 +622,49 @@ public static UnaryExpression convert_(Expression expression, Type type, * operation that throws an exception if the target type is * overflowed. */ - public static UnaryExpression convertChecked(Expression expression, + public static Expression convertChecked(Expression expression, Type type) { - throw Extensions.todo(); + if (type == Byte.class + || type == Short.class + || type == Integer.class + || type == Long.class) { + Class typeClass = (Class) type; + + Object minValue; + Object maxValue; + + try { + minValue = typeClass.getField("MIN_VALUE").get(null); + maxValue = typeClass.getField("MAX_VALUE").get(null); + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + + ThrowStatement throwStmt = + Expressions.throw_( + Expressions.new_( + IllegalArgumentException.class, + Expressions.constant("value is outside the range of " + typeClass.getName()))); + + // Covers all lower precision types + Expression longValue = Expressions.call(expression, "longValue"); + + Expression minCheck = Expressions.lessThan(longValue, Expressions.constant(minValue)); + Expression maxCheck = Expressions.greaterThan(longValue, Expressions.constant(maxValue)); + + Primitive primitive = requireNonNull(Primitive.ofBox(type)); + String primitiveName = requireNonNull(primitive.primitiveName); + Expression convertExpr = Expressions.call(expression, primitiveName + "Value"); + + return Expressions.convert_( + Expressions.makeTernary( + ExpressionType.Conditional, + Expressions.or(minCheck, maxCheck), + Expressions.fromStatement(throwStmt), + convertExpr), type); + } + + throw new IllegalArgumentException("Type " + type.getTypeName() + " is not supported yet"); } /** @@ -2822,6 +2862,18 @@ public static SymbolDocumentInfo symbolDocument(String filename, throw Extensions.todo(); } + /** + * Create an expression from a statement. + */ + public static Expression fromStatement(Statement statement) { + FunctionExpression> lambda = + Expressions.lambda( + Blocks.toFunctionBlock(statement), + Collections.emptyList()); + + return Expressions.call(lambda, "apply"); + } + /** * Creates a statement that represents the throwing of an exception. */ diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java index 0be03c9fc08c..3dec960cf7a3 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java @@ -28,6 +28,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -428,11 +429,25 @@ public static Expression castIfNecessary(Type returnType, && Number.class.isAssignableFrom((Class) returnType) && type instanceof Class && Number.class.isAssignableFrom((Class) type)) { - // E.g. - // Integer foo(BigDecimal o) { - // return o.intValue(); - // } - return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType))); + + if (returnType == BigDecimal.class) { + return Expressions.call( + BigDecimal.class, + "valueOf", + Expressions.call(expression, "longValue")); + } else if ( + returnType == Byte.class + || returnType == Short.class + || returnType == Integer.class + || returnType == Long.class) { + return Expressions.convertChecked(expression, returnType); + } else { + // E.g. + // Integer foo(BigDecimal o) { + // return o.intValue(); + // } + return Expressions.unbox(expression, requireNonNull(Primitive.ofBox(returnType))); + } } if (Primitive.is(returnType) && !Primitive.is(type)) { // E.g.