Skip to content

Commit

Permalink
Document some transform iterator corner cases (NVIDIA#2740)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored and fbusato committed Nov 12, 2024
1 parent a2ea4b7 commit 5af1a55
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
51 changes: 50 additions & 1 deletion thrust/testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ struct pass_ref
}
};

// TODO(bgruber): replace by libc++ with C++14
// a user provided functor that forwards its argument
struct forward
{
template <class _Tp>
Expand Down Expand Up @@ -157,6 +157,16 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_fwd)::reference, bool&&>::value, "");
static_assert(is_same<decltype(it_tr_fwd)::value_type, bool>::value, "");
(void) it_tr_fwd;

auto it_tr_tid = thrust::make_transform_iterator(it, thrust::identity<bool>{});
static_assert(is_same<decltype(it_tr_tid)::reference, bool>::value, ""); // identity<bool>::value_type
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}

{
Expand All @@ -180,6 +190,16 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_fwd)::reference, bool&&>::value, ""); // wrapped reference is decayed
static_assert(is_same<decltype(it_tr_fwd)::value_type, bool>::value, "");
(void) it_tr_fwd;

auto it_tr_tid = thrust::make_transform_iterator(it, thrust::identity<bool>{});
static_assert(is_same<decltype(it_tr_tid)::reference, bool>::value, ""); // identity<bool>::value_type
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}

{
Expand All @@ -203,6 +223,35 @@ void TestTransformIteratorReferenceAndValueType()
static_assert(is_same<decltype(it_tr_fwd)::reference, bool&&>::value, ""); // proxy reference is decayed
static_assert(is_same<decltype(it_tr_fwd)::value_type, bool>::value, "");
(void) it_tr_fwd;

auto it_tr_ide = thrust::make_transform_iterator(it, thrust::identity<bool>{});
static_assert(is_same<decltype(it_tr_ide)::reference, bool>::value, ""); // identity<bool>::value_type
static_assert(is_same<decltype(it_tr_ide)::value_type, bool>::value, "");
(void) it_tr_ide;

auto it_tr_tid = thrust::make_transform_iterator(it, thrust::identity<bool>{});
static_assert(is_same<decltype(it_tr_tid)::reference, bool>::value, ""); // identity<bool>::value_type
static_assert(is_same<decltype(it_tr_tid)::value_type, bool>::value, "");
(void) it_tr_tid;

auto it_tr_cid = thrust::make_transform_iterator(it, cuda::std::__identity{});
static_assert(is_same<decltype(it_tr_cid)::reference, bool&&>::value, ""); // inferred, like forward
static_assert(is_same<decltype(it_tr_cid)::value_type, bool>::value, "");
(void) it_tr_cid;
}
}
DECLARE_UNITTEST(TestTransformIteratorReferenceAndValueType);

void TestTransformIteratorIdentity()
{
thrust::device_vector<int> v(3, 42);

ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), thrust::identity<int>{}), 42);
// FIXME(bgruber): fix transform_iterator to get these tests compiling:
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), thrust::identity<>{}), 42);
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), cuda::std::identity{}), 42);
// using namespace thrust::placeholders;
// ASSERT_EQUAL(*thrust::make_transform_iterator(v.begin(), _1), 42);
}

DECLARE_UNITTEST(TestTransformIteratorIdentity);
25 changes: 18 additions & 7 deletions thrust/thrust/iterator/transform_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,26 @@ class transform_iterator
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
// TODO(bgruber): we should do as `std::ranges::transform_view::iterator` does:
// `return std::invoke(m_f, *this->base());` and return `decltype(auto)`

// Create a temporary to allow iterators with wrapped references to convert to their value type before calling m_f.
// Note that this disallows non-constant operations through m_f.
typename thrust::iterator_value<Iterator>::type const& x = *this->base();
// TODO(bgruber): we should ideally do as `std::ranges::transform_view::iterator` does:
// `return std::invoke(m_f, *this->base());` and return `decltype(auto)`. However, `*this->base()` may return a
// wrapped reference (`device_reference<T>`), which is a temporary value. If `m_f` forwards this value, e.g. as a
// `device_reference<T>&&` if `m_f` is `identity<void>`, (and `super_t::reference` is thus deduced as
// `device_reference<T>&&` as well), we return a dangling reference. So we cannot do as
// `std::ranges::transform_view::iterator` does.

// Interestingly, C++20 ranges have the same bug. The following program crashes because the transform iterator also
// returns a reference to an expired temporary (given by the iota iterator upon dereferencing)
// for (auto e : std::views::iota(10) | std::views::transform(std::identity{}))
// std::cout << e << '\n';
// See: https://godbolt.org/z/jrKcnMqhK

// The workaround is to create a temporary to allow iterators with wrapped/proxy references to convert to their
// value type before calling m_f. This also loads values from a different memory space (cf. `device_reference`).
// Note that this disallows mutable operations through m_f.
iterator_value_t<Iterator> const& x = *this->base();
// FIXME(bgruber): x may be a reference to a temporary (e.g. if the base iterator is a counting_iterator). If `m_f`
// does not produce an independent copy and super_t::reference is a reference, we return a dangling reference (e.g.
// `thrust::identity<T>` because it does not forward. `thrust::identity<void>` shoud work).
// for any `[thrust|::cuda::std]::identity` functor).
return m_f(x);
}

Expand Down

0 comments on commit 5af1a55

Please sign in to comment.