From aea7929b0a04c5ea0ac85aba2b85fb58c718626f Mon Sep 17 00:00:00 2001 From: Anton Rydahl <44206479+AntonRydahl@users.noreply.github.com> Date: Thu, 23 Nov 2023 10:55:55 -0800 Subject: [PATCH] [libc++] Unify __is_trivial_equality_predicate and __is_trivial_plus_operation into __desugars_to (#68642) When working on an OpenMP offloading backend for standard parallel algorithms (https://github.com/llvm/llvm-project/pull/66968) we noticed the need of a generalization of `__is_trivial_plus_operation`. This patch merges `__is_trivial_equality_predicate` and `__is_trivial_plus_operation` into `__desugars_to`, and in the future we might extend the latter to support other binary operations as well. Co-authored-by: Louis Dionne --- libcxx/include/CMakeLists.txt | 1 - libcxx/include/__algorithm/comp.h | 6 ++-- libcxx/include/__algorithm/equal.h | 21 +++++++------- .../cpu_backends/transform_reduce.h | 28 +++++++++++-------- libcxx/include/__functional/operations.h | 20 ++++++------- .../include/__functional/ranges_operations.h | 8 ++++-- .../include/__numeric/pstl_transform_reduce.h | 2 +- .../include/__type_traits/operation_traits.h | 18 ++++++++++-- .../include/__type_traits/predicate_traits.h | 26 ----------------- libcxx/include/module.modulemap.in | 1 - 10 files changed, 61 insertions(+), 70 deletions(-) delete mode 100644 libcxx/include/__type_traits/predicate_traits.h diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt index 889d7fedbf2965..aef54ea25fd52f 100644 --- a/libcxx/include/CMakeLists.txt +++ b/libcxx/include/CMakeLists.txt @@ -816,7 +816,6 @@ set(files __type_traits/negation.h __type_traits/noexcept_move_assign_container.h __type_traits/operation_traits.h - __type_traits/predicate_traits.h __type_traits/promote.h __type_traits/rank.h __type_traits/remove_all_extents.h diff --git a/libcxx/include/__algorithm/comp.h b/libcxx/include/__algorithm/comp.h index 9474536615ffb6..3902f7560304a1 100644 --- a/libcxx/include/__algorithm/comp.h +++ b/libcxx/include/__algorithm/comp.h @@ -11,7 +11,7 @@ #include <__config> #include <__type_traits/integral_constant.h> -#include <__type_traits/predicate_traits.h> +#include <__type_traits/operation_traits.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) # pragma GCC system_header @@ -26,8 +26,8 @@ struct __equal_to { } }; -template -struct __is_trivial_equality_predicate<__equal_to, _Lhs, _Rhs> : true_type {}; +template +struct __desugars_to<__equal_tag, __equal_to, _Tp, _Up> : true_type {}; // The definition is required because __less is part of the ABI, but it's empty // because all comparisons should be transparent. diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h index b69aeff92bb928..ca2e49ca5679a4 100644 --- a/libcxx/include/__algorithm/equal.h +++ b/libcxx/include/__algorithm/equal.h @@ -23,7 +23,7 @@ #include <__type_traits/is_constant_evaluated.h> #include <__type_traits/is_equality_comparable.h> #include <__type_traits/is_volatile.h> -#include <__type_traits/predicate_traits.h> +#include <__type_traits/operation_traits.h> #include <__utility/move.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -41,13 +41,12 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 boo return true; } -template < - class _Tp, - class _Up, - class _BinaryPredicate, - __enable_if_t<__is_trivial_equality_predicate<_BinaryPredicate, _Tp, _Up>::value && !is_volatile<_Tp>::value && - !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value, - int> = 0> +template ::value && !is_volatile<_Tp>::value && + !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value, + int> = 0> _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_iter_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _BinaryPredicate&) { return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1)); @@ -94,12 +93,12 @@ template ::value && __is_identity<_Proj1>::value && + __enable_if_t<__desugars_to<__equal_tag, _Pred, _Tp, _Up>::value && __is_identity<_Proj1>::value && __is_identity<_Proj2>::value && !is_volatile<_Tp>::value && !is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value, int> = 0> -_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __equal_impl( - _Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) { +_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool +__equal_impl(_Tp* __first1, _Tp* __last1, _Up* __first2, _Up*, _Pred&, _Proj1&, _Proj2&) { return std::__constexpr_memcmp_equal(__first1, __first2, __element_count(__last1 - __first1)); } diff --git a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h index a5ca9c89d1ab23..ab2e3172b8b63b 100644 --- a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h +++ b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h @@ -29,12 +29,14 @@ _LIBCPP_BEGIN_NAMESPACE_STD -template < - typename _DifferenceType, - typename _Tp, - typename _BinaryOperation, - typename _UnaryOperation, - __enable_if_t<__is_trivial_plus_operation<_BinaryOperation, _Tp, _Tp>::value && is_arithmetic_v<_Tp>, int> = 0> +template , + __enable_if_t<__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value && is_arithmetic_v<_Tp> && + is_arithmetic_v<_UnaryResult>, + int> = 0> _LIBCPP_HIDE_FROM_ABI _Tp __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept { _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init) @@ -43,12 +45,14 @@ __simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _Unar return __init; } -template < - typename _Size, - typename _Tp, - typename _BinaryOperation, - typename _UnaryOperation, - __enable_if_t::value && is_arithmetic_v<_Tp>), int> = 0> +template , + __enable_if_t::value && + is_arithmetic_v<_Tp> && is_arithmetic_v<_UnaryResult>), + int> = 0> _LIBCPP_HIDE_FROM_ABI _Tp __simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept { const _Size __block_size = __lane_size / sizeof(_Tp); diff --git a/libcxx/include/__functional/operations.h b/libcxx/include/__functional/operations.h index 6cdb89d6b449bc..9812ccf8e4136f 100644 --- a/libcxx/include/__functional/operations.h +++ b/libcxx/include/__functional/operations.h @@ -15,7 +15,6 @@ #include <__functional/unary_function.h> #include <__type_traits/integral_constant.h> #include <__type_traits/operation_traits.h> -#include <__type_traits/predicate_traits.h> #include <__utility/forward.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -41,13 +40,13 @@ struct _LIBCPP_TEMPLATE_VIS plus }; _LIBCPP_CTAD_SUPPORTED_FOR_TYPE(plus); +// The non-transparent std::plus specialization is only equivalent to a raw plus +// operator when we don't perform an implicit conversion when calling it. template -struct __is_trivial_plus_operation, _Tp, _Tp> : true_type {}; +struct __desugars_to<__plus_tag, plus<_Tp>, _Tp, _Tp> : true_type {}; -#if _LIBCPP_STD_VER >= 14 template -struct __is_trivial_plus_operation, _Tp, _Up> : true_type {}; -#endif +struct __desugars_to<__plus_tag, plus, _Tp, _Up> : true_type {}; #if _LIBCPP_STD_VER >= 14 template <> @@ -352,13 +351,14 @@ struct _LIBCPP_TEMPLATE_VIS equal_to }; #endif +// The non-transparent std::equal_to specialization is only equivalent to a raw equality +// comparison when we don't perform an implicit conversion when calling it. template -struct __is_trivial_equality_predicate, _Tp, _Tp> : true_type {}; +struct __desugars_to<__equal_tag, equal_to<_Tp>, _Tp, _Tp> : true_type {}; -#if _LIBCPP_STD_VER >= 14 -template -struct __is_trivial_equality_predicate, _Tp, _Tp> : true_type {}; -#endif +// In the transparent case, we do not enforce that +template +struct __desugars_to<__equal_tag, equal_to, _Tp, _Up> : true_type {}; #if _LIBCPP_STD_VER >= 14 template diff --git a/libcxx/include/__functional/ranges_operations.h b/libcxx/include/__functional/ranges_operations.h index c344fc38f98ddd..b54589f8c0d879 100644 --- a/libcxx/include/__functional/ranges_operations.h +++ b/libcxx/include/__functional/ranges_operations.h @@ -14,7 +14,7 @@ #include <__concepts/totally_ordered.h> #include <__config> #include <__type_traits/integral_constant.h> -#include <__type_traits/predicate_traits.h> +#include <__type_traits/operation_traits.h> #include <__utility/forward.h> #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) @@ -95,8 +95,10 @@ struct greater_equal { } // namespace ranges -template -struct __is_trivial_equality_predicate : true_type {}; +// For ranges we do not require that the types on each side of the equality +// operator are of the same type +template +struct __desugars_to<__equal_tag, ranges::equal_to, _Tp, _Up> : true_type {}; #endif // _LIBCPP_STD_VER >= 20 diff --git a/libcxx/include/__numeric/pstl_transform_reduce.h b/libcxx/include/__numeric/pstl_transform_reduce.h index 4127ee21e3045c..1127726046665c 100644 --- a/libcxx/include/__numeric/pstl_transform_reduce.h +++ b/libcxx/include/__numeric/pstl_transform_reduce.h @@ -84,7 +84,7 @@ _LIBCPP_HIDE_FROM_ABI _Tp transform_reduce( } // This overload doesn't get a customization point because it's trivial to detect (through e.g. -// __is_trivial_plus_operation) when specializing the more general variant, which should always be preferred +// __desugars_to) when specializing the more general variant, which should always be preferred template -struct __is_trivial_plus_operation : false_type {}; +// Tags to represent the canonical operations +struct __equal_tag {}; +struct __plus_tag {}; + +// This class template is used to determine whether an operation "desugars" +// (or boils down) to a given canonical operation. +// +// For example, `std::equal_to<>`, our internal `std::__equal_to` helper and +// `ranges::equal_to` are all just fancy ways of representing a transparent +// equality operation, so they all desugar to `__equal_tag`. +// +// This is useful to optimize some functions in cases where we know e.g. the +// predicate being passed is actually going to call a builtin operator, or has +// some specific semantics. +template +struct __desugars_to : false_type {}; _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/include/__type_traits/predicate_traits.h b/libcxx/include/__type_traits/predicate_traits.h deleted file mode 100644 index 872608e6ac3be3..00000000000000 --- a/libcxx/include/__type_traits/predicate_traits.h +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS -#define _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS - -#include <__config> -#include <__type_traits/integral_constant.h> - -#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) -# pragma GCC system_header -#endif - -_LIBCPP_BEGIN_NAMESPACE_STD - -template -struct __is_trivial_equality_predicate : false_type {}; - -_LIBCPP_END_NAMESPACE_STD - -#endif // _LIBCPP___TYPE_TRAITS_PREDICATE_TRAITS diff --git a/libcxx/include/module.modulemap.in b/libcxx/include/module.modulemap.in index 17ebe48f329963..b4a68c8ecde0ab 100644 --- a/libcxx/include/module.modulemap.in +++ b/libcxx/include/module.modulemap.in @@ -2013,7 +2013,6 @@ module std_private_type_traits_nat [system module std_private_type_traits_negation [system] { header "__type_traits/negation.h" } module std_private_type_traits_noexcept_move_assign_container [system] { header "__type_traits/noexcept_move_assign_container.h" } module std_private_type_traits_operation_traits [system] { header "__type_traits/operation_traits.h" } -module std_private_type_traits_predicate_traits [system] { header "__type_traits/predicate_traits.h" } module std_private_type_traits_promote [system] { header "__type_traits/promote.h" } module std_private_type_traits_rank [system] { header "__type_traits/rank.h" } module std_private_type_traits_remove_all_extents [system] { header "__type_traits/remove_all_extents.h" }