Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libcxx] Unifying __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to #68642

Merged
merged 14 commits into from
Nov 23, 2023

Conversation

AntonRydahl
Copy link
Contributor

When working on an OpenMP offloading backend for standard parallel algorithms (#66968) we noticed the need of a generalization of __is_trivial_plus_operation. For now, I have converted __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to, and we may then extend the latter to support other binary operations as well.

@AntonRydahl AntonRydahl requested a review from a team as a code owner October 9, 2023 22:40
@llvmbot llvmbot added the libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi. label Oct 9, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-libcxx

Changes

When working on an OpenMP offloading backend for standard parallel algorithms (#66968) we noticed the need of a generalization of __is_trivial_plus_operation. For now, I have converted __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to, and we may then extend the latter to support other binary operations as well.


Full diff: https://github.com/llvm/llvm-project/pull/68642.diff

10 Files Affected:

  • (modified) libcxx/include/CMakeLists.txt (-1)
  • (modified) libcxx/include/__algorithm/comp.h (+4-3)
  • (modified) libcxx/include/__algorithm/equal.h (+11-11)
  • (modified) libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h (+11-12)
  • (modified) libcxx/include/__functional/operations.h (+7-8)
  • (modified) libcxx/include/__functional/ranges_operations.h (+4-3)
  • (modified) libcxx/include/__numeric/pstl_transform_reduce.h (+1-1)
  • (modified) libcxx/include/__type_traits/operation_traits.h (+2-2)
  • (removed) libcxx/include/__type_traits/predicate_traits.h (-26)
  • (modified) libcxx/include/module.modulemap.in (-1)
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index 340353f8ebb41c4..7eb09a06ccd482e 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -813,7 +813,6 @@ set(files
   __type_traits/negation.h
   __type_traits/noexcept_move_assign_container.h
   __type_traits/operation_traits.h
-  __type_traits/predicate_traits.h
   __type_traits/promote.h
   __type_traits/rank.h
   __type_traits/remove_all_extents.h
diff --git a/libcxx/include/__algorithm/comp.h b/libcxx/include/__algorithm/comp.h
index 9474536615ffb67..0993c37cce36a6b 100644
--- a/libcxx/include/__algorithm/comp.h
+++ b/libcxx/include/__algorithm/comp.h
@@ -10,8 +10,9 @@
 #define _LIBCPP___ALGORITHM_COMP_H
 
 #include <__config>
+#include <__functional/operations.h>
 #include <__type_traits/integral_constant.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
@@ -26,8 +27,8 @@ struct __equal_to {
   }
 };
 
-template <class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate<__equal_to, _Lhs, _Rhs> : true_type {};
+template <>
+struct __desugars_to<__equal_to, std::equal_to<>> : true_type {};
 
 // The definition is required because __less is part of the ABI, but it's empty
 // because all comparisons should be transparent.
diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h
index b69aeff92bb9289..35e82da15e4d058 100644
--- a/libcxx/include/__algorithm/equal.h
+++ b/libcxx/include/__algorithm/equal.h
@@ -15,6 +15,7 @@
 #include <__config>
 #include <__functional/identity.h>
 #include <__functional/invoke.h>
+#include <__functional/operations.h>
 #include <__iterator/distance.h>
 #include <__iterator/iterator_traits.h>
 #include <__string/constexpr_c_functions.h>
@@ -23,7 +24,7 @@
 #include <__type_traits/is_constant_evaluated.h>
 #include <__type_traits/is_equality_comparable.h>
 #include <__type_traits/is_volatile.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 #include <__utility/move.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -41,13 +42,12 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 boo
   return true;
 }
 
-template <
-    class _Tp,
-    class _Up,
-    class _BinaryPredicate,
-    __enable_if_t<__is_trivial_equality_predicate<_BinaryPredicate, _Tp, _Up>::value && !is_volatile<_Tp>::value &&
-                      !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
-                  int> = 0>
+template < class _Tp,
+           class _Up,
+           class _BinaryPredicate,
+           __enable_if_t<__desugars_to<_BinaryPredicate, std::equal_to<>>::value && !is_volatile<_Tp>::value &&
+                             !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
+                         int> = 0>
 _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
 __equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&) {
   return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
@@ -94,12 +94,12 @@ template <class _Tp,
           class _Pred,
           class _Proj1,
           class _Proj2,
-          __enable_if_t<__is_trivial_equality_predicate<_Pred, _Tp, _Up>::value && __is_identity<_Proj1>::value &&
+          __enable_if_t<__desugars_to<_Pred, std::equal_to<>>::value && __is_identity<_Proj1>::value &&
                             __is_identity<_Proj2>::value && !is_volatile<_Tp>::value && !is_volatile<_Up>::value &&
                             __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
                         int> = 0>
-_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl(
-    _Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) {
+_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
+__equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) {
   return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1));
 }
 
diff --git a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
index a5ca9c89d1ab23b..3f4bc31632956b3 100644
--- a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
+++ b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
@@ -11,6 +11,7 @@
 
 #include <__algorithm/pstl_backends/cpu_backends/backend.h>
 #include <__config>
+#include <__functional/operations.h>
 #include <__iterator/concepts.h>
 #include <__iterator/iterator_traits.h>
 #include <__numeric/transform_reduce.h>
@@ -29,12 +30,11 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <
-    typename _DifferenceType,
-    typename _Tp,
-    typename _BinaryOperation,
-    typename _UnaryOperation,
-    __enable_if_t<__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>, int> = 0>
+template < typename _DifferenceType,
+           typename _Tp,
+           typename _BinaryOperation,
+           typename _UnaryOperation,
+           __enable_if_t<__desugars_to<_BinaryOperation, std::plus<>>::value && is_arithmetic_v<_Tp>, int> = 0>
 _LIBCPP_HIDE_FROM_ABI _Tp
 __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept {
   _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init)
@@ -43,12 +43,11 @@ __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _Unar
   return __init;
 }
 
-template <
-    typename _Size,
-    typename _Tp,
-    typename _BinaryOperation,
-    typename _UnaryOperation,
-    __enable_if_t<!(__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>), int> = 0>
+template < typename _Size,
+           typename _Tp,
+           typename _BinaryOperation,
+           typename _UnaryOperation,
+           __enable_if_t<!(__desugars_to<_BinaryOperation, std::plus<>>::value && is_arithmetic_v<_Tp>), int> = 0>
 _LIBCPP_HIDE_FROM_ABI _Tp
 __simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept {
   const _Size __block_size = __lane_size / sizeof(_Tp);
diff --git a/libcxx/include/__functional/operations.h b/libcxx/include/__functional/operations.h
index 6cdb89d6b449bcd..d6c8ff547cfbdcd 100644
--- a/libcxx/include/__functional/operations.h
+++ b/libcxx/include/__functional/operations.h
@@ -15,7 +15,6 @@
 #include <__functional/unary_function.h>
 #include <__type_traits/integral_constant.h>
 #include <__type_traits/operation_traits.h>
-#include <__type_traits/predicate_traits.h>
 #include <__utility/forward.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -41,12 +40,12 @@ struct _LIBCPP_TEMPLATE_VIS plus
 };
 _LIBCPP_CTAD_SUPPORTED_FOR_TYPE(plus);
 
-template <class _Tp>
-struct __is_trivial_plus_operation<plus<_Tp>, _Tp, _Tp> : true_type {};
+template <class _Pred>
+struct __desugars_to<plus<_Pred>, plus<_Pred>> : true_type {};
 
 #if _LIBCPP_STD_VER >= 14
-template <class _Tp, class _Up>
-struct __is_trivial_plus_operation<plus<>, _Tp, _Up> : true_type {};
+template <>
+struct __desugars_to<plus<>, plus<>> : true_type {};
 #endif
 
 #if _LIBCPP_STD_VER >= 14
@@ -353,11 +352,11 @@ struct _LIBCPP_TEMPLATE_VIS equal_to<void>
 #endif
 
 template <class _Tp>
-struct __is_trivial_equality_predicate<equal_to<_Tp>, _Tp, _Tp> : true_type {};
+struct __desugars_to<equal_to<_Tp>, std::equal_to<_Tp>> : true_type {};
 
 #if _LIBCPP_STD_VER >= 14
-template <class _Tp>
-struct __is_trivial_equality_predicate<equal_to<>, _Tp, _Tp> : true_type {};
+template <>
+struct __desugars_to<equal_to<>, std::equal_to<>> : true_type {};
 #endif
 
 #if _LIBCPP_STD_VER >= 14
diff --git a/libcxx/include/__functional/ranges_operations.h b/libcxx/include/__functional/ranges_operations.h
index c344fc38f98ddd9..22d6f7fef3b3da5 100644
--- a/libcxx/include/__functional/ranges_operations.h
+++ b/libcxx/include/__functional/ranges_operations.h
@@ -13,8 +13,9 @@
 #include <__concepts/equality_comparable.h>
 #include <__concepts/totally_ordered.h>
 #include <__config>
+#include <__functional/operations.h>
 #include <__type_traits/integral_constant.h>
-#include <__type_traits/predicate_traits.h>
+#include <__type_traits/operation_traits.h>
 #include <__utility/forward.h>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -95,8 +96,8 @@ struct greater_equal {
 
 } // namespace ranges
 
-template <class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate<ranges::equal_to, _Lhs, _Rhs> : true_type {};
+template <>
+struct __desugars_to<ranges::equal_to, std::equal_to<>> : true_type {};
 
 #endif // _LIBCPP_STD_VER >= 20
 
diff --git a/libcxx/include/__numeric/pstl_transform_reduce.h b/libcxx/include/__numeric/pstl_transform_reduce.h
index 4127ee21e3045c8..1127726046665c0 100644
--- a/libcxx/include/__numeric/pstl_transform_reduce.h
+++ b/libcxx/include/__numeric/pstl_transform_reduce.h
@@ -84,7 +84,7 @@ _LIBCPP_HIDE_FROM_ABI _Tp transform_reduce(
 }
 
 // This overload doesn't get a customization point because it's trivial to detect (through e.g.
-// __is_trivial_plus_operation) when specializing the more general variant, which should always be preferred
+// __desugars_to) when specializing the more general variant, which should always be preferred
 template <class _ExecutionPolicy,
           class _ForwardIterator1,
           class _ForwardIterator2,
diff --git a/libcxx/include/__type_traits/operation_traits.h b/libcxx/include/__type_traits/operation_traits.h
index 7dda93e9083a404..d03bf6209a6cc77 100644
--- a/libcxx/include/__type_traits/operation_traits.h
+++ b/libcxx/include/__type_traits/operation_traits.h
@@ -18,8 +18,8 @@
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <class _Pred, class _Lhs, class _Rhs>
-struct __is_trivial_plus_operation : false_type {};
+template <class _Pred, class _Reference>
+struct __desugars_to : false_type {};
 
 _LIBCPP_END_NAMESPACE_STD
 
diff --git a/libcxx/include/__type_traits/predicate_traits.h b/libcxx/include/__type_traits/predicate_traits.h
deleted file mode 100644
index 872608e6ac3be3f..000000000000000
--- a/libcxx/include/__type_traits/predicate_traits.h
+++ /dev/null
@@ -1,26 +0,0 @@
-//===----------------------------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS
-#define _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS
-
-#include <__config>
-#include <__type_traits/integral_constant.h>
-
-#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
-#  pragma GCC system_header
-#endif
-
-_LIBCPP_BEGIN_NAMESPACE_STD
-
-template <class _Pred, class _Lhs, class _Rhs>
-struct __is_trivial_equality_predicate : false_type {};
-
-_LIBCPP_END_NAMESPACE_STD
-
-#endif // _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS
diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in
index 26657b78b8440c6..3c08c2fb52ce65f 100644
--- a/libcxx/include/module.modulemap.in
+++ b/libcxx/include/module.modulemap.in
@@ -2015,7 +2015,6 @@ module std_private_type_traits_nat                                       [system
 module std_private_type_traits_negation                                  [system] { header "__type_traits/negation.h" }
 module std_private_type_traits_noexcept_move_assign_container            [system] { header "__type_traits/noexcept_move_assign_container.h" }
 module std_private_type_traits_operation_traits                          [system] { header "__type_traits/operation_traits.h" }
-module std_private_type_traits_predicate_traits                          [system] { header "__type_traits/predicate_traits.h" }
 module std_private_type_traits_promote                                   [system] { header "__type_traits/promote.h" }
 module std_private_type_traits_rank                                      [system] { header "__type_traits/rank.h" }
 module std_private_type_traits_remove_all_extents                        [system] { header "__type_traits/remove_all_extents.h" }

Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few suggestions but I quite like where this is going. @philnik777 what do you think?

libcxx/include/__type_traits/operation_traits.h Outdated Show resolved Hide resolved
libcxx/include/__functional/operations.h Outdated Show resolved Hide resolved
libcxx/include/__functional/operations.h Outdated Show resolved Hide resolved
libcxx/include/__functional/operations.h Outdated Show resolved Hide resolved
libcxx/include/__algorithm/comp.h Outdated Show resolved Hide resolved
@AntonRydahl
Copy link
Contributor Author

@ldionne do you think it looks right at this point? 😄

@philnik777 philnik777 added enhancement Improving things as opposed to bug fixing, e.g. new or missing feature code-cleanup and removed enhancement Improving things as opposed to bug fixing, e.g. new or missing feature labels Oct 27, 2023
@@ -18,8 +18,11 @@

_LIBCPP_BEGIN_NAMESPACE_STD

template <class _Pred, class _Lhs, class _Rhs>
struct __is_trivial_plus_operation : false_type {};
template <class _Operation, class _Canonical>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we want is something like this:

struct __equal_tag {};
struct __plus_tag {};
// etc...

template <class _CanonicalTag, class _Operation, class ..._Args>
struct __desugars_to : false_type {};


// std::equal_to and friends
template <class _Tp> struct __desugars_to<__equal_tag, equal_to<_Tp>, _Tp, _Tp> : true_type {};
template <class _Tp, class _Up> struct __desugars_to<__equal_tag, equal_to<void>, _Tp, _Up> : true_type {};
template <class _Tp, class _Up> struct __desugars_to<__equal_tag, __equal, _Tp, _Up> : true_type {};
template <class _Tp, class _Up> struct __desugars_to<__equal_tag, ranges::equal_to, _Tp, _Up> : true_type {};

// std::plus and friends
etc...

I originally thought that using std::equal_to<> as a "tag" to represent the canonical operation was a good idea, but since it doesn't exist in older standards we end up having to use std::equal_to<void> explicitly, and that really obfuscates the fact that it's meant to be a tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ldionne!

I am unsure about what I am supposed to do with the following:

template <class _Tp, class _Up> struct __desugars_to<__equal_tag, __equal, _Tp, _Up> : true_type {};

Do we want to match the function from include/__algorithm/equal.h?

libcxx/include/__type_traits/operation_traits.h Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Nov 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@ldionne ldionne left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my latest changes this is good to go.

@ldionne ldionne merged commit aea7929 into llvm:main Nov 23, 2023
36 of 39 checks passed
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Nov 30, 2023
Local branch amd-gfx afee350 Merged main:381efa496000 into amd-gfx:df4f5070dfea
Remote branch main aea7929 [libc++] Unify __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to (llvm#68642)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
code-cleanup libc++ libc++ C++ Standard Library. Not GNU libstdc++. Not libc++abi.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants