Skip to content

Commit

Permalink
better handle noninstalled libs error during model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 21, 2024
1 parent 8af46e5 commit 6ffaca7
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,12 @@ def get_model_class_for_task(

if (framework, model_type, task) in TasksManager._CUSTOM_CLASSES:
library, class_name = TasksManager._CUSTOM_CLASSES[(framework, model_type, task)]
loaded_library = importlib.import_module(library)

try:
loaded_library = importlib.import_module(library)
except ModuleNotFoundError:
raise ValueError(f"`{library}` selected as model source, but `{library}` is not installed. Please install it.")


return getattr(loaded_library, class_name)
else:
Expand All @@ -1387,7 +1392,10 @@ def get_model_class_for_task(
else:
tasks_to_model_loader = TasksManager._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP[library]

loaded_library = importlib.import_module(library)
try:
loaded_library = importlib.import_module(library)
except ModuleNotFoundError:
raise ValueError(f"`{library}` selected as model source, but `{library}` is not installed. Please install it.")

if model_class_name is None:
if task not in tasks_to_model_loader:
Expand Down Expand Up @@ -2063,10 +2071,13 @@ def get_model_from_task(
model_class_name = config.architectures[0]

if library_name == "diffusers":
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
class_name = config.get("_class_name", None)
loaded_library = importlib.import_module(library_name)
model_class = getattr(loaded_library, class_name)
if is_diffusers_available():
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
class_name = config.get("_class_name", None)
loaded_library = importlib.import_module(library_name)
model_class = getattr(loaded_library, class_name)
else:
raise ValueError("`diffusers` library selected as model source, but `diffusers` is not installed. Please install it.")
else:
model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
Expand Down Expand Up @@ -2219,3 +2230,4 @@ def get_exporter_config_constructor(
exporter_config_constructor = partial(exporter_config_constructor, **exporter_config_kwargs)

return exporter_config_constructor

0 comments on commit 6ffaca7

Please sign in to comment.