-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-41011: [C++] Add an output type resolver for decimal types in CompareFunction so can be casted correctly #41012
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -109,6 +109,9 @@ struct ARROW_EXPORT TypeMatcher { | |||||||||||||||||||||||
/// \brief Return true if this matcher accepts the data type. | ||||||||||||||||||||||||
virtual bool Matches(const DataType& type) const = 0; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// \brief Return true if this matcher accepts the combination types | ||||||||||||||||||||||||
virtual bool Matches(const std::vector<TypeHolder>& types) const { return true; } | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Comment on lines
+112
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class is for matching a single matcher ( I don't think it's reasonable to extend the type matching this way as it makes type checking Turing-complete (wouldn't be easy/possible to encode the typing rules in a provably-decidable type system). I'm sorry for this back and forth because I was the one that said "type matching machinery". I still think this logic shouldn't go on the output type resolver so I have a more precise suggestion this time. Since functions already have a way of implementing custom type matching rules -- arrow/cpp/src/arrow/compute/kernels/scalar_compare.cc Lines 343 to 345 in 9ee6ea7
and return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, the root cause of the problem that this PR is trying to solve is actually CompareFunction with two different decimal scales in BindNonRecursive , it is impossible to enter the logic of arrow/cpp/src/arrow/compute/expression.cc Lines 556 to 563 in 9ee6ea7
Which means the This is why we did type judgment in the output type Resolver before, so that we can return not-ok in the first
|
||||||||||||||||||||||||
/// \brief A human-interpretable string representation of what the type | ||||||||||||||||||||||||
/// matcher checks for, usable when printing KernelSignature or formatting | ||||||||||||||||||||||||
/// error messages. | ||||||||||||||||||||||||
|
@@ -241,6 +244,10 @@ class ARROW_EXPORT InputType { | |||||||||||||||||||||||
/// \brief Return true if the type matches this InputType | ||||||||||||||||||||||||
bool Matches(const DataType& type) const; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// \brief Return true if the input combination types matches this | ||||||||||||||||||||||||
/// InputType's type_matcher matched rules. | ||||||||||||||||||||||||
bool Matches(const std::vector<TypeHolder>& types) const; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// \brief The type matching rule that this InputType uses. | ||||||||||||||||||||||||
Kind kind() const { return kind_; } | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -385,6 +385,54 @@ struct VarArgsCompareFunction : ScalarFunction { | |||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class DecimalTypesCompareMatcher : public TypeMatcher { | ||||||||||||||||||||||||||||||||
public: | ||||||||||||||||||||||||||||||||
explicit DecimalTypesCompareMatcher(std::shared_ptr<TypeMatcher> decimal_type_matcher) | ||||||||||||||||||||||||||||||||
: decimal_type_matcher(std::move(decimal_type_matcher)) {} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
bool Matches(const DataType& type) const override { | ||||||||||||||||||||||||||||||||
return decimal_type_matcher->Matches(type); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
bool Matches(const std::vector<TypeHolder>& types) const override { | ||||||||||||||||||||||||||||||||
DCHECK_EQ(types.size(), 2); | ||||||||||||||||||||||||||||||||
if (!is_decimal(*types[0]) || !is_decimal(*types[1])) { | ||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// Below match logic should only be executed when types are both decimal | ||||||||||||||||||||||||||||||||
// | ||||||||||||||||||||||||||||||||
const auto& left_type = checked_cast<const DecimalType&>(*types[0]); | ||||||||||||||||||||||||||||||||
const auto& right_type = checked_cast<const DecimalType&>(*types[1]); | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
// check the decimal types' scales according kAdd promotion rule | ||||||||||||||||||||||||||||||||
const int32_t s1 = left_type.scale(); | ||||||||||||||||||||||||||||||||
const int32_t s2 = right_type.scale(); | ||||||||||||||||||||||||||||||||
Comment on lines
+409
to
+410
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So same scale with different precision can be matched here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, arrow/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc Lines 509 to 523 in 2dbc5e2
|
||||||||||||||||||||||||||||||||
if (s1 != s2) { | ||||||||||||||||||||||||||||||||
return false; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
bool Equals(const TypeMatcher& other) const override { | ||||||||||||||||||||||||||||||||
if (this == &other) { | ||||||||||||||||||||||||||||||||
return true; | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
const auto* casted = dynamic_cast<const DecimalTypesCompareMatcher*>(&other); | ||||||||||||||||||||||||||||||||
return casted != nullptr && | ||||||||||||||||||||||||||||||||
decimal_type_matcher->Equals(*casted->decimal_type_matcher); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
std::string ToString() const override { return "decimal-types-matcher"; } | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
private: | ||||||||||||||||||||||||||||||||
std::shared_ptr<TypeMatcher> decimal_type_matcher; | ||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
std::shared_ptr<TypeMatcher> DecimalTypesMatcher(Type::type type_id) { | ||||||||||||||||||||||||||||||||
return std::make_shared<DecimalTypesCompareMatcher>(match::SameTypeId(type_id)); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's an interesting approach to use the output-type resolver to assert on the input types, but what you really doing here is asserting on the inputs. I think you should write a custom matcher instead. Compare always returns That would also help (in the future) dispatching to kernels that can deal with different scales dynamically. Values can still be logically equal even though they are represented physically in memory with different scales. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've read the issue and it seems that the problem here is that we need a special cast when the scales don't match because the generic cast can't currently handle that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe we want users to be explicit with casts before comparing. I still think this should be handled by the input matching machinery. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, this line of code
The current matcher system only works on matching of builtin types, and what the compare kernel function needs is whether the dependencies between builtin types meet the requirements. This requires adding a That means the previous semantics of But in fact, it is indeed more reasonable to use matcher to determine whether the comparison operation of decimal types is legal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Besides, there is a similar case in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As long as it checks only the types and the code is simple enough that we can reason about the set of combinations handled. If some combinations are invalid, but you want to generate a message, I think you could have the type matcher route bad combinations to these fail kernels that produce an error instead of automatically casting inputs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That way, the output type of "compare" remains boolean as it should be. No extra logic needed for output type resolution, but the function is not total — for some input types, it produces a bad There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thanks, it's more concise to use |
||||||||||||||||||||||||||||||||
template <typename Op> | ||||||||||||||||||||||||||||||||
std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDoc doc) { | ||||||||||||||||||||||||||||||||
auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), std::move(doc)); | ||||||||||||||||||||||||||||||||
|
@@ -433,9 +481,9 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo | |||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) { | ||||||||||||||||||||||||||||||||
InputType in_type(DecimalTypesMatcher(id)); | ||||||||||||||||||||||||||||||||
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id); | ||||||||||||||||||||||||||||||||
DCHECK_OK( | ||||||||||||||||||||||||||||||||
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); | ||||||||||||||||||||||||||||||||
DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec))); | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type matcher acceptable rule for the combination types.