Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jun 3, 2024
1 parent 30c21a5 commit 2c8c843
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def forward(
"labels": labels,
}

onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **model_inputs, **kwargs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch=use_torch, *onnx_outputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

logits = model_outputs.get("logits")
loss = model_outputs.get("loss", None)
Expand Down
5 changes: 2 additions & 3 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,6 @@ def forward(
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

model_outputs = self.prepare_io_binding_outputs(output_shapes, output_buffers)
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}
Expand Down Expand Up @@ -1171,7 +1170,7 @@ def forward(
else:
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs, **kwargs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

Expand Down Expand Up @@ -1267,7 +1266,7 @@ def forward(
else:
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs, **kwargs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

Expand Down

0 comments on commit 2c8c843

Please sign in to comment.