diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4ea61ad1d9..f02f176923 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -162,7 +162,9 @@ def get_transformers_tasks_to_model_mapping(tasks_to_model_loader, framework="pt for model_loader in model_loaders: model_loader_class = getattr(auto_modeling_module, model_loader, None) if model_loader_class is not None: - # we can just update the model_type to model_class mapping since we only need one either way + # we can just update the model_type to model_class mapping since + # we can only have one task->model_type->model_class either way + # e.g. we merge automatic-speech-recognition's SpeechSeq2Seq and CTC models tasks_to_model_mapping[task_name].update(model_loader_class._model_mapping._model_mapping) return tasks_to_model_mapping @@ -1767,6 +1769,7 @@ def _infer_library_from_model_or_model_class( model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None, model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None, ): + inferred_library_name = None if model is not None and model_class is not None: raise ValueError("Either a model or a model class must be provided, but both were given here.") if model is None and model_class is None: @@ -1775,20 +1778,20 @@ def _infer_library_from_model_or_model_class( target_class_module = model.__class__.__module__ if model is not None else model_class.__module__ if target_class_module.startswith("sentence_transformers"): - library_name = "sentence_transformers" + inferred_library_name = "sentence_transformers" elif target_class_module.startswith("transformers"): - library_name = "transformers" + inferred_library_name = "transformers" elif target_class_module.startswith("diffusers"): - library_name = "diffusers" + inferred_library_name = "diffusers" elif target_class_module.startswith("timm"): - library_name = "timm" + inferred_library_name = "timm" - if library_name is None: + if inferred_library_name is None: raise ValueError( "The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`." ) - return library_name + return inferred_library_name @classmethod def _infer_library_from_model_name_or_path(