Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc++] Add input validation for set_intersection() in debug mode. #101508

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions libcxx/include/__algorithm/is_sorted_until.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a release note in 20.rst?

_LIBCPP_BEGIN_NAMESPACE_STD

template <class _Compare, class _ForwardIterator>
template <class _Compare, class _ForwardIterator, class _Sent>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
if (__first != __last) {
_ForwardIterator __i = __first;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_ForwardIterator __i = __first;
_ForwardIterator __prev = __first;

This makes the code quite a bit clearer.

while (++__i != __last) {
if (__comp(*__i, *__first))
return __i;
__first = __i;
while (++__first != __last) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what is the reason to swap first and i here?

if (__comp(*__first, *__i))
return __first;
__i = __first;
}
}
return __last;
return __first;
}

template <class _ForwardIterator, class _Compare>
Expand Down
11 changes: 11 additions & 0 deletions libcxx/include/__algorithm/set_intersection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

#include <__algorithm/comp.h>
#include <__algorithm/comp_ref_type.h>
#include <__algorithm/is_sorted_until.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/lower_bound.h>
#include <__assert>
#include <__config>
#include <__functional/identity.h>
#include <__iterator/iterator_traits.h>
#include <__iterator/next.h>
#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_same.h>
#include <__utility/exchange.h>
#include <__utility/move.h>
Expand Down Expand Up @@ -95,6 +98,14 @@ __set_intersection(
_Compare&& __comp,
std::forward_iterator_tag,
std::forward_iterator_tag) {
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
if (!__libcpp_is_constant_evaluated()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required?

Copy link
Contributor Author

@ichaer ichaer Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because__builtin_expect(), which_LIBCPP_ASSERT() expands to, can't be constant-evaluated. I learned that from a compilation error, btw.

Edit: Sorry, now that I said it I'm not sure =/. Maybe it wasn't __builtin_expect(), but _LIBCPP_VERBOSE_ABORT()? Anyway, something inside _LIBCPP_ASSERT() can't be constant-evaluated. I had been using __check_strict_weak_ordering_sorted() as my blueprint for this change, but I left that bit out and the compiler explained to me why I couldn't.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand. Can you try removing the if(is-constant-evaluated) and let's see what the CI says? You're probably right, but I'd like to see what the error is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 -- also curious to see the error.

_LIBCPP_ASSERT_INTERNAL(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should likely be _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN but I'd like @var-const to chime in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have added a comment: I didn't do that because _LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN is enabled in _LIBCPP_HARDENING_MODE_EXTENSIVE, and I thought the cost wasn't appropriate. More about this in my response to https://github.com/llvm/llvm-project/pull/101508/files#r1702696227.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think this should be _LIBCPP_ASSERT_SEMANTIC_REQUIREMENT instead, actually. That matches what we do for __check_strict_weak_ordering_sorted.

Copy link
Member

@var-const var-const Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. +1 to Louis' comment -- internal is meant for checks that aim to catch bugs in our own implementation, bugs that are independent from user input. Since this checks the user-provided arguments, we need to find some other category.
  2. My first intuition is that argument-within-domain is a somewhat better match than semantic-requirement. In __check_strict_weak_ordering_sorted, we're checking at the resulting (presumably) sorted sequence as a way to validate the given comparator -- the comparator has the semantic requirement to provide strict weak ordering, but we cannot check that without resorting to an imperfect heuristic. Here, however, we are checking the given argument directly, and the check is very straightforward, just expensive. argument-within-domain is essentially a catch-all for "the given argument is valid but if it's not, it won't cause UB within our code (but will produce an incorrect result that might well cause UB in user code)", which seems to apply to the situation here.
  3. Since we're wrapping the whole thing in a conditional, it's not really important which modes enable the assertion category we choose -- e.g. if we choose argument-within-domain that is enabled in both extensive and debug, the check for _LIBCPP_HARDENING_MODE_DEBUG still makes sure it only runs in debug. It's a little inelegant, but we already have precedent in __check_strict_weak_ordering_sorted, so I wouldn't try to fix that within this patch.

std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
_LIBCPP_ASSERT_INTERNAL(
std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
}
#endif
_LIBCPP_CONSTEXPR std::__identity __proj;
bool __prev_may_be_equal = false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,32 @@

#include "test_iterators.h"

namespace {

// __debug_less will perform an additional comparison in an assertion
static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
return 2;
// debug mode provides no complexity guarantees, testing them would be a waste of effort
ichaer marked this conversation as resolved.
Show resolved Hide resolved
// but we still want to run this test, to ensure we don't trigger any assertions
#ifdef _LIBCPP_HARDENING_MODE_DEBUG
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. _LIBCPP_HARDENING_MODE_DEBUG is always defined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#ifdef _LIBCPP_HARDENING_MODE_DEBUG
#if defined(_LIBCPP_HARDENING_MODE_DEBUG) && _LIBCPP_HARDENING_MODE_DEBUG

We still need to check for if defined(_LIBCPP_HARDENING_MODE_DEBUG) to avoid -Wundef when testing non-libc++ libraries.

# define ASSERT_COMPLEXITY(expression)
ichaer marked this conversation as resolved.
Show resolved Hide resolved
#else
return 1;
# define ASSERT_COMPLEXITY(expression) assert(expression)
#endif
}

namespace {

struct [[nodiscard]] OperationCounts {
std::size_t comparisons{};
struct PerInput {
std::size_t proj{};
IteratorOpCounts iterops;

[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
}
};
std::array<PerInput, 2> in;

[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
in[1].isNotBetterThan(expect.in[1]);
}
};

Expand All @@ -80,16 +79,17 @@ struct counted_set_intersection_result {

constexpr counted_set_intersection_result() = default;

constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
: result{contents} {}

constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
assert(result == other.result);
assert(opcounts.isNotBetterThan(other.opcounts));
ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
}
};

template <std::size_t ResultSize>
counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;

template <template <class...> class InIterType1,
template <class...>
Expand Down Expand Up @@ -306,7 +306,7 @@ constexpr bool testComplexityBasic() {
std::array<int, 5> r2{2, 4, 6, 8, 10};
std::array<int, 0> expected{};

const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
[[maybe_unused]] const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;

// std::set_intersection
{
Expand All @@ -321,7 +321,7 @@ constexpr bool testComplexityBasic() {
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);

assert(std::ranges::equal(out, expected));
assert(numberOfComp <= maxOperation);
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
}

// ranges::set_intersection iterator overload
Expand Down Expand Up @@ -349,9 +349,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);

assert(std::ranges::equal(out, expected));
assert(numberOfComp <= maxOperation);
assert(numberOfProj1 <= maxOperation);
assert(numberOfProj2 <= maxOperation);
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
}

// ranges::set_intersection range overload
Expand Down Expand Up @@ -379,9 +379,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);

assert(std::ranges::equal(out, expected));
assert(numberOfComp < maxOperation);
assert(numberOfProj1 < maxOperation);
assert(numberOfProj2 < maxOperation);
ASSERT_COMPLEXITY(numberOfComp < maxOperation);
ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,44 +40,45 @@ constexpr bool test_all() {
constexpr auto operator<=>(const A&) const = default;
};

std::array in = {1, 2, 3};
std::array in2 = {A{4}, A{5}, A{6}};
const std::array in = {1, 2, 3};
const std::array in2 = {A{4}, A{5}, A{6}};

std::array output = {7, 8, 9, 10, 11, 12};
auto out = output.begin();
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
auto out2 = output2.begin();

std::ranges::equal_to eq;
std::ranges::less less;
auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
auto proj1 = [](int x) { return x * -1; };
auto proj2 = [](A a) { return a.x * -1; };
const std::ranges::equal_to eq;
const std::ranges::less less;
const std::ranges::greater greater;
const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
const auto proj1 = [](int x) { return x * -1; };
const auto proj2 = [](A a) { return a.x * -1; };

#if TEST_STD_VER >= 23
test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
#endif
test(std::ranges::equal, in, in2, eq, proj1, proj2);
test(std::ranges::lexicographical_compare, in, in2, eq, proj1, proj2);
test(std::ranges::is_permutation, in, in2, eq, proj1, proj2);
test(std::ranges::includes, in, in2, less, proj1, proj2);
test(std::ranges::includes, in, in2, greater, proj1, proj2);
test(std::ranges::find_first_of, in, in2, eq, proj1, proj2);
test(std::ranges::mismatch, in, in2, eq, proj1, proj2);
test(std::ranges::search, in, in2, eq, proj1, proj2);
test(std::ranges::find_end, in, in2, eq, proj1, proj2);
test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
test(std::ranges::merge, in, in2, out, less, proj1, proj2);
test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
#if TEST_STD_VER > 20
test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
#endif
Expand Down
Loading