Skip to content

Commit

Permalink
default value for weights vector
Browse files Browse the repository at this point in the history
  • Loading branch information
samutamm committed Oct 1, 2024
1 parent 8440c11 commit 64475e8
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4350,14 +4350,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) ||
binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) ||
binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) ||
binder.f32FloatArrayAttr(weights, "weights", {}) ||
binder.customOpNameStringAttr(mode, "mode", "") ||
binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) ||
binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) ||
binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) ||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure();

llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
return failure();

if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported");
Expand Down

0 comments on commit 64475e8

Please sign in to comment.