Skip to content

Commit

Permalink
[WebNN EP] Add labels for all WebNN operators (#21516)
Browse files Browse the repository at this point in the history
In order to provide more diagnosable error messages for developers.

Spec change: webmachinelearning/webnn#742
  • Loading branch information
Honry authored Jul 29, 2024
1 parent 5bc12bf commit 94eb70d
Show file tree
Hide file tree
Showing 30 changed files with 180 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

NodeAttrHelper helper(node);
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
if (op_type == "Elu") {
options.set("alpha", helper.Get("alpha", 1.0f));
output = model_builder.GetBuilder().call<emscripten::val>("elu", input, options);
Expand All @@ -46,20 +47,20 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("beta", helper.Get("beta", 0.5f));
output = model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", input, options);
} else if (op_type == "HardSwish") {
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input);
output = model_builder.GetBuilder().call<emscripten::val>("hardSwish", input, options);
} else if (op_type == "LeakyRelu") {
options.set("alpha", helper.Get("alpha", 0.0f));
output = model_builder.GetBuilder().call<emscripten::val>("leakyRelu", input, options);
} else if (op_type == "Relu") {
output = model_builder.GetBuilder().call<emscripten::val>("relu", input);
output = model_builder.GetBuilder().call<emscripten::val>("relu", input, options);
} else if (op_type == "Sigmoid") {
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input);
output = model_builder.GetBuilder().call<emscripten::val>("sigmoid", input, options);
} else if (op_type == "Softplus") {
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input);
output = model_builder.GetBuilder().call<emscripten::val>("softplus", input, options);
} else if (op_type == "Softsign") {
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input);
output = model_builder.GetBuilder().call<emscripten::val>("softsign", input, options);
} else if (op_type == "Tanh") {
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input);
output = model_builder.GetBuilder().call<emscripten::val>("tanh", input, options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("keepDimensions", keep_dims == 1);
// TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API.
options.set("outputDataType", "int64");
options.set("label", node.Name());
emscripten::val output = emscripten::val::object();

const auto& op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,21 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name());
emscripten::val output = emscripten::val::object();
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

if (op_type == "Add") {
output = model_builder.GetBuilder().call<emscripten::val>("add", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("add", input0, input1, options);
} else if (op_type == "Sub") {
output = model_builder.GetBuilder().call<emscripten::val>("sub", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("sub", input0, input1, options);
} else if (op_type == "Mul") {
output = model_builder.GetBuilder().call<emscripten::val>("mul", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("mul", input0, input1, options);
} else if (op_type == "Div") {
output = model_builder.GetBuilder().call<emscripten::val>("div", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("div", input0, input1, options);
} else if (op_type == "Pow") {
output = model_builder.GetBuilder().call<emscripten::val>("pow", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("pow", input0, input1, options);
} else if (op_type == "PRelu") {
output = model_builder.GetBuilder().call<emscripten::val>("prelu", input0, input1);
output = model_builder.GetBuilder().call<emscripten::val>("prelu", input0, input1, options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,11 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
node.Name(), " type: ", to_type);
}

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("cast", input, emscripten::val(operand_type));
model_builder.GetBuilder().call<emscripten::val>("cast", input, emscripten::val(operand_type), options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
"GetClipMinMax failed");
options.set("minValue", minValue);
options.set("maxValue", maxValue);
options.set("label", node.Name());
emscripten::val input = model_builder.GetOperand(input_name);
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("clamp", input, options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
inputs.push_back(model_builder.GetOperand(input->Name()));
}

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("concat", emscripten::val::array(inputs), axis);
model_builder.GetBuilder().call<emscripten::val>("concat", emscripten::val::array(inputs), axis, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
}

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
ORT_RETURN_IF_ERROR(SetConvBaseOptions(
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
bool depthwise = false;
Expand Down Expand Up @@ -276,7 +277,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (!is_nhwc || !is_constant_weight) {
// The weight_shape has been appended 1's, reshape weight operand.
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
filter = model_builder.GetBuilder().call<emscripten::val>("reshape", filter, emscripten::val::array(new_shape));
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
filter,
emscripten::val::array(new_shape),
reshape_options);
}
}

Expand All @@ -293,6 +299,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
perm = {0, 2, 3, 1}; // L_0231
}
transpose_options.set("permutation", emscripten::val::array(perm));
transpose_options.set("label", node.Name() + "_transpose_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("transpose", filter, transpose_options);
}

Expand Down Expand Up @@ -323,7 +330,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
std::vector<int64_t> output_shape;
ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape");
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(output_shape);
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(new_shape));
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_output");
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output,
emscripten::val::array(new_shape),
reshape_options);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,22 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil
std::vector<int32_t> target_shape{static_cast<int>(input_shape[axis])};
target_shape.insert(target_shape.begin(), axis, 1);
target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1);
scale = model_builder.GetBuilder().call<emscripten::val>("reshape", scale, emscripten::val::array(target_shape));
emscripten::val reshape_scale_options = emscripten::val::object();
reshape_scale_options.set("label", node.Name() + "_reshape_scale");
scale = model_builder.GetBuilder().call<emscripten::val>("reshape",
scale,
emscripten::val::array(target_shape),
reshape_scale_options);
emscripten::val reshape_zero_point_options = emscripten::val::object();
reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point");
zero_point = model_builder.GetBuilder().call<emscripten::val>("reshape",
zero_point, emscripten::val::array(target_shape));
zero_point,
emscripten::val::array(target_shape),
reshape_zero_point_options);
}
output = model_builder.GetBuilder().call<emscripten::val>("dequantizeLinear", input, scale, zero_point);
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
output = model_builder.GetBuilder().call<emscripten::val>("dequantizeLinear", input, scale, zero_point, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

output_array = model_builder.GetBuilder().call<emscripten::val>("dynamicQuantizeLinear", input);
output_array = model_builder.GetBuilder().call<emscripten::val>("dynamicQuantizeLinear", input, options);

for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::vector<int64_t> output_shape;
ORT_RETURN_IF_NOT(GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape), "Cannot get output shape.");

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());

emscripten::val output =
model_builder.GetBuilder().call<emscripten::val>("expand",
input,
emscripten::val::array(GetVecUint32FromVecInt64(output_shape)));
emscripten::val::array(GetVecUint32FromVecInt64(output_shape)),
options);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
SafeInt<uint32_t>(num_post_axis_elements)};

emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", inputs, emscripten::val::array(new_shape));
"reshape", inputs, emscripten::val::array(new_shape), options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val indices = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val options = emscripten::val::object();
options.set("axis", axis);
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("gather", input, indices, options);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
Expand Down
39 changes: 31 additions & 8 deletions onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name());
emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name());
emscripten::val output = emscripten::val::object();
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
if (op_type == "MatMul") {
std::vector<int64_t> a_shape;
if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
Expand All @@ -53,23 +55,34 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
if (a_shape.size() == 1) {
extended_a_shape = true;
a_shape.insert(a_shape.begin(), 1);
emscripten::val reshape_a_options = emscripten::val::object();
reshape_a_options.set("label", node.Name() + "_reshape_a");
a = model_builder.GetBuilder().call<emscripten::val>("reshape", a,
emscripten::val::array(GetVecUint32FromVecInt64(a_shape)));
emscripten::val::array(GetVecUint32FromVecInt64(a_shape)),
reshape_a_options);
}
// If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions.
bool extended_b_shape = false;
if (b_shape.size() == 1) {
extended_b_shape = true;
b_shape.push_back(1);
emscripten::val reshape_b_options = emscripten::val::object();
reshape_b_options.set("label", node.Name() + "_reshape_b");
b = model_builder.GetBuilder().call<emscripten::val>("reshape", b,
emscripten::val::array(GetVecUint32FromVecInt64(b_shape)));
emscripten::val::array(GetVecUint32FromVecInt64(b_shape)),
reshape_b_options);
}

output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b, options);

emscripten::val reshape_output_options = emscripten::val::object();
reshape_output_options.set("label", node.Name() + "_reshape_output");
// If the inputs are both 1D, reduce the output to a scalar.
if (extended_a_shape && extended_b_shape) {
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array());
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output,
emscripten::val::array(),
reshape_output_options);
}
// After matrix multiplication the prepended 1 is removed.
else if (extended_a_shape) {
Expand All @@ -78,15 +91,21 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
new_shape.push_back(narrow<uint32_t>(b_shape[i]));
}
new_shape.push_back(narrow<uint32_t>(b_shape.back()));
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(new_shape));
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output,
emscripten::val::array(new_shape),
reshape_output_options);
}
// After matrix multiplication the appended 1 is removed.
else if (extended_b_shape) {
std::vector<uint32_t> new_shape;
for (size_t i = 0; i < a_shape.size() - 1; i++) {
new_shape.push_back(narrow<uint32_t>(a_shape[i]));
}
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(new_shape));
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output,
emscripten::val::array(new_shape),
reshape_output_options);
}
} else if (op_type == "MatMulInteger") {
emscripten::val a_zero_point = emscripten::val::null();
Expand All @@ -101,9 +120,13 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
} else {
b_zero_point = model_builder.GetZeroConstant("uint8");
}
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger", a, a_zero_point, b, b_zero_point);
output = model_builder.GetBuilder().call<emscripten::val>("matmulInteger",
a,
a_zero_point,
b,
b_zero_point,
options);
} else { // Gemm
emscripten::val options = emscripten::val::object();
NodeAttrHelper helper(node);
const auto transA = helper.Get("transA", 0);
options.set("aTranspose", emscripten::val(transA == 1));
Expand Down
Loading

0 comments on commit 94eb70d

Please sign in to comment.