Skip to content

Commit

Permalink
[reland] "[reland] _foreach_copy with different src/dst dtypes" (py…
Browse files Browse the repository at this point in the history
…torch#127186)

Fixes pytorch#115171

Pull Request resolved: pytorch#127186
Approved by: https://github.com/ezyang
  • Loading branch information
crcrpar authored and petrex committed Jun 5, 2024
1 parent 120573a commit bafa43c
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 16 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ inline void check_foreach_api_restrictions(
// corresponding tensors (aligning in index across the tensorLists) share the
// same device and dtype.
inline bool _check_tensors_share_device_and_dtype(
ArrayRef<TensorList> tensorLists) {
ArrayRef<TensorList> tensorLists,
const bool skip_dtype_check = false) {
const auto expected_dtype = tensorLists[0][0].dtype();
const auto expected_device = tensorLists[0][0].device();

auto is_tensor_okay = [&](const Tensor& tensor) {
return tensor.dtype() == expected_dtype &&
return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
tensor.is_non_overlapping_and_dense();
};
Expand Down
188 changes: 174 additions & 14 deletions aten/src/ATen/native/cuda/ForeachBinaryOpList.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
#include <functional>
#include <type_traits>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -252,20 +254,156 @@ FOREACH_BINARY_OP_LIST(
power_functor,
/*division_op*/ true);

template <typename T>
struct Identity {
__device__ __forceinline__ T operator()(const T& x) {
return x;
template <typename dst_t, typename src_t = dst_t>
struct Copy {
__device__ __forceinline__ dst_t operator()(const src_t& x) {
return static_cast<dst_t>(x);
}
};

template <typename dst_t>
struct Copy<dst_t, c10::complex<double>> {
__device__ __forceinline__ dst_t operator()(const c10::complex<double>& x) {
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
std::is_same_v<dst_t, c10::complex<float>>)) {
return static_cast<dst_t>(x.real());
} else {
return static_cast<dst_t>(x);
}
}
};

template <typename dst_t>
struct Copy<dst_t, c10::complex<float>> {
__device__ __forceinline__ dst_t operator()(const c10::complex<float>& x) {
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
std::is_same_v<dst_t, c10::complex<float>>)) {
return static_cast<dst_t>(x.real());
} else {
return static_cast<dst_t>(x);
}
}
};

#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Byte, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Char, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Long, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Short, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Int, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Double, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Float, src_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::ComplexDouble, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::ComplexFloat, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Half, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::BFloat16, \
src_t, \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE_USING_HINT( \
at::ScalarType::Bool, \
src_t, \
__VA_ARGS__))

namespace {

template <
typename T,
typename src_t,
int depth,
int r_args_depth,
int res_arg_index>
struct CopyFunctor {
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];

src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc];
src_ptr += chunk_idx * chunk_size;
T* self_ptr = (T*)tl.addresses[1][tensor_loc];
self_ptr += chunk_idx * chunk_size;

const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)};

n -= chunk_idx * chunk_size;
src_t src_args[kILP];
T r_args[kILP];

// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = threadIdx.x;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(src_args, src_ptr, 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[ii] = static_cast<T>(op(src_args[ii]));
}
// store
load_store(self_ptr, r_args, i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * kILP) {
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
const auto i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
src_args[ii] = src_ptr[i];
}
}
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[ii] = static_cast<T>(op(src_args[ii]));
}
store_args(self_ptr, r_args, i_start, chunk_size, n);
}
}
}
};

} // anonymous namespace

void foreach_tensor_copy_list_kernel_cuda_(
TensorList self,
TensorList src,
const bool non_blocking) {
check_foreach_api_restrictions(self, src);
if (!can_use_fast_route(
self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
if (!(_check_tensors_share_device_and_dtype(
{self, src}, /* skip_dtype_check */ true) &&
std::all_of(
src.cbegin(),
src.cend(),
[&](const auto& t) -> bool {
return t.dtype() == src[0].dtype();
}) &&
_check_tensors_share_sizes_and_strides({self, src}))) {
return at::native::foreach_tensor_copy_list_kernel_slow_(
self, src, non_blocking);
}
Expand All @@ -280,16 +418,38 @@ void foreach_tensor_copy_list_kernel_cuda_(
"foreach_tensor_copy",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Identity<opmath_t>());
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
if constexpr (std::is_same_v<scalar_t, src_t>) {
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<opmath_t, opmath_t>());
} else {
// Ref:
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
if (!self[0].is_complex() && src[0].is_complex()) {
TORCH_WARN_ONCE(
"Casting complex values to real discards the imaginary part");
}
multi_tensor_apply<2>(
tensor_lists,
CopyFunctor<
scalar_t,
src_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<scalar_t, src_t>());
}
});
});
increment_version(self);
}

#undef AT_DISPATCH_SOURCE_TYPES

} // namespace at::native
22 changes: 22 additions & 0 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,28 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
copy_(t, s, non_blocking)
self.assertEqual(ref_input, sample.input)

@onlyCUDA
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
for sample in op.sample_inputs(device, dtype, noncontiguous=False):
for src_dtype in floating_types_and(torch.half, torch.bfloat16):
if src_dtype == dtype:
continue
self_tensors = [t.clone() for t in sample.input]
src_tensors = [t.to(src_dtype) for t in self_tensors]
out = foreach_copy_(
(self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
)
self.assertEqual(
out,
[
torch.empty_like(t).copy_(s)
for t, s in zip(self_tensors, src_tensors)
],
)

# Test reverse-mode & forward-mode AD if supported.
@onlyCUDA
@ops(
Expand Down

0 comments on commit bafa43c

Please sign in to comment.