Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 14, 2024
1 parent b55b06f commit e956e81
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,20 @@ def shared_attributes_init(
model: ort.InferenceSession,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
"""
Initializes attributes that may be shared among several ONNX Runtime inference sesssions.
"""

# TODO: remove at version 2.0
if kwargs.pop("latest_model_name", None) is not None:
logger.warning(
f"The latest_model_name argument to create an {self.__class__.__name__} is deprecated, and not used "
"anymore."
)

if kwargs:
raise ValueError(
f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments."
Expand Down Expand Up @@ -227,9 +230,6 @@ def shared_attributes_init(
else:
self.model_save_dir = Path(model._model_path).parent

# because OptimizedModel requires it
self.preprocessors = kwargs.pop("preprocessors", [])

# Registers the ORTModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.model_type, AutoConfig)
Expand All @@ -239,13 +239,16 @@ def shared_attributes_init(
# Define the pattern here to avoid recomputing it everytime.
self.output_shape_inference_pattern = re.compile(r"([a-zA-Z_]+)|([0-9]+)|([+-/*])|([\(\)])")

# because OptimizedModel requires it
self.preprocessors = preprocessors

def __init__(
self,
model: ort.InferenceSession,
config: "PretrainedConfig",
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
super().__init__(model, config)
Expand Down

0 comments on commit e956e81

Please sign in to comment.