Skip to content

Commit

Permalink
Add decimal mod operator (facebookincubator#9351)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookincubator#9351

Reviewed By: Yuhta

Differential Revision: D57478854

Pulled By: kgpai

fbshipit-source-id: bfc65a847ee0072198867a2950e4583d5183e336
  • Loading branch information
pramodsatya authored and facebook-github-bot committed Jun 3, 2024
1 parent 8648f50 commit ab06a77
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 1 deletion.
69 changes: 68 additions & 1 deletion velox/docs/functions/presto/decimal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ type decimal(p2, s2) with unscaled value B.
::

a = A / 10^s1
a = B / 10^s2
b = B / 10^s2

The result type precision and scale are:

Expand Down Expand Up @@ -174,6 +174,60 @@ digits after the decimal point, hence, max(s1, s2).

SELECT 1.2 / 0.01

Modulus
-------

For the modulus operation :code:`a % b`, when a and b are integers, the result
`r` is less than `b` and less than or equal to `a`. Hence the number of digits
needed to represent `r` is no more than the minimum of the number of digits
needed to represent `a` or `b`. We can extend this to decimal inputs `a` and
`b` by computing the modulus of their unscaled values. However, we should
first make sure that `a` and `b` have the same scale. This can be achieved by
scaling up the input with lesser scale by the difference in the inputs' scales,
so both `a` and `b` have scale s. Once `a` and `b` have the same scale, we
compute the modulus of their unscaled values, A and B. `r` has s digits after
the decimal point, and since `r` does not need any more digits than the
minimum number of digits needed to represent `a` or `b`, the result precision
needs to be increased by the smaller of the differences in the precision and
scale of either inputs. Hence the result type precision and scale are:

::

s = max(s1, s2)
p = min(p2 - s2, p1 - s1) + max(s1, s2)

To compute R, we first rescale A and B to 's':

::

A = a * 10^s1
B = b * 10^s2

A' = a * 10^s
B' = b * 10^s

Then we compute modulus of the rescaled values:

::

R = A' % B' = r * 10^s

For example, say `a` = 12.3 and `b` = 1.21, `r` = :code:`a % b` is calculated
as follows:

::

s = max(1, 2) = 2
p = min(2, 1) + s = 3

A = 12.3 * 10^1 = 123
B = 1.21 * 10^2 = 121

A' = 12.3 * 10^2 = 1230
B' = 1.21 * 10^2 = 121

R = 1230 % 121 = 20 = 0.20 * 100

Decimal Functions
-----------------

Expand Down Expand Up @@ -216,6 +270,19 @@ Decimal Functions

Throws if result cannot be represented using precision calculated above.

.. function:: modulus(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p, s)

Returns the remainder from division of x by y (r = x % y).

x and y are decimal values with possibly different precisions and scales. The
precision and scale of the result are calculated as follows:
::

p = min(p2 - s2, p1 - s1) + max(s1, s2)
s = max(s1, s2)

Throws if y is zero.

.. function:: multiply(x: decimal(p1, s1), y: decimal(p2, s2)) -> r: decimal(p, s)

Returns the result of multiplying x by y (r = x * y).
Expand Down
113 changes: 113 additions & 0 deletions velox/functions/prestosql/DecimalFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,56 @@ struct DecimalDivideFunction {
uint8_t aRescale_;
};

template <typename TExec>
struct DecimalModulusFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

template <typename A, typename B>
void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& /*config*/,
A* /*a*/,
B* /*b*/) {
const auto& aType = inputTypes[0];
const auto& bType = inputTypes[1];
auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
aRescale_ = std::max(0, bScale - aScale);
bRescale_ = std::max(0, aScale - bScale);
}

template <typename R, typename A, typename B>
void call(R& out, const A& a, const B& b) {
VELOX_USER_CHECK_NE(b, 0, "Modulus by zero");
int remainderSign = 1;
R unsignedDividendRescaled(a);
if (a < 0) {
remainderSign *= -1;
unsignedDividendRescaled *= -1;
}
unsignedDividendRescaled = checkedMultiply<R>(
unsignedDividendRescaled,
R(DecimalUtil::kPowersOfTen[aRescale_]),
"Decimal");

R unsignedDivisorRescaled(b);
if (b < 0) {
unsignedDivisorRescaled *= -1;
}
unsignedDivisorRescaled = checkedMultiply<B>(
unsignedDivisorRescaled,
R(DecimalUtil::kPowersOfTen[bRescale_]),
"Decimal");

R remainder = unsignedDividendRescaled % unsignedDivisorRescaled;
out = remainder * remainderSign;
}

private:
uint8_t aRescale_;
uint8_t bRescale_;
};

template <typename TExec>
struct DecimalRoundFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);
Expand Down Expand Up @@ -371,6 +421,69 @@ void registerDecimalDivide(const std::string& prefix) {
ShortDecimal<P2, S2>>({prefix + "divide"}, constraints);
}

void registerDecimalModulus(const std::string& prefix) {
std::vector<exec::SignatureVariable> constraints = {
exec::SignatureVariable(
P3::name(),
fmt::format(
"min({b_precision} - {b_scale}, {a_precision} - {a_scale}) + max({a_scale}, {b_scale})",
fmt::arg("a_precision", P1::name()),
fmt::arg("a_scale", S1::name()),
fmt::arg("b_precision", P2::name()),
fmt::arg("b_scale", S2::name())),
exec::ParameterType::kIntegerParameter),
exec::SignatureVariable(
S3::name(),
fmt::format(
"max({a_scale}, {b_scale})",
fmt::arg("a_scale", S1::name()),
fmt::arg("b_scale", S2::name())),
exec::ParameterType::kIntegerParameter),
};

// (short, short) -> short
registerFunction<
DecimalModulusFunction,
ShortDecimal<P3, S3>,
ShortDecimal<P1, S1>,
ShortDecimal<P2, S2>>({prefix + "mod"}, constraints);

// (short, long) -> short
registerFunction<
DecimalModulusFunction,
ShortDecimal<P3, S3>,
ShortDecimal<P1, S1>,
LongDecimal<P2, S2>>({prefix + "mod"}, constraints);

// (long, short) -> short
registerFunction<
DecimalModulusFunction,
ShortDecimal<P3, S3>,
LongDecimal<P1, S1>,
ShortDecimal<P2, S2>>({prefix + "mod"}, constraints);

// (short, long) -> long
registerFunction<
DecimalModulusFunction,
LongDecimal<P3, S3>,
ShortDecimal<P1, S1>,
LongDecimal<P2, S2>>({prefix + "mod"}, constraints);

// (long, short) -> long
registerFunction<
DecimalModulusFunction,
LongDecimal<P3, S3>,
LongDecimal<P1, S1>,
ShortDecimal<P2, S2>>({prefix + "mod"}, constraints);

// (long, long) -> long
registerFunction<
DecimalModulusFunction,
LongDecimal<P3, S3>,
LongDecimal<P1, S1>,
LongDecimal<P2, S2>>({prefix + "mod"}, constraints);
}

void registerDecimalFloor(const std::string& prefix) {
std::vector<exec::SignatureVariable> constraints = {
exec::SignatureVariable(
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/prestosql/DecimalFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ void registerDecimalMultiply(const std::string& prefix);

void registerDecimalDivide(const std::string& prefix);

void registerDecimalModulus(const std::string& prefix);

void registerDecimalFloor(const std::string& prefix);

void registerDecimalRound(const std::string& prefix);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ void registerMathematicalOperators(const std::string& prefix = "") {
registerDecimalMinus(prefix);
registerDecimalMultiply(prefix);
registerDecimalDivide(prefix);
registerDecimalModulus(prefix);
}

} // namespace facebook::velox::functions
66 changes: 66 additions & 0 deletions velox/functions/prestosql/tests/DecimalArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,72 @@ TEST_F(DecimalArithmeticTest, decimalDivDifferentTypes) {
makeFlatVector<int64_t>({100, 200, -300, 400}, DECIMAL(12, 2))});
}

TEST_F(DecimalArithmeticTest, decimalMod) {
// short % short -> short.
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({0, 0}, DECIMAL(2, 1)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({0, 50}, DECIMAL(2, 1)),
makeFlatVector<int64_t>({20, 25}, DECIMAL(2, 1))});
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({3, -3, 3, -3}, DECIMAL(2, 1)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({13, -13, 13, -13}, DECIMAL(3, 1)),
makeFlatVector<int64_t>({5, 5, -5, -5}, DECIMAL(2, 1))});
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({90, -245, 245, -90}, DECIMAL(3, 2)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({50, -50, 50, -50}, DECIMAL(2, 1)),
makeFlatVector<int64_t>({205, 255, -255, -205}, DECIMAL(3, 2))});
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({2500, -12000}, DECIMAL(5, 3)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({2500, -12000}, DECIMAL(5, 3)),
makeFlatVector<int64_t>({600, 5000}, DECIMAL(5, 2))});

// short % long -> short.
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({1000, -600, 1000, -600}, DECIMAL(17, 15)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({1000, -600, 1000, -600}, DECIMAL(17, 15)),
makeFlatVector<int128_t>({13, 17, -13, -17}, DECIMAL(20, 10))});

// long % short -> short.
testDecimalExpr<TypeKind::BIGINT>(
makeFlatVector<int64_t>({8, -11, 8, -11}, DECIMAL(17, 15)),
"mod(c0, c1)",
{makeFlatVector<int128_t>({500, -4000, 500, -4000}, DECIMAL(20, 10)),
makeFlatVector<int64_t>({17, 19, -17, -19}, DECIMAL(17, 15))});

// short % long -> long.
testDecimalExpr<TypeKind::HUGEINT>(
makeFlatVector<int128_t>({0, -16, 0, -16}, DECIMAL(25, 10)),
"mod(c0, c1)",
{makeFlatVector<int64_t>({1000, -600, 1000, -600}, DECIMAL(17, 2)),
makeFlatVector<int128_t>({400, 38, -400, -38}, DECIMAL(30, 10))});

// long % short -> long.
testDecimalExpr<TypeKind::HUGEINT>(
makeFlatVector<int128_t>({500, -4000, 500, -4000}, DECIMAL(25, 10)),
"mod(c0, c1)",
{makeFlatVector<int128_t>({500, -4000, 500, -4000}, DECIMAL(30, 10)),
makeFlatVector<int64_t>({1000, 2000, -1000, -2000}, DECIMAL(17, 2))});

// long % long -> long.
testDecimalExpr<TypeKind::HUGEINT>(
makeFlatVector<int128_t>({2500, -12000, 2500, -12000}, DECIMAL(23, 5)),
"mod(c0, c1)",
{makeFlatVector<int128_t>({2500, -12000, 2500, -12000}, DECIMAL(25, 5)),
makeFlatVector<int128_t>({500, 4000, -500, -4000}, DECIMAL(20, 2))});

VELOX_ASSERT_USER_THROW(
testDecimalExpr<TypeKind::BIGINT>(
{},
"c0 % 0.0",
{makeFlatVector<int64_t>({1000, 2000}, DECIMAL(17, 3))}),
"Modulus by zero");
}

TEST_F(DecimalArithmeticTest, round) {
// Round short decimals.
testDecimalExpr<TypeKind::BIGINT>(
Expand Down

0 comments on commit ab06a77

Please sign in to comment.