Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 15, 2024
1 parent 6183a8b commit 2d386d9
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 28 deletions.
12 changes: 10 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"latent_parameters": {0: "batch_size", 2: "sample_height / len_blocks", 3: "sample_width / len_blocks"},
"latent_parameters": {
0: "batch_size",
2: "sample_height / down_scaling_factor",
3: "sample_width / down_scaling_factor",
},
}


Expand All @@ -1156,7 +1160,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"sample": {0: "batch_size", 2: "latent_height * len_blocks", 3: "latent_width * len_blocks"},
"sample": {
0: "batch_size",
2: "latent_height * up_scaling_factor",
3: "latent_width * up_scaling_factor",
},
}


Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from onnx.tools import update_model_dims
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast

import onnx
import onnxruntime
from onnx.tools import update_model_dims

from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..onnx.utils import check_model_uses_external_data
Expand Down
24 changes: 21 additions & 3 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,18 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
return self._validate_same_attribute_value_across_components("dtype")

@property
def providers(self) -> Tuple[str]:
return self._validate_same_attribute_value_across_components("providers")

@property
def provider(self) -> str:
return self._validate_same_attribute_value_across_components("provider")

@property
def providers_options(self) -> Dict[str, Dict[str, Any]]:
return self._validate_same_attribute_value_across_components("providers_options")

@property
def provider_options(self) -> Dict[str, Any]:
return self._validate_same_attribute_value_across_components("provider_options")
Expand Down Expand Up @@ -461,7 +469,9 @@ def __init__(self, session: ort.InferenceSession, use_io_binding: Optional[bool]
self.register_to_config(**self._dict_from_json_file(config_file_path))

self.session = session
self.use_io_binding = use_io_binding or session.get_providers()[0] in ["CUDAExecutionProvider"]
self.use_io_binding = (
use_io_binding if use_io_binding is not None else session.get_providers()[0] in ["CUDAExecutionProvider"]
)

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.input_shapes = {input_key.name: input_key.shape for input_key in self.session.get_inputs()}
Expand Down Expand Up @@ -494,10 +504,18 @@ def _compile_shapes(self, shapes: Dict[str, Tuple[Union[int, str]]]) -> Dict[str
def device(self):
return self._device

@property
def proverties(self):
return self._providers

@property
def provider(self):
return self._providers[0]

@property
def providers_options(self):
return self._providers_options

@property
def provider_options(self):
return self._providers_options[self._providers[0]]
Expand Down Expand Up @@ -802,7 +820,7 @@ def __init__(self, *args, **kwargs):
)
self.register_to_config(scaling_factor=0.18215)

self._known_symbols["len_blocks"] = len(self.config.block_out_channels)
self._known_symbols["down_scaling_factor"] = 2 ** (len(self.config.down_block_types) - 1)

def forward(
self,
Expand Down Expand Up @@ -848,7 +866,7 @@ def __init__(self, *args, **kwargs):
)
self.register_to_config(scaling_factor=0.18215)

self._known_symbols["len_blocks"] = len(self.config.block_out_channels)
self._known_symbols["up_scaling_factor"] = 2 ** (len(self.config.up_block_types) - 1)

def forward(
self,
Expand Down
40 changes: 21 additions & 19 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def shared_attributes_init(
# OptimizedModel requires it
self.preprocessors = preprocessors

self.use_io_binding = use_io_binding or model.get_providers()[0] == "CUDAExecutionProvider"
self.use_io_binding = (
use_io_binding if use_io_binding is not None else model.get_providers()[0] in ["CUDAExecutionProvider"]
)

self._providers = model.get_providers()
self._providers_options = model.get_provider_options()
Expand Down Expand Up @@ -265,6 +267,24 @@ def __init__(

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

@property
def device(self):
return self._device

@property
def dtype(self) -> torch.dtype:
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@property
def providers(self):
# all providers
Expand All @@ -285,24 +305,6 @@ def provider_options(self):
# main provider options
return self.providers_options[self.provider]

@property
def device(self):
return self._device

@property
def dtype(self) -> torch.dtype:
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

def to(self, *args, device: Optional[Union[torch.device, str, int]] = None, dtype: Optional[torch.dtype] = None):
for arg in args:
if isinstance(arg, torch.device):
Expand Down
2 changes: 2 additions & 0 deletions optimum/subpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def load_namespace_modules(namespace: str, module: str):
dist_name = dist.metadata["Name"]
if not dist_name.startswith(f"{namespace}-"):
continue
if dist_name == "optimum-benchmark":
continue
package_import_name = dist_name.replace("-", ".")
module_import_name = f"{package_import_name}.{module}"
if module_import_name in sys.modules:
Expand Down
4 changes: 2 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,14 +2354,14 @@ def test_load_model_from_hub_onnx(self):
self.assertFalse(model.use_merged)
self.assertTrue(model.use_cache)
self.assertIsInstance(model.model, onnxruntime.InferenceSession)
self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_WITH_PAST_NAME)
self.assertEqual(model.model_path.name, ONNX_DECODER_WITH_PAST_NAME)

model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge")

self.assertTrue(model.use_merged)
self.assertTrue(model.use_cache)
self.assertIsInstance(model.model, onnxruntime.InferenceSession)
self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_MERGED_NAME)
self.assertEqual(model.model_path.name, ONNX_DECODER_MERGED_NAME)

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
Expand Down

0 comments on commit 2d386d9

Please sign in to comment.