Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of WebGPU EP #22591

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algor
option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF)
option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF)
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
snnn marked this conversation as resolved.
Show resolved Hide resolved

# Options related to reducing the binary size produced by the build
# XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON
Expand Down Expand Up @@ -948,6 +949,9 @@ if (onnxruntime_USE_WEBGPU)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu)
if (onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
Expand Down
7 changes: 6 additions & 1 deletion cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,16 @@ if (onnxruntime_USE_WEBGPU)

# Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it.
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
# We are currently always using the D3D12 backend.
set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE)
snnn marked this conversation as resolved.
Show resolved Hide resolved
endif()

onnxruntime_fetchcontent_makeavailable(dawn)

list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc)
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
endif()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()

set(onnxruntime_LINK_DIRS)
Expand Down
5 changes: 4 additions & 1 deletion cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc)
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)
endif()
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
12 changes: 12 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ set (onnxruntime_global_thread_pools_test_SRC
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc)

set (onnxruntime_webgpu_external_dawn_test_SRC
${TEST_SRC_DIR}/webgpu/external_dawn/main.cc)

# tests from lowest level library up.
# the order of libraries should be maintained, with higher libraries being added first in the list

Expand Down Expand Up @@ -1884,4 +1887,13 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
endif()
endif()

if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN)
AddTest(TARGET onnxruntime_webgpu_external_dawn_test
SOURCES ${onnxruntime_webgpu_external_dawn_test_SRC}
LIBS dawn::dawn_native ${onnxruntime_test_providers_libs}
DEPENDS ${all_dependencies}
)
onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers)
endif()

include(onnxruntime_fuzz_test.cmake)
45 changes: 30 additions & 15 deletions cmake/patches/dawn/dawn.patch
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,55 @@ index 9c0bd6fa4e..bf8a57aeac 100644
###############################################################################
# Do the 'complete_lib' build.
diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm
index ce55acbd43..baa4835362 100644
index ce55acbd43..2cfd363479 100644
--- a/src/dawn/native/Surface_metal.mm
+++ b/src/dawn/native/Surface_metal.mm
@@ -36,7 +36,13 @@
@@ -33,10 +33,18 @@

#import <QuartzCore/CAMetalLayer.h>

+#include "dawn/common/Platform.h"
+
namespace dawn::native {

bool InheritsFromCAMetalLayer(void* obj) {
- id<NSObject> object = static_cast<id>(obj);
+ id<NSObject> object =
+#if TARGET_OS_IOS
+#if DAWN_PLATFORM_IS(IOS)
+ (__bridge id)obj;
+#else
+#else // DAWN_PLATFORM_IS(IOS)
+ static_cast<id>(obj);
+#endif
+#endif // DAWN_PLATFORM_IS(IOS)
+
return [object isKindOfClass:[CAMetalLayer class]];
}

diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm
index bde8bfea07..f2f6459e91 100644
index bde8bfea07..8906185d6f 100644
--- a/src/dawn/native/metal/SharedFenceMTL.mm
+++ b/src/dawn/native/metal/SharedFenceMTL.mm
@@ -40,7 +40,13 @@ ResultOrError<Ref<SharedFence>> SharedFence::Create(
@@ -25,6 +25,8 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+#include "dawn/common/Platform.h"
+
#include "dawn/native/metal/SharedFenceMTL.h"

#include "dawn/native/ChainUtils.h"
@@ -39,8 +41,13 @@ ResultOrError<Ref<SharedFence>> SharedFence::Create(
const SharedFenceMTLSharedEventDescriptor* descriptor) {
DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing.");
if (@available(macOS 10.14, iOS 12.0, *)) {
return AcquireRef(new SharedFence(
- return AcquireRef(new SharedFence(
- device, label, static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)));
+ device, label,
+#if TARGET_OS_IOS
+ (__bridge id<MTLSharedEvent>)(descriptor->sharedEvent)
+#else
+ static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)
+#endif
+ ));
+ return AcquireRef(new SharedFence(device, label,
+#if DAWN_PLATFORM_IS(IOS)
+ (__bridge id<MTLSharedEvent>)(descriptor->sharedEvent)
+#else // DAWN_PLATFORM_IS(IOS)
+ static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)
+#endif // DAWN_PLATFORM_IS(IOS)
+ ));
} else {
return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported.");
}
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
#include "contrib_ops/webgpu/bert/fast_gelu.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
FastGelu,

Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:15: Do not indent within a namespace. [whitespace/indent_namespace] [4]
kMSDomain,

Check warning on line 16 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:16: Do not indent within a namespace. [whitespace/indent_namespace] [4]
1,

Check warning on line 17 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:17: Do not indent within a namespace. [whitespace/indent_namespace] [4]
kWebGpuExecutionProvider,

Check warning on line 18 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:18: Do not indent within a namespace. [whitespace/indent_namespace] [4]
(*KernelDefBuilder::Create())

Check warning on line 19 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:19: Do not indent within a namespace. [whitespace/indent_namespace] [4]
.TypeConstraint("T", WebGpuSupportedFloatTypes()),

Check warning on line 20 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:20: Do not indent within a namespace. [whitespace/indent_namespace] [4]
FastGelu);

Check warning on line 21 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:21: Do not indent within a namespace. [whitespace/indent_namespace] [4]

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform);

shader.AdditionalImplementation() << TanhImpl;
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
<< " var a = " << x.GetByOffset("global_idx") << ";\n";
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
if (bias_components_ == 1) {
shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n"
" a += x_value_t("
<< bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n";
} else {
shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
}
}
shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr);

return Status::OK();
}

Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
auto* output = context.Output(0, input->Shape());

uint32_t data_size = gsl::narrow<uint32_t>(output->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const auto vec_size = (data_size + 3) / 4;
uint32_t bias_size = 0;
int bias_components = 1;

if (bias != nullptr) {
bias_size = gsl::narrow<uint32_t>(bias->Shape().Size());
if (bias_size % 4 == 0) {
bias_components = 4;
bias_size = bias_size / 4;
}
}

FastGeluProgram program{bias_components};
program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariable({vec_size});

if (bias != nullptr) {
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components});
}
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class FastGeluProgram final : public Program<FastGeluProgram> {
public:
FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
int bias_components_;
};

class FastGelu final : public WebGpuKernel {
public:
FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
36 changes: 36 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/nn/layer_norm.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 14 in onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc:14: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
LayerNormalization,

Check warning on line 18 in onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc:18: Do not indent within a namespace. [whitespace/indent_namespace] [4]
kOnnxDomain,
1,
16,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<false>);

ONNX_OPERATOR_KERNEL_EX(
SimplifiedLayerNormalization,
kOnnxDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<true>);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading
Loading