Skip to content

Commit

Permalink
Fix infer library for sentence transformers models (#1832)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Apr 25, 2024
1 parent 3b5c486 commit c55f882
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,16 +1671,17 @@ def _infer_library_from_model(
if library_name is not None:
return library_name

if (
# SentenceTransformer models have no config attributes
if hasattr(model, "_model_config"):
library_name = "sentence_transformers"
elif (
hasattr(model, "pretrained_cfg")
or hasattr(model.config, "pretrained_cfg")
or hasattr(model.config, "architecture")
):
library_name = "timm"
elif hasattr(model.config, "_diffusers_version") or getattr(model, "config_name", "") == "model_index.json":
library_name = "diffusers"
elif hasattr(model, "_model_config"):
library_name = "sentence_transformers"
else:
library_name = "transformers"
return library_name
Expand Down Expand Up @@ -1905,7 +1906,6 @@ def get_model_from_task(
model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
model = model.to(torch_dtype).to(device)
Expand Down

0 comments on commit c55f882

Please sign in to comment.