Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WegGPU] Initial changes to support wasm64. #21260

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
60cd3db
Initial changes to support wasm64.
satyajandhyala Jul 9, 2024
95d1707
Lint
satyajandhyala Jul 10, 2024
b8d5db7
Add missing files.
satyajandhyala Jul 10, 2024
4726758
Revert changes to conv3d_native_webgpu.ts
satyajandhyala Jul 10, 2024
7001f39
Set ASYNCIFY=1 with MEMORY64.
satyajandhyala Jul 11, 2024
29fc6c0
Added Chrome Canary as a browser option.
satyajandhyala Jul 12, 2024
a6357e5
Use wasm.PTR_SIZE instead of hardcoding to 4.
satyajandhyala Jul 12, 2024
37a6d15
Modified SIGNATURE_CONVERSION for OrtCreateTensor.
satyajandhyala Jul 12, 2024
acc7e5a
Fix data type.
satyajandhyala Jul 12, 2024
3b46edd
Use etValue/etValue instead of directly accessing heap.
satyajandhyala Jul 13, 2024
edcaa64
Use uintptr_t instead of uint32_t.
satyajandhyala Jul 13, 2024
05d0426
Removed WASM_MEMORY64 macro
satyajandhyala Jul 13, 2024
af8a685
Clean up
satyajandhyala Jul 13, 2024
757229d
Fix OrtRun integer arguments type.
satyajandhyala Jul 14, 2024
343f812
Added Number type conversions.
satyajandhyala Jul 16, 2024
7862a8f
Number type conversion.
satyajandhyala Jul 16, 2024
0b4b040
Added ASYNCIFY_IMPORT, and signature convertions.
satyajandhyala Jul 17, 2024
786b58c
Removed unused settings.
satyajandhyala Jul 17, 2024
4523acc
Added missing function in SIGNATURE_CONVERSIONS.
satyajandhyala Jul 17, 2024
4d563c3
clean-up
satyajandhyala Jul 18, 2024
5f504c5
Miscellaneous edits.
satyajandhyala Jul 18, 2024
5e07c97
Use Number cast to jsepRunKernel
satyajandhyala Jul 19, 2024
c8b7d20
Use uint32_t instead of size_t.
satyajandhyala Jul 26, 2024
a31c5de
Revert unnecessary compiler flags.
satyajandhyala Jul 26, 2024
bfbb2d6
Fixed SIGNATURE_CONVERSIONS
satyajandhyala Jul 26, 2024
603013c
minor change
satyajandhyala Jul 26, 2024
ab91914
lint
satyajandhyala Jul 26, 2024
a4fad86
Revert changes to gemm.h
satyajandhyala Jul 26, 2024
11bbf26
Keep static assertion guarded by ifdef.
satyajandhyala Jul 26, 2024
939a740
Switch back to using size_t instead of uint32_t
satyajandhyala Jul 29, 2024
b3aaea9
Specify setValue/getValue type argument 'i32' or 'i64' based on wasm3…
satyajandhyala Jul 29, 2024
07d6e3e
Enable exception catching with wasm64
satyajandhyala Jul 29, 2024
b8bf40d
Convert dims to Number
satyajandhyala Jul 29, 2024
33e6f54
Make Ort api functions return
satyajandhyala Jul 29, 2024
b6604c6
Modified SIGNATURE_CONVERSIONS.
satyajandhyala Jul 29, 2024
3bf0347
Fixed ORT api functions return type consistently
satyajandhyala Jul 30, 2024
8b1be68
Skip jsepDownload
satyajandhyala Jul 30, 2024
3d78ac8
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Jul 31, 2024
ea944ff
Remove unnecessary SIGNATURE_CONVERSIONS.
satyajandhyala Jul 31, 2024
523527e
Check return value
satyajandhyala Jul 31, 2024
cd5ed5c
Lint/format
satyajandhyala Aug 1, 2024
993273b
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 6, 2024
80f3f49
Revert "Remove unnecessary SIGNATURE_CONVERSIONS."
satyajandhyala Aug 7, 2024
26d5dda
Revert "Skip jsepDownload"
satyajandhyala Aug 12, 2024
9f1e6cb
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 14, 2024
e8bc234
User Number convertion.
satyajandhyala Aug 14, 2024
0b70e66
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 14, 2024
d07c3a4
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
satyajandhyala Aug 15, 2024
bcfb312
Added Number conversion
satyajandhyala Aug 16, 2024
8a7ecdb
Formatting and other minor changes.
satyajandhyala Aug 16, 2024
ca37f7f
Merge branch 'main' into sajandhy/webgpu_support_64_bit_integer
satyajandhyala Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t
option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF)

# Enable bitcode for iOS
option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
Expand Down
5 changes: 5 additions & 0 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
string(APPEND CMAKE_C_FLAGS " -DORT_WASM64")
string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64")
endif()

# Build WebAssembly with multi-threads support.
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth")
Expand Down
144 changes: 141 additions & 3 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
re2::re2
)

set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")
set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'")

if (onnxruntime_USE_XNNPACK)
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
Expand All @@ -215,10 +215,109 @@
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set(MAXIMUM_MEMORY "17179869184")
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s MEMORY64=1"
)
string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental")
set(SMEMORY_FLAG "-sMEMORY64")

target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
# target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental)

Check warning on line 247 in cmake/onnxruntime_webassembly.cmake

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "marshalling" is a misspelling of "marshaling" Raw Output: ./cmake/onnxruntime_webassembly.cmake:247:38: "marshalling" is a misspelling of "marshaling"
target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js"
)
else ()
set(MAXIMUM_MEMORY "4294967296")
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js"
)
endif ()

target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
"SHELL:-s MAXIMUM_MEMORY=4294967296"
"SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}"
"SHELL:-s EXIT_RUNTIME=0"
"SHELL:-s ALLOW_MEMORY_GROWTH=1"
"SHELL:-s MODULARIZE=1"
Expand All @@ -231,6 +330,41 @@
--no-entry
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\""
)
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\
OrtRunWithBinding:_ppppp,\
OrtGetTensorData:_ppppp,\
OrtCreateTensor:p_pppp_,\
OrtCreateSession:pppp,\
OrtReleaseSession:_p,\
OrtGetInputOutputCount:_ppp,\
OrtCreateSessionOptions:pp__p_ppppp,\
OrtReleaseSessionOptions:_p,\
OrtAppendExecutionProvider:_pp,\
OrtAddSessionConfigEntry:_ppp,\
OrtGetInputName:ppp,\
OrtGetOutputName:ppp,\
OrtCreateRunOptions:ppp_p,\
OrtReleaseRunOptions:_p,\
OrtReleaseTensor:_p,\
OrtFree:_p,\
OrtCreateBinding:_p,\
OrtBindInput:_ppp,\
OrtBindOutput:_ppp_,\
OrtClearBoundOutputs:_p,\
OrtReleaseBinding:_p,\
OrtGetLastError:_pp,\
JsepOutput:pp_p,\
JsepGetNodeName:pp,\
JsepOutput:pp_p,\
jsepCopy:_pp_,\
jsepCopyAsync:_pp_,\
jsepDownload:_pp_")
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
"SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'"
)
endif ()
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js)

if (onnxruntime_USE_JSEP)
Expand All @@ -243,6 +377,8 @@
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
"SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
Expand Down Expand Up @@ -279,7 +415,9 @@
endif()

# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING)
target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs)
Expand Down
64 changes: 40 additions & 24 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,25 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;

// extract context data
let dataIndex = contextDataOffset >>> 2;
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];
const ptrSize = module.PTR_SIZE;
let dataIndex = contextDataOffset / module.PTR_SIZE;
const type = ptrSize === 4 ? 'i32' : 'i64';
this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type));
const inputCount = Number(module.getValue(ptrSize * dataIndex++, type));
this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type));
this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*'));
this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type));

const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
const dataType = heapU32[dataIndex++];
const data = heapU32[dataIndex++];
const dim = heapU32[dataIndex++];
const dataType = Number(module.getValue(ptrSize * dataIndex++, type));
const data = Number(module.getValue(ptrSize * dataIndex++, '*'));
const dim = Number(module.getValue(ptrSize * dataIndex++, type));
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
dims.push(heapU32[dataIndex++]);
dims.push(Number(module.getValue(ptrSize * dataIndex++, type)));
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
Expand Down Expand Up @@ -142,11 +143,12 @@ class ComputeContextImpl implements ComputeContext {
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
let offset = data >> 2;
this.module.HEAPU32[offset++] = dims.length;
const ptrSize = this.module.PTR_SIZE;
const type = ptrSize === 4 ? 'i32' : 'i64';
const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */);
this.module.setValue(data, dims.length, type);
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
this.module.setValue(data + ptrSize * (i + 1), dims[i], type);
}
return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
Expand Down Expand Up @@ -213,12 +215,19 @@ export const init = async (
// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`,
);
backend.memcpy(Number(src), Number(dst));
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
LOG_DEBUG(
'verbose',
() =>
`[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`,
);
const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size));
backend.upload(Number(dst), data);
}
},

Expand All @@ -229,12 +238,19 @@ export const init = async (
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
await backend.download(Number(gpuDataId), () =>
module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0),
);
},

// jsepCreateKernel
(kernelType: string, kernelId: number, attribute: unknown) =>
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
backend.createKernel(
kernelType,
Number(kernelId),
attribute,
module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))),
),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
Expand All @@ -246,8 +262,8 @@ export const init = async (
() =>
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
const context = new ComputeContextImpl(module, backend, Number(contextDataOffset));
return backend.computeKernel(Number(kernel), context, errors);
},
// jsepCaptureBegin
() => backend.captureBegin(),
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export class ShapeUtil {
'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.',
);
}
size *= dims[i];
size *= Number(dims[i]);
}
return size;
}
Expand Down
13 changes: 6 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ const bucketArr: number[] = [];
/**
* normalize the buffer size so that it fits the 128-bits (16 bytes) alignment.
*/
const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16;
const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16;

/**
* calculate the buffer size so that it fits into buckets.
Expand Down Expand Up @@ -295,9 +295,7 @@ class GpuDataManagerImpl implements GpuDataManager {
LOG_DEBUG(
'verbose',
() =>
`[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
id
}, buffer is the same, skip.`,
`[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, buffer is the same, skip.`,
);
return id;
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
Expand Down Expand Up @@ -358,7 +356,7 @@ class GpuDataManagerImpl implements GpuDataManager {
}

const gpuData = { id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer };
this.storageCache.set(gpuData.id, { gpuData, originalSize: size });
this.storageCache.set(gpuData.id, { gpuData, originalSize: Number(size) });

LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`);
return gpuData;
Expand All @@ -368,7 +366,8 @@ class GpuDataManagerImpl implements GpuDataManager {
return this.storageCache.get(id)?.gpuData;
}

release(id: GpuDataId): number {
release(idInput: GpuDataId): number {
const id = typeof idInput === 'bigint' ? Number(idInput) : idInput;
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('releasing data does not exist');
Expand All @@ -384,7 +383,7 @@ class GpuDataManagerImpl implements GpuDataManager {
}

async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise<void> {
const cachedData = this.storageCache.get(id);
const cachedData = this.storageCache.get(Number(id));
if (!cachedData) {
throw new Error('data does not exist');
}
Expand Down
Loading
Loading