diff --git a/velox/common/base/BitUtil.h b/velox/common/base/BitUtil.h index fbe8600d0163..7b70e763700f 100644 --- a/velox/common/base/BitUtil.h +++ b/velox/common/base/BitUtil.h @@ -693,6 +693,13 @@ inline int32_t countLeadingZeros(uint64_t word) { return __builtin_clzll(word); } +inline int32_t countLeadingZerosUint128(__uint128_t word) { + uint64_t hi = word >> 64; + uint64_t lo = static_cast(word); + return (hi == 0) ? 64 + bits::countLeadingZeros(lo) + : bits::countLeadingZeros(hi); +} + inline uint64_t nextPowerOfTwo(uint64_t size) { if (size == 0) { return 0; diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index c15c0a0d58f1..f3a9893c8133 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -26,6 +26,7 @@ #include "velox/expression/StringWriter.h" #include "velox/external/date/tz.h" #include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/type/DecimalUtilOp.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/FunctionVector.h" #include "velox/vector/SelectivityVector.h" @@ -201,6 +202,30 @@ void applyDoubleToDecimalCastKernel( } }); } + +template +void applyVarCharToDecimalCastKernel( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& toType, + VectorPtr castResult) { + auto sourceVector = input.as>(); + auto castResultRawBuffer = + castResult->asUnchecked>()->mutableRawValues(); + const auto& toPrecisionScale = getDecimalPrecisionScale(*toType); + context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { + auto rescaledValue = DecimalUtilOp::rescaleVarchar( + sourceVector->valueAt(row), + toPrecisionScale.first, + toPrecisionScale.second); + if (rescaledValue.has_value()) { + castResultRawBuffer[row] = rescaledValue.value(); + } else { + castResult->setNull(row, true); + } + }); +} } // namespace template @@ -635,6 +660,16 @@ VectorPtr CastExpr::applyDecimal( } break; } + case TypeKind::VARCHAR: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyVarCharToDecimalCastKernel( + rows, input, context, toType, castResult); + } else { + applyVarCharToDecimalCastKernel( + rows, input, context, toType, castResult); + } + break; + } default: VELOX_UNSUPPORTED( "Cast from {} to {} is not supported", diff --git a/velox/expression/tests/CastExprTest.cpp b/velox/expression/tests/CastExprTest.cpp index dd4fd72fcb0d..deb04a3ce001 100644 --- a/velox/expression/tests/CastExprTest.cpp +++ b/velox/expression/tests/CastExprTest.cpp @@ -826,6 +826,26 @@ TEST_F(CastExprTest, bigintToDecimal) { "Cannot cast BIGINT '100' to DECIMAL(17,16)"); } +TEST_F(CastExprTest, varcharToDecimal) { + // varchar to short decimal +// auto input = makeFlatVector({"-3", "177"}); +// testComplexCast( +// "c0", input, makeShortDecimalFlatVector({-300, 17700}, DECIMAL(6, 2))); + +// // varchar to long decimal +// auto input2 = makeFlatVector( +// {"-300000001234567891234.5", "1771234.5678912345678"}); +// testComplexCast( +// "c0", input2, makeLongDecimalFlatVector({-300, 17700}, DECIMAL(32, 7))); + + auto input3 = makeFlatVector({"9999999999.99", "9999999999.99"}); + testComplexCast( + "c0", input3, makeLongDecimalFlatVector( + {-30'000'000'000, + -20'000'000'000}, + DECIMAL(12, 2))); +} + TEST_F(CastExprTest, castInTry) { // Test try(cast(array(varchar) as array(bigint))) whose input vector is // wrapped in dictinary encoding. The row of ["2a"] should trigger an error diff --git a/velox/functions/prestosql/DecimalArithmetic.cpp b/velox/functions/prestosql/DecimalArithmetic.cpp index 053a86b0ffad..6e5948f82add 100644 --- a/velox/functions/prestosql/DecimalArithmetic.cpp +++ b/velox/functions/prestosql/DecimalArithmetic.cpp @@ -32,8 +32,22 @@ class DecimalBaseFunction : public exec::VectorFunction { DecimalBaseFunction( uint8_t aRescale, uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, const TypePtr& resultType) - : aRescale_(aRescale), bRescale_(bRescale), resultType_(resultType) {} + : aRescale_(aRescale), + bRescale_(bRescale), + aPrecision_(aPrecision), + aScale_(aScale), + bPrecision_(bPrecision), + bScale_(bScale), + rPrecision_(rPrecision), + rScale_(rScale), + resultType_(resultType) {} void apply( const SelectivityVector& rows, @@ -48,8 +62,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto flatValues = args[1]->asUnchecked>(); auto rawValues = flatValues->mutableRawValues(); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], constant, rawValues[row], aRescale_, bRescale_); + rawResults[row], + constant, + rawValues[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { // Fast path for (flat, const). @@ -57,8 +86,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto constant = args[1]->asUnchecked>()->valueAt(0); auto rawValues = flatValues->mutableRawValues(); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], rawValues[row], constant, aRescale_, bRescale_); + rawResults[row], + rawValues[row], + constant, + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { // Fast path for (flat, flat). @@ -66,9 +110,25 @@ class DecimalBaseFunction : public exec::VectorFunction { auto rawA = flatA->mutableRawValues(); auto flatB = args[1]->asUnchecked>(); auto rawB = flatB->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], rawA[row], rawB[row], aRescale_, bRescale_); + rawResults[row], + rawA[row], + rawB[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else { // Fast path if one or more arguments are encoded. @@ -76,12 +136,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto a = decodedArgs.at(0); auto b = decodedArgs.at(1); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( rawResults[row], a->valueAt(row), b->valueAt(row), aRescale_, - bRescale_); + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } } @@ -101,14 +172,31 @@ class DecimalBaseFunction : public exec::VectorFunction { const uint8_t aRescale_; const uint8_t bRescale_; + const uint8_t aPrecision_; + const uint8_t aScale_; + const uint8_t bPrecision_; + const uint8_t bScale_; + const uint8_t rPrecision_; + const uint8_t rScale_; const TypePtr resultType_; }; class Addition { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) #if defined(__has_feature) #if __has_feature(__address_sanitizer__) __attribute__((__no_sanitize__("signed-integer-overflow"))) @@ -128,7 +216,10 @@ class Addition { VELOX_ARITHMETIC_ERROR( "Decimal overflow: {} + {}", a.unscaledValue(), b.unscaledValue()); } - r = checkedPlus(R(aRescaled), R(bRescaled)); + auto res = R(aRescaled).plus(R(bRescaled), overflow); + if (!*overflow) { + r = res; + } } inline static uint8_t @@ -148,13 +239,38 @@ class Addition { std::max(aScale, bScale) + 1), std::max(aScale, bScale)}; } + + inline static std::pair adjustPrecisionScale( + const uint8_t rPrecision, + const uint8_t rScale) { + if (rPrecision <= 38) { + return {rPrecision, rScale}; + } else if (rScale < 0) { + return {38, rScale}; + } else { + int32_t minScale = std::min(static_cast(rScale), 6); + int32_t delta = rPrecision - 38; + return {38, std::max(rScale - delta, minScale)}; + } + } }; class Subtraction { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) #if defined(__has_feature) #if __has_feature(__address_sanitizer__) __attribute__((__no_sanitize__("signed-integer-overflow"))) @@ -171,10 +287,13 @@ class Subtraction { b.unscaledValue(), DecimalUtil::kPowersOfTen[bRescale], &bRescaled)) { - VELOX_ARITHMETIC_ERROR( - "Decimal overflow: {} - {}", a.unscaledValue(), b.unscaledValue()); + *overflow = true; + return; + } + auto res = R(aRescaled).minus(R(bRescaled), overflow); + if (!*overflow) { + r = res; } - r = checkedMinus(R(aRescaled), R(bRescaled)); } inline static uint8_t @@ -195,11 +314,83 @@ class Subtraction { class Multiply { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) { - r = checkedMultiply( - checkedMultiply(R(a), R(b)), - R(DecimalUtil::kPowersOfTen[aRescale + bRescale])); + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool* overflow) { + // derive from Arrow + if (rPrecision < 38) { + auto res = checkedMultiply( + R(a).multiply(R(b), overflow), + R(DecimalUtil::kPowersOfTen[aRescale + bRescale])); + if (!*overflow) { + r = res; + } + } else if (a.unscaledValue() == 0 && b.unscaledValue() == 0) { + // Handle this separately to avoid divide-by-zero errors. + r = R(0); + } else { + auto deltaScale = aScale + bScale - rScale; + if (deltaScale == 0) { + // No scale down + auto res = R(a).multiply(R(b), overflow); + if (!*overflow) { + r = res; + } + } else { + // scale down + // It's possible that the intermediate value does not fit in 128-bits, + // but the final value will (after scaling down). + int32_t total_leading_zeros = + a.countLeadingZeros() + b.countLeadingZeros(); + // This check is quick, but conservative. In some cases it will + // indicate that converting to 256 bits is necessary, when it's not + // actually the case. + if (UNLIKELY(total_leading_zeros <= 128)) { + // needs_int256 + int256_t aLarge = a.unscaledValue(); + int256_t blarge = b.unscaledValue(); + int256_t reslarge = aLarge * blarge; + reslarge = ReduceScaleBy(reslarge, deltaScale); + auto res = R::convert(reslarge, overflow); + if (!*overflow) { + r = res; + } + } else { + if (LIKELY(deltaScale <= 38)) { + // The largest value that result can have here is (2^64 - 1) * (2^63 + // - 1), which is greater than BasicDecimal128::kMaxValue. + auto res = R(a).multiply(R(b), overflow); + VELOX_DCHECK(!*overflow); + // Since deltaScale is greater than zero, result can now be at most + // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than + // BasicDecimal128::kMaxValue, so there cannot be any overflow. + r = res / R(DecimalUtil::kPowersOfTen[deltaScale]); + } else { + // We are multiplying decimal(38, 38) by decimal(38, 38). The result + // should be a + // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are + // not in the 256 bit intermediate value case and we are scaling + // down by 39, then we are guaranteed that the result is 0 (even if + // we try to round). The largest possible intermediate result is 38 + // "9"s. If we scale down by 39, the leftmost 9 is now two digits to + // the right of the rightmost "visible" one. The reason why we have + // to handle this case separately is because a scale multiplier with + // a deltaScale 39 does not fit into 128 bit. + r = R(0); + } + } + } + } } inline static uint8_t @@ -212,16 +403,49 @@ class Multiply { const uint8_t aScale, const uint8_t bPrecision, const uint8_t bScale) { - return {std::min(38, aPrecision + bPrecision), aScale + bScale}; + return Addition::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale); + } + + private: + // derive from Arrow + inline static int256_t ReduceScaleBy(int256_t in, int32_t reduceBy) { + if (reduceBy == 0) { + // nothing to do. + return in; + } + + int256_t divisor = DecimalUtil::kPowersOfTen[reduceBy]; + DCHECK_GT(divisor, 0); + DCHECK_EQ(divisor % 2, 0); // multiple of 10. + auto result = in / divisor; + auto remainder = in % divisor; + // round up (same as BasicDecimal128::ReduceScaleBy) + if (abs(remainder) >= (divisor >> 1)) { + result += (in > 0 ? 1 : -1); + } + return result; } }; class Divide { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t /*bRescale*/) { - DecimalUtilOp::divideWithRoundUp(r, a, b, false, aRescale, 0); + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t /*bRescale*/, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) { + DecimalUtilOp::divideWithRoundUp( + r, a, b, false, aRescale, 0, overflow); } inline static uint8_t @@ -236,30 +460,25 @@ class Divide { const uint8_t bScale) { auto scale = std::max(6, aScale + bPrecision + 1); auto precision = aPrecision - aScale + bScale + scale; - if (precision > 38) { - int32_t min_scale = std::min(scale, 6); - int32_t delta = precision - 38; - precision = 38; - scale = std::max(scale - delta, min_scale); - } - return {precision, scale}; + return Addition::adjustPrecisionScale(precision, scale); } }; std::vector> decimalMultiplySignature() { - return { - exec::FunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .integerVariable("b_precision") - .integerVariable("b_scale") - .integerVariable("r_precision", "min(38, a_precision + b_precision)") - .integerVariable("r_scale", "a_scale + b_scale") - .returnType("DECIMAL(r_precision, r_scale)") - .argumentType("DECIMAL(a_precision, a_scale)") - .argumentType("DECIMAL(b_precision, b_scale)") - .build()}; + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", "min(38, a_precision + b_precision + 1)") + .integerVariable( + "r_scale", "a_scale") // not same with the result type + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; } std::vector> @@ -295,9 +514,10 @@ std::vector> decimalDivideSignature() { "min(37, max(6, a_scale + b_precision + 1))") // if precision is // more than 38, // scale has new - // value, this check - // constrait is not - // same with result + // value, this + // check constrait + // is not same + // with result // type .returnType("DECIMAL(r_precision, r_scale)") .argumentType("DECIMAL(a_precision, a_scale)") @@ -325,38 +545,85 @@ std::shared_ptr createDecimalFunction( UnscaledLongDecimal /*result*/, UnscaledShortDecimal, UnscaledShortDecimal, - Operation>>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } else { // Arguments are short decimals and result is a short decimal. return std::make_shared>(aRescale, bRescale, SHORT_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + SHORT_DECIMAL(rPrecision, rScale)); } } else { - // LHS is short decimal and rhs is a long decimal, result is long decimal. + // LHS is short decimal and rhs is a long decimal, result is long + // decimal. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } } else { if (bType->kind() == TypeKind::SHORT_DECIMAL) { - // LHS is long decimal and rhs is short decimal, result is a long decimal. + // LHS is long decimal and rhs is short decimal, result is a long + // decimal. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } else { // Arguments and result are all long decimals. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } } VELOX_UNSUPPORTED(); diff --git a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp index e25af702a83c..0083dbc22d69 100644 --- a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp @@ -142,12 +142,6 @@ TEST_F(DecimalArithmeticTest, add) { "Decimal overflow: 1 + 99999999999999999999999999999999999999"); } -TEST_F(DecimalArithmeticTest, int128Abs) { - int128_t va = UnscaledLongDecimal::min().unscaledValue(); - int128_t absVal = std::abs(va); -; -} - TEST_F(DecimalArithmeticTest, subtract) { auto shortFlatA = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(18, 3)); // Subtract short and short, returning long. @@ -230,6 +224,36 @@ TEST_F(DecimalArithmeticTest, subtract) { "Decimal overflow: 1 - -99999999999999999999999999999999999999"); } +TEST_F(DecimalArithmeticTest, sparkMultiply) { + // auto shortFlat = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(17, + // 3)); + // // Multiply short and short, returning long. + // testDecimalExpr( + // makeLongDecimalFlatVector({1000000, 4000000}, DECIMAL(35, 6)), + // "multiply(c0, c1)", + // {shortFlat, shortFlat}); + + // auto longFlat = makeLongDecimalFlatVector({1000, 2000}, DECIMAL(21, 3)); + // auto longFlat1 = makeLongDecimalFlatVector({1000, 2000}, DECIMAL(21, 2)); + // // Multiply short and short, returning long. + // testDecimalExpr( + // makeLongDecimalFlatVector({1000000, 4000000}, DECIMAL(38, 5)), + // "multiply(c0, c1)", + // {longFlat, longFlat1}); + + // testDecimalExpr( + // makeLongDecimalFlatVector({1000, 4000}, DECIMAL(38, 7)), + // "multiply(c0, c1)", + // {makeLongDecimalFlatVector({1000, 2000}, DECIMAL(20, 5)), + // makeLongDecimalFlatVector({1000, 2000}, DECIMAL(20, 5))}); + + testDecimalExpr( + makeLongDecimalFlatVector({1000}, DECIMAL(38, 7)), + "multiply(c0, c1)", + {makeShortDecimalFlatVector({1}, DECIMAL(10, 0)), + makeLongDecimalFlatVector({1123210000000000000}, DECIMAL(38, 18))}); +} + TEST_F(DecimalArithmeticTest, multiply) { auto shortFlat = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(17, 3)); // Multiply short and short, returning long. diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index e02c7160c11c..3163380d8dd7 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -46,7 +46,8 @@ class DecimalUtil { const int fromScale, const int toPrecision, const int toScale, - bool nullOnOverflow = false) { + bool nullOnOverflow = false, + bool roundUp = true) { int128_t rescaledValue = inputValue.unscaledValue(); auto scaleDifference = toScale - fromScale; bool isOverflow = false; @@ -60,9 +61,10 @@ class DecimalUtil { const auto scalingFactor = DecimalUtil::kPowersOfTen[scaleDifference]; rescaledValue /= scalingFactor; int128_t remainder = inputValue.unscaledValue() % scalingFactor; - if (inputValue.unscaledValue() >= 0 && remainder >= scalingFactor / 2) { + if (roundUp && inputValue.unscaledValue() >= 0 && + remainder >= scalingFactor / 2) { ++rescaledValue; - } else if (remainder <= -scalingFactor / 2) { + } else if (roundUp && remainder <= -scalingFactor / 2) { --rescaledValue; } } @@ -98,6 +100,7 @@ class DecimalUtil { // Multiply decimal with the scale auto unscaled = inputValue * DecimalUtil::kPowersOfTen[toScale]; + bool isOverflow = std::isnan(unscaled); unscaled = std::round(unscaled); @@ -115,7 +118,7 @@ class DecimalUtil { if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] || rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) { VELOX_USER_FAIL( - "Cannot cast BIGINT '{}' to DECIMAL({},{})", + "Cannot cast DOUBLE '{}' to DECIMAL({},{})", inputValue, toPrecision, toScale); @@ -455,10 +458,10 @@ class DecimalUtil { } template - inline static int numDigits(T number) - { + inline static int numDigits(T number) { int digits = 0; - if (number < 0) digits = 1; // remove this line if '-' counts as a digit + if (number < 0) + digits = 1; // remove this line if '-' counts as a digit while (number) { number /= 10; digits++; diff --git a/velox/type/DecimalUtilOp.h b/velox/type/DecimalUtilOp.h index bdc6d7d41dc1..7b6fa09db1f5 100644 --- a/velox/type/DecimalUtilOp.h +++ b/velox/type/DecimalUtilOp.h @@ -24,11 +24,7 @@ #include "velox/type/UnscaledLongDecimal.h" #include "velox/type/UnscaledShortDecimal.h" -#include - namespace facebook::velox { -using boost::multiprecision::int256_t; -using uint128_t = __uint128_t; class DecimalUtilOp { public: @@ -57,60 +53,12 @@ class DecimalUtilOp { if constexpr (std::is_same_v) { num_occupied = 64 - bits::countLeadingZeros(valueAbs); } else { - uint64_t hi = valueAbs >> 64; - uint64_t lo = static_cast(valueAbs); - num_occupied = (hi == 0) ? 64 - bits::countLeadingZeros(lo) - : 64 - bits::countLeadingZeros(hi); + num_occupied = 128 - num.countLeadingZeros(); } return num_occupied + maxBitsRequiredIncreaseAfterScaling(aRescale); } - inline static int128_t ConvertToInt128(int256_t in) { - int128_t result; - int128_t INT128_MAX = int128_t(int128_t(-1L)) >> 1; - constexpr int256_t UINT128_MASK = std::numeric_limits::max(); - - int256_t in_abs = abs(in); - bool is_negative = in < 0; - - uint128_t unsignResult = (in_abs & UINT128_MASK).convert_to(); - in_abs >>= 128; - - if (in_abs > 0) { - // we've shifted in by 128-bit, so nothing should be left. - VELOX_FAIL("in_abs overflow"); - } else if (unsignResult > INT128_MAX) { - // the high-bit must not be set (signed 128-bit). - VELOX_FAIL("in_abs > int128 max"); - } else { - result = static_cast(unsignResult); - } - return is_negative ? -result : result; - } - - inline static int64_t ConvertToInt64(int256_t in) { - int64_t result; - constexpr int256_t UINT64_MASK = std::numeric_limits::max(); - - int256_t in_abs = abs(in); - bool is_negative = in < 0; - - uint128_t unsignResult = (in_abs & UINT64_MASK).convert_to(); - in_abs >>= 64; - - if (in_abs > 0) { - // we've shifted in by 128-bit, so nothing should be left. - VELOX_FAIL("in_abs overflow"); - } else if (unsignResult > INT64_MAX) { - // the high-bit must not be set (signed 128-bit). - VELOX_FAIL("in_abs > int64 max"); - } else { - result = static_cast(unsignResult); - } - return is_negative ? -result : result; - } - template inline static R divideWithRoundUp( R& r, @@ -118,8 +66,12 @@ class DecimalUtilOp { const B& b, bool noRoundUp, uint8_t aRescale, - uint8_t /*bRescale*/) { - VELOX_CHECK_NE(b, 0, "Division by zero"); + uint8_t /*bRescale*/, + bool* overflow) { + if (b.unscaledValue() == 0) { + *overflow = true; + return R(-1); + } int resultSign = 1; R unsignedDividendRescaled(a); int aSign = 1; @@ -136,10 +88,12 @@ class DecimalUtilOp { bSign = -1; } auto bitsRequiredAfterScaling = maxBitsRequiredAfterScaling(a, aRescale); - if (bitsRequiredAfterScaling <= 127) { - unsignedDividendRescaled = checkedMultiply( - unsignedDividendRescaled, R(DecimalUtil::kPowersOfTen[aRescale])); + unsignedDividendRescaled = unsignedDividendRescaled.multiply( + R(DecimalUtil::kPowersOfTen[aRescale]), overflow); + if (*overflow) { + return R(-1); + } R quotient = unsignedDividendRescaled / unsignedDivisor; R remainder = unsignedDividendRescaled % unsignedDivisor; if (!noRoundUp && remainder * 2 >= unsignedDivisor) { @@ -152,10 +106,8 @@ class DecimalUtilOp { std::is_same_v) { // Derives from Arrow BasicDecimal128 Divide if (aRescale > 38 && bitsRequiredAfterScaling > 255) { - VELOX_FAIL( - "Decimal overflow because rescale {} > 38 and bitsRequiredAfterScaling {} > 255", - aRescale, - bitsRequiredAfterScaling); + *overflow = true; + return R(-1); } int256_t aLarge = a.unscaledValue(); int256_t x_large_scaled_up = aLarge * DecimalUtil::kPowersOfTen[aRescale]; @@ -171,24 +123,135 @@ class DecimalUtilOp { // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1 result_large += (aSign ^ bSign) + 1; } - if constexpr (std::is_same_v) { - int64_t result = ConvertToInt128(result_large); - if (!R::valueInRange(result)) { - VELOX_FAIL("overflow long decimal"); - } - r = UnscaledShortDecimal(result); - return UnscaledShortDecimal(ConvertToInt64(remainder_large)); + + auto result = R::convert(result_large, overflow); + auto remainder = R::convert(remainder_large, overflow); + if (!R::valueInRange(result.unscaledValue())) { + *overflow = true; } else { - int128_t result = ConvertToInt128(result_large); - if (!R::valueInRange(result)) { - VELOX_FAIL("overflow long decimal"); - } - r = UnscaledLongDecimal(result); - return UnscaledLongDecimal(ConvertToInt128(remainder_large)); + r = result; } + return remainder; } else { VELOX_FAIL("Should not reach here in DecimalUtilOp.h"); } } + + // return unscaled value and scale + inline static std::pair splitVarChar( + const StringView& value) { + std::string s = value.str(); + size_t pos = s.find('.'); + if (pos == std::string::npos) { + return {s.substr(0, pos), 0}; + } else { + return { + s.substr(0, pos) + s.substr(pos + 1, s.length()), s.length() - pos - 1}; + } + } + + static int128_t convertStringToInt128( + const std::string& value, + bool& nullOutput) { + // Handling integer target cases + const char* v = value.c_str(); + nullOutput = true; + bool negative = false; + int128_t result = 0; + int index = 0; + int len = value.size(); + if (len == 0) { + return -1; + } + // Setting negative flag + if (v[0] == '-') { + if (len == 1) { + return -1; + } + negative = true; + index = 1; + } + if (negative) { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 - (v[index] - '0'); + // Overflow check + if (result > 0) { + return -1; + } + } + } else { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 + (v[index] - '0'); + // Overflow check + if (result < 0) { + return -1; + } + } + } + // Final result + nullOutput = false; + return result; + } + + template + inline static std::optional rescaleVarchar( + const StringView inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || + std::is_same_v); + auto [unscaledStr, fromScale] = splitVarChar(inputValue); + uint8_t fromPrecision = unscaledStr.size(); + VELOX_CHECK_LE( + fromPrecision, DecimalType::kMaxPrecision); + if (fromPrecision <= 18) { + int64_t fromUnscaledValue = folly::to(unscaledStr); + return DecimalUtil::rescaleWithRoundUp( + UnscaledShortDecimal(fromUnscaledValue), + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + false); + } else { + bool nullOutput = true; + int128_t decimalValue = convertStringToInt128(unscaledStr, nullOutput); + if (nullOutput) { + VELOX_USER_FAIL( + "Cannot cast StringView '{}' to DECIMAL({},{})", + inputValue, + toPrecision, + toScale); + } + return DecimalUtil::rescaleWithRoundUp( + UnscaledLongDecimal(decimalValue), + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + false); + } + } + + template + inline static std::optional rescaleDouble( + const TInput inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || + std::is_same_v); + return rescaleVarchar( + velox::to(inputValue), toPrecision, toScale); + } }; } // namespace facebook::velox diff --git a/velox/type/UnscaledLongDecimal.h b/velox/type/UnscaledLongDecimal.h index b3966836e972..5a3d4bff361e 100644 --- a/velox/type/UnscaledLongDecimal.h +++ b/velox/type/UnscaledLongDecimal.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -26,6 +27,8 @@ namespace facebook::velox { using int128_t = __int128_t; +using boost::multiprecision::int256_t; +using uint128_t = __uint128_t; constexpr int128_t buildInt128(uint64_t hi, uint64_t lo) { // GCC does not allow left shift negative value. @@ -173,6 +176,66 @@ struct UnscaledLongDecimal { memcpy(&ans.unscaledValue_, serializedData, sizeof(int128_t)); return ans; } + + UnscaledLongDecimal plus(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_add_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + UnscaledLongDecimal minus(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_sub_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + UnscaledLongDecimal multiply(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_mul_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + int32_t countLeadingZeros() const { + auto abs = std::abs(unscaledValue_); + return bits::countLeadingZerosUint128(abs); + } + + static inline UnscaledLongDecimal convert(int256_t in, bool* overflow) { + int128_t result; + int128_t INT128_MAX = int128_t(int128_t(-1L)) >> 1; + constexpr int256_t UINT128_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT128_MASK).convert_to(); + inAbs >>= 128; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT128_MAX) { + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return UnscaledLongDecimal(isNegative ? -result : result); + } private: static constexpr int128_t kMin = diff --git a/velox/type/UnscaledShortDecimal.h b/velox/type/UnscaledShortDecimal.h index 581bddee9306..1679550b3d7d 100644 --- a/velox/type/UnscaledShortDecimal.h +++ b/velox/type/UnscaledShortDecimal.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -23,6 +24,8 @@ #pragma once namespace facebook::velox { +using boost::multiprecision::int256_t; +using uint128_t = __uint128_t; struct UnscaledShortDecimal { public: @@ -122,6 +125,66 @@ struct UnscaledShortDecimal { return *this; } + UnscaledShortDecimal plus(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_add_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + UnscaledShortDecimal minus(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_sub_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + UnscaledShortDecimal multiply(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_mul_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + int32_t countLeadingZeros() const { + auto abs = std::abs(unscaledValue_); + return bits::countLeadingZeros(abs); + } + + static inline UnscaledShortDecimal convert(int256_t in, bool* overflow) { + int64_t result; + constexpr int256_t UINT64_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT64_MASK).convert_to(); + inAbs >>= 64; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT64_MAX) { + // the high-bit must not be set (signed 128-bit). + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return UnscaledShortDecimal(isNegative ? -result : result); + } + private: static constexpr int64_t kMin = -1'000'000'000'000'000'000 + 1; static constexpr int64_t kMax = 1'000'000'000'000'000'000 - 1;