diff --git a/velox/docs/functions/presto/decimal.rst b/velox/docs/functions/presto/decimal.rst index c1c862256f52..2cd7f2150a2e 100644 --- a/velox/docs/functions/presto/decimal.rst +++ b/velox/docs/functions/presto/decimal.rst @@ -136,7 +136,7 @@ type decimal(p2, s2) with unscaled value B. :: a = A / 10^s1 - a = B / 10^s2 + b = B / 10^s2 The result type precision and scale are: @@ -174,6 +174,60 @@ digits after the decimal point, hence, max(s1, s2). SELECT 1.2 / 0.01 +Modulus +------- + +For the modulus operation :code:`a % b`, when a and b are integers, the result +`r` is less than `b` and less than or equal to `a`. Hence the number of digits +needed to represent `r` is no more than the minimum of the number of digits +needed to represent `a` or `b`. We can extend this to decimal inputs `a` and +`b` by computing the modulus of their unscaled values. However, we should +first make sure that `a` and `b` have the same scale. This can be achieved by +scaling up the input with lesser scale by the difference in the inputs' scales, +so both `a` and `b` have scale s. Once `a` and `b` have the same scale, we +compute the modulus of their unscaled values, A and B. `r` has s digits after +the decimal point, and since `r` does not need any more digits than the +minimum number of digits needed to represent `a` or `b`, the result precision +needs to be increased by the smaller of the differences in the precision and +scale of either inputs. Hence the result type precision and scale are: + +:: + + s = max(s1, s2) + p = min(p2 - s2, p1 - s1) + max(s1, s2) + +To compute R, we first rescale A and B to 's': + +:: + + A = a * 10^s1 + B = b * 10^s2 + + A' = a * 10^s + B' = b * 10^s + +Then we compute modulus of the rescaled values: + +:: + + R = A' % B' = r * 10^s + +For example, say `a` = 12.3 and `b` = 1.21, `r` = :code:`a % b` is calculated +as follows: + +:: + + s = max(1, 2) = 2 + p = min(2, 1) + s = 3 + + A = 12.3 * 10^1 = 123 + B = 1.21 * 10^2 = 121 + + A' = 12.3 * 10^2 = 1230 + B' = 1.21 * 10^2 = 121 + + R = 1230 % 121 = 20 = 0.20 * 100 + Decimal Functions ----------------- @@ -216,6 +270,19 @@ Decimal Functions Throws if result cannot be represented using precision calculated above. +.. function:: modulus(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p, s) + + Returns the remainder from division of x by y (r = x % y). + + x and y are decimal values with possibly different precisions and scales. The + precision and scale of the result are calculated as follows: + :: + + p = min(p2 - s2, p1 - s1) + max(s1, s2) + s = max(s1, s2) + + Throws if y is zero. + .. function:: multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p, s) Returns the result of multiplying x by y (r = x * y). diff --git a/velox/functions/prestosql/DecimalFunctions.cpp b/velox/functions/prestosql/DecimalFunctions.cpp index fa89a50cbf94..3ec752154a82 100644 --- a/velox/functions/prestosql/DecimalFunctions.cpp +++ b/velox/functions/prestosql/DecimalFunctions.cpp @@ -163,6 +163,56 @@ struct DecimalDivideFunction { uint8_t aRescale_; }; +template +struct DecimalModulusFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + template + void initialize( + const std::vector& inputTypes, + const core::QueryConfig& /*config*/, + A* /*a*/, + B* /*b*/) { + const auto& aType = inputTypes[0]; + const auto& bType = inputTypes[1]; + auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); + auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + aRescale_ = std::max(0, bScale - aScale); + bRescale_ = std::max(0, aScale - bScale); + } + + template + void call(R& out, const A& a, const B& b) { + VELOX_USER_CHECK_NE(b, 0, "Modulus by zero"); + int remainderSign = 1; + R unsignedDividendRescaled(a); + if (a < 0) { + remainderSign *= -1; + unsignedDividendRescaled *= -1; + } + unsignedDividendRescaled = checkedMultiply( + unsignedDividendRescaled, + R(DecimalUtil::kPowersOfTen[aRescale_]), + "Decimal"); + + R unsignedDivisorRescaled(b); + if (b < 0) { + unsignedDivisorRescaled *= -1; + } + unsignedDivisorRescaled = checkedMultiply( + unsignedDivisorRescaled, + R(DecimalUtil::kPowersOfTen[bRescale_]), + "Decimal"); + + R remainder = unsignedDividendRescaled % unsignedDivisorRescaled; + out = remainder * remainderSign; + } + + private: + uint8_t aRescale_; + uint8_t bRescale_; +}; + template struct DecimalRoundFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -371,6 +421,69 @@ void registerDecimalDivide(const std::string& prefix) { ShortDecimal>({prefix + "divide"}, constraints); } +void registerDecimalModulus(const std::string& prefix) { + std::vector constraints = { + exec::SignatureVariable( + P3::name(), + fmt::format( + "min({b_precision} - {b_scale}, {a_precision} - {a_scale}) + max({a_scale}, {b_scale})", + fmt::arg("a_precision", P1::name()), + fmt::arg("a_scale", S1::name()), + fmt::arg("b_precision", P2::name()), + fmt::arg("b_scale", S2::name())), + exec::ParameterType::kIntegerParameter), + exec::SignatureVariable( + S3::name(), + fmt::format( + "max({a_scale}, {b_scale})", + fmt::arg("a_scale", S1::name()), + fmt::arg("b_scale", S2::name())), + exec::ParameterType::kIntegerParameter), + }; + + // (short, short) -> short + registerFunction< + DecimalModulusFunction, + ShortDecimal, + ShortDecimal, + ShortDecimal>({prefix + "mod"}, constraints); + + // (short, long) -> short + registerFunction< + DecimalModulusFunction, + ShortDecimal, + ShortDecimal, + LongDecimal>({prefix + "mod"}, constraints); + + // (long, short) -> short + registerFunction< + DecimalModulusFunction, + ShortDecimal, + LongDecimal, + ShortDecimal>({prefix + "mod"}, constraints); + + // (short, long) -> long + registerFunction< + DecimalModulusFunction, + LongDecimal, + ShortDecimal, + LongDecimal>({prefix + "mod"}, constraints); + + // (long, short) -> long + registerFunction< + DecimalModulusFunction, + LongDecimal, + LongDecimal, + ShortDecimal>({prefix + "mod"}, constraints); + + // (long, long) -> long + registerFunction< + DecimalModulusFunction, + LongDecimal, + LongDecimal, + LongDecimal>({prefix + "mod"}, constraints); +} + void registerDecimalFloor(const std::string& prefix) { std::vector constraints = { exec::SignatureVariable( diff --git a/velox/functions/prestosql/DecimalFunctions.h b/velox/functions/prestosql/DecimalFunctions.h index 69f7c9aaa4f5..981073fbc414 100644 --- a/velox/functions/prestosql/DecimalFunctions.h +++ b/velox/functions/prestosql/DecimalFunctions.h @@ -24,6 +24,8 @@ void registerDecimalMultiply(const std::string& prefix); void registerDecimalDivide(const std::string& prefix); +void registerDecimalModulus(const std::string& prefix); + void registerDecimalFloor(const std::string& prefix); void registerDecimalRound(const std::string& prefix); diff --git a/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp b/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp index 0d2440451873..f5d806589bd3 100644 --- a/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp +++ b/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp @@ -67,6 +67,7 @@ void registerMathematicalOperators(const std::string& prefix = "") { registerDecimalMinus(prefix); registerDecimalMultiply(prefix); registerDecimalDivide(prefix); + registerDecimalModulus(prefix); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp index b363c71bb96b..21536c923e11 100644 --- a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp @@ -403,6 +403,72 @@ TEST_F(DecimalArithmeticTest, decimalDivDifferentTypes) { makeFlatVector({100, 200, -300, 400}, DECIMAL(12, 2))}); } +TEST_F(DecimalArithmeticTest, decimalMod) { + // short % short -> short. + testDecimalExpr( + makeFlatVector({0, 0}, DECIMAL(2, 1)), + "mod(c0, c1)", + {makeFlatVector({0, 50}, DECIMAL(2, 1)), + makeFlatVector({20, 25}, DECIMAL(2, 1))}); + testDecimalExpr( + makeFlatVector({3, -3, 3, -3}, DECIMAL(2, 1)), + "mod(c0, c1)", + {makeFlatVector({13, -13, 13, -13}, DECIMAL(3, 1)), + makeFlatVector({5, 5, -5, -5}, DECIMAL(2, 1))}); + testDecimalExpr( + makeFlatVector({90, -245, 245, -90}, DECIMAL(3, 2)), + "mod(c0, c1)", + {makeFlatVector({50, -50, 50, -50}, DECIMAL(2, 1)), + makeFlatVector({205, 255, -255, -205}, DECIMAL(3, 2))}); + testDecimalExpr( + makeFlatVector({2500, -12000}, DECIMAL(5, 3)), + "mod(c0, c1)", + {makeFlatVector({2500, -12000}, DECIMAL(5, 3)), + makeFlatVector({600, 5000}, DECIMAL(5, 2))}); + + // short % long -> short. + testDecimalExpr( + makeFlatVector({1000, -600, 1000, -600}, DECIMAL(17, 15)), + "mod(c0, c1)", + {makeFlatVector({1000, -600, 1000, -600}, DECIMAL(17, 15)), + makeFlatVector({13, 17, -13, -17}, DECIMAL(20, 10))}); + + // long % short -> short. + testDecimalExpr( + makeFlatVector({8, -11, 8, -11}, DECIMAL(17, 15)), + "mod(c0, c1)", + {makeFlatVector({500, -4000, 500, -4000}, DECIMAL(20, 10)), + makeFlatVector({17, 19, -17, -19}, DECIMAL(17, 15))}); + + // short % long -> long. + testDecimalExpr( + makeFlatVector({0, -16, 0, -16}, DECIMAL(25, 10)), + "mod(c0, c1)", + {makeFlatVector({1000, -600, 1000, -600}, DECIMAL(17, 2)), + makeFlatVector({400, 38, -400, -38}, DECIMAL(30, 10))}); + + // long % short -> long. + testDecimalExpr( + makeFlatVector({500, -4000, 500, -4000}, DECIMAL(25, 10)), + "mod(c0, c1)", + {makeFlatVector({500, -4000, 500, -4000}, DECIMAL(30, 10)), + makeFlatVector({1000, 2000, -1000, -2000}, DECIMAL(17, 2))}); + + // long % long -> long. + testDecimalExpr( + makeFlatVector({2500, -12000, 2500, -12000}, DECIMAL(23, 5)), + "mod(c0, c1)", + {makeFlatVector({2500, -12000, 2500, -12000}, DECIMAL(25, 5)), + makeFlatVector({500, 4000, -500, -4000}, DECIMAL(20, 2))}); + + VELOX_ASSERT_USER_THROW( + testDecimalExpr( + {}, + "c0 % 0.0", + {makeFlatVector({1000, 2000}, DECIMAL(17, 3))}), + "Modulus by zero"); +} + TEST_F(DecimalArithmeticTest, round) { // Round short decimals. testDecimalExpr(