Skip to content

Commit

Permalink
Merge branch 'main' into adrianl/quant-tool-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Oct 24, 2024
2 parents a8559f0 + 3ae7c3c commit 0026eca
Show file tree
Hide file tree
Showing 13 changed files with 603 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return false;
}

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Expand does not support empty input's shape.";
return false;
}

std::vector<int64_t> output_shape;
if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) {
LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,25 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& input_defs = node.InputDefs();
const auto& initializers(model_builder.GetInitializerTensors());
const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name());
const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
? reinterpret_cast<const int64_t*>(target_shape_tensor.raw_data().data())
: target_shape_tensor.int64_data().data();
const auto& target_shape_tensor_dims = target_shape_tensor.dims();
std::vector<uint32_t> new_shape;
// Do nothing if target shape is an empty shape, which means converting to a scalar.
if (!target_shape_tensor_dims.empty()) {
const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty()
? reinterpret_cast<const int64_t*>(target_shape_tensor.raw_data().data())
: target_shape_tensor.int64_data().data();

const auto size = target_shape_tensor_dims[0];
TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
ReshapeHelper helper(TensorShape(input_shape), target_shape);
std::transform(target_shape.cbegin(), target_shape.cend(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
}

const auto size = target_shape_tensor.dims()[0];
TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size};
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
ReshapeHelper helper(TensorShape(input_shape), target_shape);
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
std::vector<int32_t> new_shape;
std::transform(target_shape.cbegin(), target_shape.cend(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<int32_t>(dim); });

emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("reshape",
Expand All @@ -76,6 +80,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

const auto& perm_name = input_defs[1]->Name();
if (!Contains(initializers, perm_name)) {
LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer";
Expand All @@ -92,24 +101,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer

const int64_t* raw_new_shape = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
const auto& perm_dims = perm_tensor.dims();
if (perm_dims.empty() || perm_dims[0] == 0) {
LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

if (input_shape.empty()) {
LOGS(logger, VERBOSE) << "Reshape does not support empty input shape";
return false;
}

// WebNN reshape does not support 0 as dimension.
NodeAttrHelper helper(node);
const bool allow_zero = helper.Get("allowzero ", 0) == 1;
if (allow_zero) {
const bool allow_zero = helper.Get("allowzero", 0) == 1;
if (allow_zero && !perm_dims.empty()) {
for (int64_t i = 0; i < perm_dims[0]; i++) {
if (raw_new_shape[i] == 0) {
LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";
Expand Down
Loading

0 comments on commit 0026eca

Please sign in to comment.