diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 7f2b110331d99..8db1aa8718d64 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -101,6 +101,9 @@ class QueryConfig { static constexpr const char* kCastMatchStructByName = "cast_match_struct_by_name"; + // This flags forces to bound the decimal precision. + static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss"; + /// Used for backpressure to block local exchange producers when the local /// exchange buffer reaches or exceeds this size. static constexpr const char* kMaxLocalExchangeBufferSize = @@ -496,6 +499,10 @@ class QueryConfig { return get(kCastMatchStructByName, false); } + bool isAllowPrecisionLoss() const { + return get(kAllowPrecisionLoss, true); + } + bool codegenEnabled() const { return get(kCodegenEnabled, false); } diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 96b7e0cc0200e..e4f09ca42f6fa 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -413,14 +413,17 @@ class Addition { } inline static std::pair computeResultPrecisionScale( - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale) { + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale, + const bool allowPrecisionLoss) { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return DecimalUtil::adjustPrecisionScale(precision, scale); + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale(precision, scale) + : DecimalUtil::bounded(precision, scale); } }; @@ -461,12 +464,13 @@ class Subtraction { } inline static std::pair computeResultPrecisionScale( - uint8_t aPrecision, - uint8_t aScale, - uint8_t bPrecision, - uint8_t bScale) { + const uint8_t aPrecision, + const uint8_t aScale, + const uint8_t bPrecision, + const uint8_t bScale, + const bool allowPrecisionLoss) { return Addition::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); } }; @@ -566,9 +570,12 @@ class Multiply { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - return DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); + uint8_t bScale, + const bool allowPrecisionLoss) { + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale) + : DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale); } private: @@ -616,10 +623,22 @@ class Divide { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return DecimalUtil::adjustPrecisionScale(precision, scale); + uint8_t bScale, + bool allowPrecisionLoss) { + if (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + auto intDig = std::min(38, aPrecision - aScale + bScale); + auto decDig = + std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff = + (intDig + decDig) - 38; + if (diff > 0) { + decDig -= diff / 2 + 1 intDig = 38 - decDig + } + return DecimalUtil::bounded(intDig + decDig, decDig); + } } }; @@ -694,8 +713,9 @@ std::shared_ptr createDecimalFunction( const auto& bType = inputArgs[1].type; const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + const bool allowPrecisionLoss = config.isAllowPrecisionLoss(); const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); const uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); const uint8_t bRescale = diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index fbe5da77809ec..d638395f4c5d3 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -46,6 +46,12 @@ class DecimalUtil { } } + inline static std::pair bounded( + const uint8_t rPrecision, + const uint8_t rScale) { + return {std::min(rPrecision, 38), std::min(rScale, 38)}; + } + /// @brief Convert int256 value to int64 or int128, set overflow to true if /// value cannot convert to specific type. /// @return The converted value.