Skip to content

Commit

Permalink
fix decimal compare wrong in compute expression
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHuiGui committed Apr 10, 2024
1 parent 831b94a commit 8ca5f83
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
40 changes: 40 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,46 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
}
}

TEST(Expression, ExecuteCallWithDecimalComparisonOps) {
// GH-41011, make sure the decimal's comparison operations are casted
// in expression bind and make correct results in expression execute
ExpectExecute(
call("not_equal", {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
R"([
{"d1": "40", "d2": "4.0"},
{"d1": "20", "d2": "2.0"}
])"));

ExpectExecute(
call("less", {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 1)), field("d2", decimal(2, 0))}),
R"([
{"d1": "4.0", "d2": "40"},
{"d1": "2.0", "d2": "20"}
])"));

for (std::string fname : {"less_equal", "equal"}) {
ExpectExecute(
call(fname, {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(3, 2)), field("d2", decimal(2, 1))}),
R"([
{"d1": "3.10", "d2": "3.1"},
{"d1": "2.10", "d2": "2.1"}
])"));
}

for (std::string fname : {"greater_equal", "greater"}) {
ExpectExecute(
call(fname, {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
R"([
{"d1": "4", "d2": "3.0"},
{"d1": "3", "d2": "2.0"}
])"));
}
}

TEST(Expression, ExecuteCall) {
ExpectExecute(add(field_ref("a"), literal(3.5)),
ArrayFromJSON(struct_({field("a", float64())}), R"([
Expand Down
21 changes: 19 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,23 @@ struct VarArgsCompareFunction : ScalarFunction {
}
};

Result<TypeHolder> ResolveDecimalCompareOutputType(KernelContext*,
const std::vector<TypeHolder>& types) {
// casted types should be same size decimals
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);
DCHECK_EQ(left_type.id(), right_type.id());

// check the casted decimal scales according kAdd promotion rule
const int32_t s1 = left_type.scale();
const int32_t s2 = right_type.scale();
if (s1 != s2) {
return Status::Invalid("Comparison of two decimal ", "types s1 != s2. (", s1, s2,
").");
}
return boolean();
}

template <typename Op>
std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDoc doc) {
auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), std::move(doc));
Expand Down Expand Up @@ -433,9 +450,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
}

for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
OutputType out_type(ResolveDecimalCompareOutputType);
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, out_type, std::move(exec)));
}

{
Expand Down

0 comments on commit 8ca5f83

Please sign in to comment.