Skip to content

Commit

Permalink
Support float & double types in pmod function (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored and zhejiangxiaomai committed Mar 29, 2023
1 parent 17d9853 commit 462979e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
16 changes: 15 additions & 1 deletion velox/functions/sparksql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
namespace facebook::velox::functions::sparksql {

template <typename T>
struct PModFunction {
struct PModIntFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n)
#if defined(__has_feature)
Expand All @@ -43,6 +43,20 @@ struct PModFunction {
}
};

template <typename T>
struct PModFloatFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n)
{
if (UNLIKELY(n == (TInput)0)) {
return false;
}
TInput r = fmod(a, n);
result = (r > 0) ? r : fmod(r + n, n);
return true;
}
};

template <typename T>
struct RemainderFunction {
template <typename TInput>
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ void registerArithmeticFunctions(const std::string& prefix) {
// Math functions.
registerUnaryNumeric<AbsFunction>({prefix + "abs"});
registerFunction<ExpFunction, double, double>({prefix + "exp"});
registerBinaryIntegral<PModFunction>({prefix + "pmod"});
registerBinaryIntegral<PModIntFunction>({prefix + "pmod"});
registerBinaryFloatingPoint<PModFloatFunction>({prefix + "pmod"});
registerFunction<PowerFunction, double, double, double>({prefix + "power"});
registerUnaryNumeric<RoundFunction>({prefix + "round"});
registerFunction<RoundFunction, int8_t, int8_t, int32_t>({prefix + "round"});
Expand Down
12 changes: 12 additions & 0 deletions velox/functions/sparksql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ TEST_F(PmodTest, int64) {
EXPECT_EQ(INT64_MAX - 1, pmod<int64_t>(INT64_MIN, INT64_MAX));
}

TEST_F(PmodTest, float) {
EXPECT_FLOAT_EQ(0.2, pmod<float>(0.5, 0.3).value());
EXPECT_FLOAT_EQ(0.9, pmod<float>(-1.1, 2).value());
EXPECT_EQ(std::nullopt, pmod<float>(2.14159, 0.0));
}

TEST_F(PmodTest, double) {
EXPECT_DOUBLE_EQ(0.2, pmod<double>(0.5, 0.3).value());
EXPECT_DOUBLE_EQ(0.9, pmod<double>(-1.1, 2).value());
EXPECT_EQ(std::nullopt, pmod<double>(2.14159, 0.0));
}

class RemainderTest : public SparkFunctionBaseTest {
protected:
template <typename T>
Expand Down

0 comments on commit 462979e

Please sign in to comment.