Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41011: [C++] Add an output type resolver for decimal types in CompareFunction so can be casted correctly #41012

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,46 @@ void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) {
}
}

TEST(Expression, ExecuteCallWithDecimalComparisonOps) {
// GH-41011, make sure the decimal's comparison operations are casted
// in expression bind and make correct results in expression execute
ExpectExecute(
call("not_equal", {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
R"([
{"d1": "40", "d2": "4.0"},
{"d1": "20", "d2": "2.0"}
])"));

ExpectExecute(
call("less", {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 1)), field("d2", decimal(2, 0))}),
R"([
{"d1": "4.0", "d2": "40"},
{"d1": "2.0", "d2": "20"}
])"));

for (std::string fname : {"less_equal", "equal"}) {
ExpectExecute(
call(fname, {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(3, 2)), field("d2", decimal(2, 1))}),
R"([
{"d1": "3.10", "d2": "3.1"},
{"d1": "2.10", "d2": "2.1"}
])"));
}

for (std::string fname : {"greater_equal", "greater"}) {
ExpectExecute(
call(fname, {field_ref("d1"), field_ref("d2")}),
ArrayFromJSON(struct_({field("d1", decimal(2, 0)), field("d2", decimal(2, 1))}),
R"([
{"d1": "4", "d2": "3.0"},
{"d1": "3", "d2": "2.0"}
])"));
}
}

TEST(Expression, ExecuteCall) {
ExpectExecute(add(field_ref("a"), literal(3.5)),
ArrayFromJSON(struct_({field("a", float64())}), R"([
Expand Down
14 changes: 12 additions & 2 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ bool InputType::Matches(const Datum& value) const {
return Matches(*value.type());
}

bool InputType::Matches(const std::vector<TypeHolder>& types) const {
DCHECK_EQ(InputType::USE_TYPE_MATCHER, kind_);
return type_matcher_->Matches(types);
}

const std::shared_ptr<DataType>& InputType::type() const {
DCHECK_EQ(InputType::EXACT_TYPE, kind_);
return type_;
Expand Down Expand Up @@ -505,9 +510,14 @@ bool KernelSignature::Equals(const KernelSignature& other) const {
}

bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const {
auto is_match_combination_types = [&](const InputType& in_type) {
Copy link
Collaborator Author

@ZhangHuiGui ZhangHuiGui May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type matcher acceptable rule for the combination types.

return in_type.kind() == InputType::USE_TYPE_MATCHER ? in_type.Matches(types) : true;
};

if (is_varargs_) {
for (size_t i = 0; i < types.size(); ++i) {
if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(*types[i])) {
const auto& in_type = in_types_[std::min(i, in_types_.size() - 1)];
if (!in_type.Matches(*types[i]) || !is_match_combination_types(in_type)) {
return false;
}
}
Expand All @@ -516,7 +526,7 @@ bool KernelSignature::MatchesInputs(const std::vector<TypeHolder>& types) const
return false;
}
for (size_t i = 0; i < in_types_.size(); ++i) {
if (!in_types_[i].Matches(*types[i])) {
if (!in_types_[i].Matches(*types[i]) || !is_match_combination_types(in_types_[i])) {
return false;
}
}
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ struct ARROW_EXPORT TypeMatcher {
/// \brief Return true if this matcher accepts the data type.
virtual bool Matches(const DataType& type) const = 0;

/// \brief Return true if this matcher accepts the combination types
virtual bool Matches(const std::vector<TypeHolder>& types) const { return true; }

Comment on lines +112 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is for matching a single matcher (this) against a single type.

I don't think it's reasonable to extend the type matching this way as it makes type checking Turing-complete (wouldn't be easy/possible to encode the typing rules in a provably-decidable type system).

I'm sorry for this back and forth because I was the one that said "type matching machinery". I still think this logic shouldn't go on the output type resolver so I have a more precise suggestion this time. Since functions already have a way of implementing custom type matching rules -- DispatchBest —— you can add this logic in

if (HasDecimal(*types)) {
RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, types));
}

and return return arrow::compute::detail::NoMatchingKernel(this, *types); or a Status::NotImplemented with a message that describes that you can't compare two decimals with different scales.

Copy link
Collaborator Author

@ZhangHuiGui ZhangHuiGui Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the root cause of the problem that this PR is trying to solve is actually CompareFunction with two different decimal scales in BindNonRecursive , it is impossible to enter the logic of DispatchBest.

// First try and bind exactly
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
if (maybe_exact_match.ok()) {
call.kernel = *maybe_exact_match;
if (FinishBind().ok()) {
return Expression(std::move(call));
}
}

Which means the DispatchExact will always return ok and can't go into DispatchBest when CompareFunction called by expression system with two different decimal scales.

This is why we did type judgment in the output type Resolver before, so that we can return not-ok in the first FinishBind and enter DispatchBest.

DispatchBest does not need to make additional judgments, because it will do cast according to the decimal rules, ensuring that different input scales are cast to the same scales according to DecimalPromotion::kAdd.

/// \brief A human-interpretable string representation of what the type
/// matcher checks for, usable when printing KernelSignature or formatting
/// error messages.
Expand Down Expand Up @@ -241,6 +244,10 @@ class ARROW_EXPORT InputType {
/// \brief Return true if the type matches this InputType
bool Matches(const DataType& type) const;

/// \brief Return true if the input combination types matches this
/// InputType's type_matcher matched rules.
bool Matches(const std::vector<TypeHolder>& types) const;

/// \brief The type matching rule that this InputType uses.
Kind kind() const { return kind_; }

Expand Down
52 changes: 50 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,54 @@ struct VarArgsCompareFunction : ScalarFunction {
}
};

class DecimalTypesCompareMatcher : public TypeMatcher {
public:
explicit DecimalTypesCompareMatcher(std::shared_ptr<TypeMatcher> decimal_type_matcher)
: decimal_type_matcher(std::move(decimal_type_matcher)) {}

bool Matches(const DataType& type) const override {
return decimal_type_matcher->Matches(type);
}

bool Matches(const std::vector<TypeHolder>& types) const override {
DCHECK_EQ(types.size(), 2);
if (!is_decimal(*types[0]) || !is_decimal(*types[1])) {
return true;
}

// Below match logic should only be executed when types are both decimal
//
const auto& left_type = checked_cast<const DecimalType&>(*types[0]);
const auto& right_type = checked_cast<const DecimalType&>(*types[1]);

// check the decimal types' scales according kAdd promotion rule
const int32_t s1 = left_type.scale();
const int32_t s2 = right_type.scale();
Comment on lines +409 to +410
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So same scale with different precision can be matched here?

Copy link
Collaborator Author

@ZhangHuiGui ZhangHuiGui May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, kAdd promotion rule only care about the scale and the implicit cast with this rule only keep scales same.

Result<TypeHolder> ResolveDecimalAdditionOrSubtractionOutput(
KernelContext*, const std::vector<TypeHolder>& types) {
return ResolveDecimalBinaryOperationOutput(
types,
[](int32_t p1, int32_t s1, int32_t p2,
int32_t s2) -> Result<std::pair<int32_t, int32_t>> {
if (s1 != s2) {
return Status::Invalid("Addition or subtraction of two decimal ",
"types scale1 != scale2. (", s1, s2, ").");
}
DCHECK_EQ(s1, s2);
const int32_t scale = s1;
const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1;
return std::make_pair(precision, scale);
});

if (s1 != s2) {
return false;
}
return true;
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
const auto* casted = dynamic_cast<const DecimalTypesCompareMatcher*>(&other);
return casted != nullptr &&
decimal_type_matcher->Equals(*casted->decimal_type_matcher);
}

std::string ToString() const override { return "decimal-types-matcher"; }

private:
std::shared_ptr<TypeMatcher> decimal_type_matcher;
};

std::shared_ptr<TypeMatcher> DecimalTypesMatcher(Type::type type_id) {
return std::make_shared<DecimalTypesCompareMatcher>(match::SameTypeId(type_id));
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an interesting approach to use the output-type resolver to assert on the input types, but what you really doing here is asserting on the inputs. I think you should write a custom matcher instead. Compare always returns boolean(), there is nothing to resolve.

That would also help (in the future) dispatching to kernels that can deal with different scales dynamically. Values can still be logically equal even though they are represented physically in memory with different scales.

Copy link
Contributor

@felipecrv felipecrv May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've read the issue and it seems that the problem here is that we need a special cast when the scales don't match because the generic cast can't currently handle that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe we want users to be explicit with casts before comparing. I still think this should be handled by the input matching machinery.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what you really doing here is asserting on the inputs.

Yes, this line of code DCHECK_EQ(left_type.id(), right_type.id()); in ResolveDecimalCompareOutputType is problematic because we might be comparing float and decimal or decimal128 and decimal256...

Or maybe we want users to be explicit with casts before comparing. I still think this should be handled by the input matching machinery.

The current matcher system only works on matching of builtin types, and what the compare kernel function needs is whether the dependencies between builtin types meet the requirements. This requires adding a bool Matches(const std::vector<TypeHolder>& types) const; interface to TypeMatcher.

That means the previous semantics of TypeMatcher was to check the validity of builtin types. Now it is necessary to check the legality of dependencies generated when functions use builtin types. Is this reasonable for the design of TypeMatcher?

But in fact, it is indeed more reasonable to use matcher to determine whether the comparison operation of decimal types is legal.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, there is a similar case in IfElse related kernel functions : #41363.
Do the input types' check in the output resolver in IfElse case is reasonable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this reasonable for the design of TypeMatcher?

As long as it checks only the types and the code is simple enough that we can reason about the set of combinations handled.

If some combinations are invalid, but you want to generate a message, I think you could have the type matcher route bad combinations to these fail kernels that produce an error instead of automatically casting inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way, the output type of "compare" remains boolean as it should be. No extra logic needed for output type resolution, but the function is not total — for some input types, it produces a bad Status.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as it checks only the types and the code is simple enough that we can reason about the set of combinations handled.

Thanks, it's more concise to use TypeMatcher check the input types.

template <typename Op>
std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDoc doc) {
auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), std::move(doc));
Expand Down Expand Up @@ -433,9 +481,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
}

for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
InputType in_type(DecimalTypesMatcher(id));
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
}

{
Expand Down
Loading