From c76102688c64c631af31b310663c727221a1d2b3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 25 Jul 2024 16:15:54 +0200 Subject: [PATCH] revert --- optimum/onnxruntime/modeling_diffusion.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 268c89088c..4bbfb2eda2 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -56,9 +56,9 @@ DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) -from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( + _ORT_TO_NP_TYPE, ONNX_WEIGHTS_NAME, get_provider_for_device, parse_device, @@ -496,8 +496,7 @@ def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} config_path = Path(session._model_path).parent / self.CONFIG_NAME self.config = self.parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} - self.input_dtype = {inputs.name: inputs.type for inputs in self.session.get_inputs()} - self.output_dtype = {outputs.name: outputs.type for outputs in self.session.get_outputs()} + self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} @property def device(self):