Skip to content

Commit

Permalink
Merge branch 'main' into yifanl/trt_ver_update_1
Browse files Browse the repository at this point in the history
  • Loading branch information
yf711 committed Oct 26, 2024
2 parents 755b9cd + 008c909 commit 138245a
Show file tree
Hide file tree
Showing 25 changed files with 152 additions and 81 deletions.
69 changes: 39 additions & 30 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,50 @@ if (onnxruntime_USE_ROCM)
message(FATAL_ERROR "ROCM does not support build with CUDA!")
endif()

# replicate strategy used by pytorch to get ROCM_VERSION
# https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake
# with modification
if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version")
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n")
file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h")
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm_version.h ****\n")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h")
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h ****\n")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
endif()

if (ROCM_VERSION_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")

message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}")
else()
message(FATAL_ERROR "Cannot determine ROCm version string")
endif()


if (NOT CMAKE_HIP_COMPILER)
set(CMAKE_HIP_COMPILER "${onnxruntime_ROCM_HOME}/llvm/bin/clang++")
endif()

if (NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201")
if (ROCM_VERSION_DEV VERSION_LESS "6.2")
message(FATAL_ERROR "CMAKE_HIP_ARCHITECTURES is not set when ROCm version < 6.2")
else()
set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201")
endif()
endif()

file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*)
Expand Down Expand Up @@ -328,35 +366,6 @@ if (onnxruntime_USE_ROCM)
set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl)
endif()

# replicate strategy used by pytorch to get ROCM_VERSION
# https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake
# with modification
if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version")
file(READ "${onnxruntime_ROCM_HOME}/.info/version" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h")
file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW)
string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW})
endif()

if (ROCM_VERSION_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}")
else()
message(FATAL_ERROR "Cannot determine ROCm version string")
endif()
message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version ****\n")
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")
message("ROCM_VERSION_DEV_INT: ${ROCM_VERSION_DEV_INT}")
message("\n***** HIP LANGUAGE CONFIG INFO ****\n")
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
message("CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}")
Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/Dockerfile.migraphx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Dockerfile to run ONNXRuntime with MIGraphX integration
#--------------------------------------------------------------------------

FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0

ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=main
Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Dockerfile to run ONNXRuntime with ROCm integration
#--------------------------------------------------------------------------

FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
FROM rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0

ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=main
Expand Down
4 changes: 2 additions & 2 deletions dockerfiles/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image
Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropriate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime).
## MIGraphX
**Ubuntu 20.04, ROCm6.0, MIGraphX**
**Ubuntu 22.04, ROCm6.2.3, MIGraphX**
1. Build the docker image from the Dockerfile in this repository.
```
Expand All @@ -306,7 +306,7 @@ Note: When running the container you built in Docker, please either use 'nvidia-
```
## ROCm
**Ubuntu 20.04, ROCm6.0**
**Ubuntu 22.04, ROCm6.2.3**
1. Build the docker image from the Dockerfile in this repository.
```
Expand Down
4 changes: 3 additions & 1 deletion js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ export class Tensor implements TensorInterface {
type !== 'uint64' &&
type !== 'int8' &&
type !== 'uint8' &&
type !== 'bool'
type !== 'bool' &&
type !== 'uint4' &&
type !== 'int4'
) {
throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`);
}
Expand Down
4 changes: 3 additions & 1 deletion js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ export declare namespace Tensor {
| 'uint32'
| 'int64'
| 'uint64'
| 'bool';
| 'bool'
| 'uint4'
| 'int4';

/**
* represent where the tensor data is stored
Expand Down
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.uint32, 'uint32'],
[DataType.int64, 'int64'],
[DataType.uint64, 'uint64'],
[DataType.int4, 'int4'],
[DataType.uint4, 'uint4'],
[DataType.int8, 'int8'],
[DataType.uint8, 'uint8'],
[DataType.bool, 'uint8'],
Expand Down Expand Up @@ -214,6 +216,8 @@ export class WebNNBackend {
case 'int8':
bufferView = new Int8Array(buffer);
break;
case 'int4':
case 'uint4':
case 'uint8':
bufferView = new Uint8Array(buffer);
break;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ interface MLContext {
}
interface MLGraph {}
type MLInputOperandLayout = 'nchw'|'nhwc';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8';
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'|'int4'|'uint4';
interface MLOperandDescriptor {
dataType: MLOperandDataType;
shape?: readonly number[];
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTen
type === 'uint64' ||
type === 'int8' ||
type === 'uint8' ||
type === 'bool';
type === 'bool' ||
type === 'uint4' ||
type === 'int4';

/**
* Map string data location to integer value
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice));
HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
}
} else if (src_device.Type() == OrtDevice::GPU) {
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1445,7 +1445,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
std::vector<int64_t> ort_shape{res_lens.begin(), res_lens.end()};
auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size());
void* output_data = output_tensor.GetTensorMutableRawData();
HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice));
HIP_CALL_THROW(hipMemcpyWithStream(output_data,
gpu_res.data(),
res_shape.bytes(),
hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(rocm_stream)));
}
}
};
Expand Down
29 changes: 19 additions & 10 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,24 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const We
}
}

bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
const auto& input_name = input.Name();
const auto* shape_proto = input.Shape();
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger) {
const auto& node_arg_name = node_arg.Name();
const auto* shape_proto = node_arg.Shape();
// Optional tensors can be indicated by an empty name, just ignore it.
if (input_name.empty()) {
if (node_arg_name.empty()) {
return true;
}
// We do not support input with no shape.
// We do not support input/output with no shape.
if (!shape_proto) {
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name
<< "] has not shape";
LOGS(logger, VERBOSE) << "Node arg [" << node_arg_name << "] of [" << parent_name << "] has not shape";
return false;
}

for (const auto& dim : shape_proto->dim()) {
// WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape.
if (!dim.has_dim_value()) {
LOGS(logger, VERBOSE) << "Dynamic shape is not supported, "
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: "
<< input_name;
<< "use sessionOptions.FreeDimensionOverrides to set a fixed shape: " << node_arg_name;
return false;
}
}
Expand All @@ -104,7 +102,12 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
std::vector<std::vector<size_t>> supported_node_groups;

for (const auto* input : graph_viewer.GetInputs()) {
if (!IsInputSupported(*input, "graph", logger)) {
if (!IsTensorShapeSupported(*input, "graph", logger)) {
return supported_node_groups;
}
}
for (const auto* output : graph_viewer.GetOutputs()) {
if (!IsTensorShapeSupported(*output, "graph", logger)) {
return supported_node_groups;
}
}
Expand Down Expand Up @@ -226,6 +229,12 @@ bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
desc.set("dataType", emscripten::val("int4"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
desc.set("dataType", emscripten::val("uint4"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("dataType", emscripten::val("uint8"));
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
Expand Down Expand Up @@ -303,6 +303,8 @@ inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_typ
}

static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
{ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"},
{ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"},
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedInputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
if (!HasSupportedOutputs(node, wnn_limits, logger))
return false;

if (!HasSupportedOpSet(node, logger))
Expand All @@ -47,7 +47,7 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsInputSupported(*input, node_name, logger)) {
if (!IsTensorShapeSupported(*input, node_name, logger)) {
return false;
}
}
Expand All @@ -68,6 +68,18 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* output : node.OutputDefs()) {
if (!IsTensorShapeSupported(*output, node_name, logger)) {
return false;
}
}

return HasSupportedOutputsImpl(node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class BaseOpBuilder : public IOpBuilder {
private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};

} // namespace webnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
std::string operand_type;
switch (to_type) {
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
operand_type = "int4";
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
operand_type = "uint4";
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
operand_type = "uint8";
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
Expand Down Expand Up @@ -93,6 +95,8 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
Expand Down Expand Up @@ -210,6 +214,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = input_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
break;
Expand Down Expand Up @@ -245,6 +251,8 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = output_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
break;
Expand Down
Loading

0 comments on commit 138245a

Please sign in to comment.