Skip to content

Commit

Permalink
Initial changes to support wasm64.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Jul 10, 2024
1 parent 4c3c809 commit 3ea20e9
Show file tree
Hide file tree
Showing 25 changed files with 188 additions and 79 deletions.
3 changes: 2 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t
option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF)

# Enable bitcode for iOS
option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
Expand Down Expand Up @@ -241,7 +242,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF)

# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set
option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON)
cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF)
option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON)
option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF)
option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF)
option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF)
Expand Down
15 changes: 13 additions & 2 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,26 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING)
string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions")
string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions")
else()
string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
endif()
endif()

# Build WebAssembly with multi-threads support.
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth")
string(APPEND CMAKE_CXX_FLAGS " -pthread -Wno-pthreads-mem-growth")
endif()

# Build WebAssembly with 64bit support.
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental")
endif()
endif()

if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH)
Expand Down
74 changes: 67 additions & 7 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,11 @@ else()
"${ONNXRUNTIME_ROOT}/wasm/api.cc"
"${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc"
)
set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0")
message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set")
set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING})
if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0")
message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set")
set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING})
endif()
endif()

target_link_libraries(onnxruntime_webassembly PRIVATE
Expand All @@ -193,7 +195,7 @@ else()
re2::re2
)

set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")
set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'")

if (onnxruntime_USE_XNNPACK)
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
Expand All @@ -215,10 +217,55 @@ else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set(ASYNCIFY 2)
set(MAXIMUM_MEMORY "17179869184")
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s MEMORY64=1"
)
string(APPEND CMAKE_C_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental")
set(SMEMORY_FLAG "-sMEMORY64")

target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
# target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)

target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js"
)
else ()
set(ASYNCIFY 1)
set(MAXIMUM_MEMORY "4294967296")
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js"
)
endif ()

target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
"SHELL:-s MAXIMUM_MEMORY=4294967296"
"SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}"
"SHELL:-s EXIT_RUNTIME=0"
"SHELL:-s ALLOW_MEMORY_GROWTH=1"
"SHELL:-s MODULARIZE=1"
Expand All @@ -231,6 +278,12 @@ else()
--no-entry
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\""
)
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
"SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p"
)
endif ()
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js)

if (onnxruntime_USE_JSEP)
Expand All @@ -241,8 +294,11 @@ else()
target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1)
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY=${ASYNCIFY}"
#"SHELL:-s JSPI"
#"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
Expand Down Expand Up @@ -279,7 +335,11 @@ else()
endif()

# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions")
else()
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING)
target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs)
Expand Down
8 changes: 4 additions & 4 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_JSEP) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
registerBackend('webgpu', wasmBackend, 1);
registerBackend('webnn', wasmBackend, 1);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('cpu', wasmBackend, 1);
registerBackend('wasm', wasmBackend, 1);
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ export class WebGpuBackend {
maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,
maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ,
maxBindingsPerBindGroup: adapter.limits.maxBindingsPerBindGroup,
},
requiredFeatures,
};
Expand Down Expand Up @@ -449,7 +450,7 @@ export class WebGpuBackend {
const isPersistent = validatedOutputIndices[i] === -2;
const tensorView = (isTemporary || isPersistent) ?
createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
createKernelOutput(outputs[i].outputIndex || validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
outputTensorViews.push(tensorView);
// if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
if (tensorView.data === 0) {
Expand Down
43 changes: 24 additions & 19 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import {Env} from 'onnxruntime-common';

import type {OrtWasmModule} from '../wasm-types';
import {DataType, getTensorElementSize} from '../wasm-common';
import type {OrtWasmModule} from '../wasm-types';

import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
Expand Down Expand Up @@ -68,24 +68,24 @@ class ComputeContextImpl implements ComputeContext {
private customDataSize = 0;
constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;
const heap = module.PTR_SIZE === 4 ? module.HEAPU32 : module.HEAPU64;

// extract context data
let dataIndex = (contextDataOffset >>> 2);
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];
let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2);
this.opKernelContext = Number(heap[dataIndex++]);
const inputCount = Number(heap[dataIndex++]);
this.outputCount = Number(heap[dataIndex++]);
this.customDataOffset = Number(heap[dataIndex++]);
this.customDataSize = Number(heap[dataIndex++]);

const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
const dataType = heapU32[dataIndex++];
const data = heapU32[dataIndex++];
const dim = heapU32[dataIndex++];
const dataType = Number(heap[dataIndex++]);
const data = Number(heap[dataIndex++]);
const dim = Number(heap[dataIndex++]);
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
dims.push(heapU32[dataIndex++]);
dims.push(Number(heap[dataIndex++]));
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
Expand Down Expand Up @@ -127,11 +127,11 @@ class ComputeContextImpl implements ComputeContext {
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
let offset = data >> 2;
this.module.HEAPU32[offset++] = dims.length;
const ptrSize = this.module.PTR_SIZE;
const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */);
this.module.setValue(data, dims.length, '*');
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
this.module.setValue(data + ptrSize * (i + 1), dims[i], '*');
}
return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
Expand Down Expand Up @@ -193,10 +193,15 @@ export const init =
// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${
Number(size)}`);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
}
Expand Down Expand Up @@ -226,7 +231,7 @@ export const init =
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
const context = new ComputeContextImpl(module, backend, Number(contextDataOffset));
return backend.computeKernel(kernel, context, errors);
},
// jsepCaptureBegin
Expand Down
7 changes: 4 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ const bucketArr: number[] = [];
/**
* normalize the buffer size so that it fits the 128-bits (16 bytes) alignment.
*/
const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16;
const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16;

/**
* calculate the buffer size so that it fits into buckets.
Expand Down Expand Up @@ -342,7 +342,7 @@ class GpuDataManagerImpl implements GpuDataManager {
}

const gpuData = {id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer};
this.storageCache.set(gpuData.id, {gpuData, originalSize: size});
this.storageCache.set(gpuData.id, {gpuData, originalSize: Number(size)});

LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`);
return gpuData;
Expand All @@ -352,7 +352,8 @@ class GpuDataManagerImpl implements GpuDataManager {
return this.storageCache.get(id)?.gpuData;
}

release(id: GpuDataId): number {
release(idInput: GpuDataId): number {
const id = typeof idInput === 'bigint' ? Number(idInput) : idInput;
const cachedData = this.storageCache.get(id);
if (!cachedData) {
throw new Error('releasing data does not exist');
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface GpuData {
export interface TensorInfo {
dims: readonly number[];
dataType: number;
outputIndex?: number;
}

export interface ProgramUniform {
Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,15 @@ export interface OrtTrainingAPIs {
*/
export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial<OrtTrainingAPIs>,
Partial<JSEP.Module> {
HEAP64: BigInt64Array;
HEAPU64: BigUint64Array;
PTR_SIZE: number;
// #region emscripten functions
stackSave(): number;
stackRestore(stack: number): void;
stackAlloc(size: number): number;
getValue(ptr: number, type: string): number;
setValue(ptr: number, value: number, type: string): void;

UTF8ToString(offset: number, maxBytesToRead?: number): string;
lengthBytesUTF8(str: string): number;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo
try {
// Copy the file data (fileData,offset,length) into WebAssembly memory
// (HEAPU8,buffer,length).
HEAPU8.set(fileData.subarray(offset, offset + length), buffer);
HEAPU8.set(fileData.subarray(Number(offset), Number(offset) + length), buffer);
return 0;
} catch {
return 4;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ static Status SaveModel(Model& model, const T& file_path) {
const buffer_size = $1;
const file_path = UTF8ToString($2);
const bytes = new Uint8Array(buffer_size);
bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size));
bytes.set(HEAPU8.subarray(Number(buffer), Number(buffer) + buffer_size));
if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') {
// Node.js
require('fs').writeFileSync(file_path, bytes);
Expand Down
Loading

0 comments on commit 3ea20e9

Please sign in to comment.