From 9571e077b3386c1ade2eb5c02566b0b79f343487 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Wed, 28 Aug 2024 14:53:03 +0800 Subject: [PATCH] [10383 ] Support decimal operation not precision loss mode (10383) Signed-off-by: Yuan Zhou --- velox/docs/functions/spark/config.rst | 21 ++ velox/docs/functions/spark/decimal.rst | 25 ++- velox/docs/spark_functions.rst | 2 + .../functions/sparksql/DecimalArithmetic.cpp | 205 ++++++++++++++---- velox/functions/sparksql/DecimalArithmetic.h | 12 +- velox/functions/sparksql/DecimalUtil.h | 10 + velox/functions/sparksql/Register.cpp | 10 +- velox/functions/sparksql/Register.h | 5 + .../functions/sparksql/RegisterArithmetic.cpp | 13 +- velox/functions/sparksql/RegisterArithmetic.h | 5 +- velox/functions/sparksql/RegistrationConfig.h | 27 +++ .../sparksql/tests/DecimalArithmeticTest.cpp | 66 ++++++ .../sparksql/tests/DecimalUtilTest.cpp | 13 ++ 13 files changed, 356 insertions(+), 58 deletions(-) create mode 100644 velox/docs/functions/spark/config.rst create mode 100644 velox/functions/sparksql/RegistrationConfig.h diff --git a/velox/docs/functions/spark/config.rst b/velox/docs/functions/spark/config.rst new file mode 100644 index 000000000000..146d61efd741 --- /dev/null +++ b/velox/docs/functions/spark/config.rst @@ -0,0 +1,21 @@ +================================ +SparkRegistration Configuration +================================ + +struct SparkRegistrationConfig +--------------------- +.. list-table:: + :widths: 20 10 10 70 + :header-rows: 1 + + * - Property Name + - Type + - Default Value + - Description + * - allowPrecisionLoss + - bool + - true + - When true, establishing the result type of an arithmetic operation according to Hive behavior and SQL ANSI 2011 specification, i.e. + rounding the decimal part of the result if an exact representation is not + possible. Otherwise, NULL is returned when the actual result cannot be represented with the calculated decimal type. Now we support add, + subtract, multiply and divide operations. \ No newline at end of file diff --git a/velox/docs/functions/spark/decimal.rst b/velox/docs/functions/spark/decimal.rst index 19eee325f4b3..75814f6d8d25 100644 --- a/velox/docs/functions/spark/decimal.rst +++ b/velox/docs/functions/spark/decimal.rst @@ -33,8 +33,11 @@ Division p = p1 - s1 + s2 + max(6, s1 + p2 + 1) s = max(6, s1 + p2 + 1) +Decimal Precision and Scale Adjustment +<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + For above arithmetic operators, when the precision of result exceeds 38, -caps p at 38 and reduces the scale, in order to prevent the truncation of +caps p at 38 and reduces the scale when allowing precision loss, in order to prevent the truncation of the integer part of the decimals. Below formula illustrates how the result precision and scale are adjusted. @@ -43,6 +46,26 @@ precision and scale are adjusted. precision = 38 scale = max(38 - (p - s), min(s, 6)) +Caps p and s at 38 when not allowing precision loss. +For decimal addition, subtraction, multiplication, the precision and scale computation logic is same, +but for decimal division, it is different as following: +:: + + wholeDigits = min(38, p1 - s1 + s2); + fractionalDigits = min(38, max(6, s1 + p2 + 1)); + +If ``wholeDigits + fractionalDigits`` is more than 38: +:: + + p = 38 + s = fractionalDigits - (wholeDigits + fractionalDigits - 38) / 2 - 1 + +Otherwise: +:: + + p = wholeDigits + fractionalDigits + s = fractionalDigits + Users experience runtime errors when the actual result cannot be represented with the calculated decimal type. diff --git a/velox/docs/spark_functions.rst b/velox/docs/spark_functions.rst index 24c825ac1ef5..b42460ccfdac 100644 --- a/velox/docs/spark_functions.rst +++ b/velox/docs/spark_functions.rst @@ -4,6 +4,8 @@ Spark Functions The semantics of Spark functions match Spark 3.5 with ANSI OFF. +Spark functions can be registered by :doc:`struct SparkRegistrationConfig `. + .. toctree:: :maxdepth: 1 diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 61599bce10ea..e8b987dc3cab 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -23,13 +23,14 @@ namespace { struct DecimalAddSubtractBase { protected: + template void initializeBase(const std::vector& inputTypes) { auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); aScale_ = aScale; bScale_ = bScale; - auto [rPrecision, rScale] = - computeResultPrecisionScale(aPrecision, aScale_, bPrecision, bScale_); + auto [rPrecision, rScale] = computeResultPrecisionScale( + aPrecision, aScale_, bPrecision, bScale_); rPrecision_ = rPrecision; rScale_ = rScale; aRescale_ = computeRescaleFactor(aScale_, bScale_); @@ -253,10 +254,12 @@ struct DecimalAddSubtractBase { } // Computes the result precision and scale for decimal add and subtract - // operations following Hive's formulas. + // operations following Hive's formulas when `allowPrecisionLoss` is true. // If result is representable with long decimal, the result // scale is the maximum of 'aScale' and 'bScale'. If not, reduces result scale // and caps the result precision at 38. + // Caps p and s at 38 when not allowing precision loss. + template static std::pair computeResultPrecisionScale( uint8_t aPrecision, uint8_t aScale, @@ -265,7 +268,11 @@ struct DecimalAddSubtractBase { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return sparksql::DecimalUtil::adjustPrecisionScale(precision, scale); + if constexpr (allowPrecisionLoss) { + return sparksql::DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + return sparksql::DecimalUtil::bounded(precision, scale); + } } static uint8_t computeRescaleFactor(uint8_t fromScale, uint8_t toScale) { @@ -280,7 +287,7 @@ struct DecimalAddSubtractBase { uint8_t rScale_; }; -template +template struct DecimalAddFunction : DecimalAddSubtractBase { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -290,7 +297,7 @@ struct DecimalAddFunction : DecimalAddSubtractBase { const core::QueryConfig& /*config*/, A* /*a*/, B* /*b*/) { - initializeBase(inputTypes); + initializeBase(inputTypes); } template @@ -299,7 +306,7 @@ struct DecimalAddFunction : DecimalAddSubtractBase { } }; -template +template struct DecimalSubtractFunction : DecimalAddSubtractBase { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -309,7 +316,7 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase { const core::QueryConfig& /*config*/, A* /*a*/, B* /*b*/) { - initializeBase(inputTypes); + initializeBase(inputTypes); } template @@ -318,7 +325,7 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase { } }; -template +template struct DecimalMultiplyFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -330,8 +337,10 @@ struct DecimalMultiplyFunction { B* /*b*/) { auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); - auto [rPrecision, rScale] = DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); + auto [rPrecision, rScale] = allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale) + : DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale); rPrecision_ = rPrecision; deltaScale_ = aScale + bScale - rScale; } @@ -426,7 +435,7 @@ struct DecimalMultiplyFunction { int32_t deltaScale_; }; -template +template struct DecimalDivideFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -458,65 +467,102 @@ struct DecimalDivideFunction { uint8_t aScale, uint8_t bPrecision, uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return DecimalUtil::adjustPrecisionScale(precision, scale); + if constexpr (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + auto wholeDigits = std::min(38, aPrecision - aScale + bScale); + auto fractionDigits = std::min(38, std::max(6, aScale + bPrecision + 1)); + auto diff = (wholeDigits + fractionDigits) - 38; + if (diff > 0) { + fractionDigits -= diff / 2 + 1; + wholeDigits = 38 - fractionDigits; + } + return DecimalUtil::bounded(wholeDigits + fractionDigits, fractionDigits); + } } uint8_t aRescale_; uint8_t rPrecision_; }; -template