Skip to content

Commit

Permalink
Distributed Reduction (microsoft#18206)
Browse files Browse the repository at this point in the history
This PR implements distributed reduciton for llama 2. This version
doesn't consider any cases requring re-sharding because we haven't seen
any use cases.

Intutive examples:
- [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] ->
Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and
device_mesh=[0,1]
- [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] ->
Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and
device_mesh=[0,1]
- [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1]
-> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and
device_mesh=[0,1]

Algorithm:
When the reduced axes are not sharded, each device can call reduction
directly. The output sharding spec will be identical to input sharding
spec. We currently throw when input and output sharding specs are
different.

Review guideline:
- Check 97b8d2f for new op's schema and how new op is registered.
- Read tests in 2450f93 to get faimilar with the behavior of these ops.
- Check the implementation details in 753d9af.
  • Loading branch information
wschin authored Nov 1, 2023
1 parent d87216b commit 9e8ad39
Show file tree
Hide file tree
Showing 8 changed files with 638 additions and 108 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc")
endif()

set(provider_excluded_files
Expand Down
175 changes: 175 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_reduce.h"
#include "sharding.h"
#include "sharding_spec.h"
#include "nccl_kernels.h"
#include "mpi_include.h"

// ORT system.
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/reduction/reduction_ops.h"

// std C++.
#include <iostream>

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
DistributedReduceBase<T>::DistributedReduceBase(
const OpKernelInfo& info,
cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = cudnn_reduce_op;
};

template <typename T>
Status DistributedReduceBase<T>::ComputeInternal(OpKernelContext* context) const {
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& axes_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

ORT_ENFORCE(axes_sharding_spec.HasNoShard(),
"It's not worthy to shard axes tensor. "
"If sharding axes is needed, please submit a feature request.");

const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as an identity.
if (axes_span.empty()) {
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed.");
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.

bool sharding_on_reduced_axes = false;
for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) {
if (*axis_it == input_sharding_spec.GetPartitionAxis()) {
sharding_on_reduced_axes = true;
break;
}
}

if (sharding_on_reduced_axes) {
// Case 2-1: sharding on reduced axes.
ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica.");
} else {
// Case 2-2: sharding on passing-through axes or no shard.
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed.");
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata));
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}
return Status::OK();
}

template <typename T>
DistributedReduceSum<T>::DistributedReduceSum(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_ADD){};

template <typename T>
DistributedReduceMean<T>::DistributedReduceMean(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_AVG){};

template <typename T>
DistributedReduceMax<T>::DistributedReduceMax(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_MAX){};

// ReduceSum
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceSum<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceSum<MLFloat16>);

// ReduceMean
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMean<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMean<MLFloat16>);

// ReduceMax
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMax<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMax<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
59 changes: 59 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"
#include "sharding.h"
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <nccl.h>
#include <sstream>

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedReduceBase : public DistributedKernel {
public:
explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

template <typename T>
class DistributedReduceSum final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceSum(const OpKernelInfo& info);
};

template <typename T>
class DistributedReduceMean final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceMean(const OpKernelInfo& info);
};

template <typename T>
class DistributedReduceMax final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceMax(const OpKernelInfo& info);
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean);
#endif

template <>
Expand Down Expand Up @@ -354,6 +363,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean)>,
#endif

};
Expand Down
Loading

0 comments on commit 9e8ad39

Please sign in to comment.