Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 15, 2024
1 parent f46493c commit 2d5c926
Showing 1 changed file with 49 additions and 40 deletions.
89 changes: 49 additions & 40 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ class TasksManager:
_LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {}

# Torch model mappings
_TRANSFORMERS_TASKS_TO_MODEL_MAPPING = {}
_DIFFUSERS_TASKS_TO_MODEL_MAPPING = {}
_TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = {}
_DIFFUSERS_TASKS_TO_MODEL_MAPPINGS = {}

# TF model loaders
_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS = {}
_LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP = {}

# TF model mappings
_TRANSFORMERS_TASKS_TO_MODEL_MAPPING = {}
_TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = {}

if is_torch_available():
# Refer to https://huggingface.co/datasets/huggingface/transformers-metadata/blob/main/pipeline_tags.json
Expand Down Expand Up @@ -226,7 +226,7 @@ class TasksManager:
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}

_TRANSFORMERS_TASKS_TO_MODEL_MAPPING = get_transformers_tasks_to_model_mapping(
_TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = get_transformers_tasks_to_model_mapping(
_TRANSFORMERS_TASKS_TO_MODEL_LOADERS, framework="pt"
)

Expand All @@ -246,7 +246,7 @@ class TasksManager:
"text-to-image": "AutoPipelineForText2Image",
}

_DIFFUSERS_TASKS_TO_MODEL_MAPPING = get_diffusers_tasks_to_model_mapping()
_DIFFUSERS_TASKS_TO_MODEL_MAPPINGS = get_diffusers_tasks_to_model_mapping()

_LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {
"diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS,
Expand Down Expand Up @@ -281,7 +281,7 @@ class TasksManager:
"transformers": _TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS,
}

_TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPING = get_transformers_tasks_to_model_mapping(
_TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS = get_transformers_tasks_to_model_mapping(
_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS, framework="tf"
)

Expand Down Expand Up @@ -1605,16 +1605,16 @@ def _infer_task_from_model_or_model_class(
if target_class_name == model_loader_class_name:
return task_name

# using TASKS_TO_MODEL_MAPPING to infer the task name
# using TASKS_TO_MODEL_MAPPINGS to infer the task name
tasks_to_model_mapping = None

if target_class_module.startswith("transformers"):
if target_class_name.startswith("TF"):
tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPING
tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS
else:
tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_MODEL_MAPPING
tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS
elif target_class_module.startswith("diffusers"):
tasks_to_model_mapping = cls._DIFFUSERS_TASKS_TO_MODEL_MAPPING
tasks_to_model_mapping = cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS

if tasks_to_model_mapping is not None:
for task_name, model_mapping in tasks_to_model_mapping.items():
Expand Down Expand Up @@ -1661,7 +1661,7 @@ def _infer_task_from_model_name_or_path(
raise RuntimeError(
f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})."
)
library_name = TasksManager.infer_library_from_model(
library_name = cls.infer_library_from_model(
model_name_or_path,
subfolder=subfolder,
revision=revision,
Expand All @@ -1671,37 +1671,46 @@ def _infer_task_from_model_name_or_path(

if library_name == "timm":
inferred_task_name = "image-classification"
else:
pipeline_tag = getattr(model_info, "pipeline_tag", None)
# The Hub task "conversational" is not a supported task per se, just an alias that may map to
# text-generaton or text2text-generation.
# The Hub task "object-detection" is not a supported task per se, as in Transformers this may map to either
# zero-shot-object-detection or object-detection.
if pipeline_tag is not None and pipeline_tag not in ["conversational", "object-detection"]:
inferred_task_name = TasksManager.map_from_synonym(model_info.pipeline_tag)
elif library_name == "transformers":
transformers_info = model_info.transformersInfo
if transformers_info is not None and transformers_info.get("pipeline_tag") is not None:
inferred_task_name = TasksManager.map_from_synonym(transformers_info["pipeline_tag"])
else:
# transformersInfo does not always have a pipeline_tag attribute
class_name_prefix = ""
if is_torch_available():
tasks_to_automodels = TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP[library_name]
else:
tasks_to_automodels = TasksManager._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP[library_name]
class_name_prefix = "TF"

auto_model_class_name = transformers_info["auto_model"]
if not auto_model_class_name.startswith("TF"):
auto_model_class_name = f"{class_name_prefix}{auto_model_class_name}"
for task, class_name_for_task in tasks_to_automodels.items():
if class_name_for_task == auto_model_class_name:
inferred_task_name = task
elif library_name == "diffusers":
pipeline_tag = pipeline_tag = model_info.pipeline_tag
model_config = model_info.config
if pipeline_tag is not None:
inferred_task_name = cls.map_from_synonym(pipeline_tag)
elif model_config is not None:
if model_config is not None and model_config.get("diffusers", None) is not None:
diffusers_class_name = model_config["diffusers"]["_class_name"]
for task_name, model_mapping in cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS.items():
for model_type, model_class_name in model_mapping.items():
if diffusers_class_name == model_class_name:
inferred_task_name = task_name
break
if inferred_task_name is not None:
break
elif library_name == "transformers":
pipeline_tag = model_info.pipeline_tag
transformers_info = model_info.transformersInfo
if pipeline_tag is not None:
inferred_task_name = cls.map_from_synonym(model_info.pipeline_tag)
elif transformers_info is not None:
transformers_pipeline_tag = transformers_info.get("pipeline_tag", None)
transformers_auto_model = transformers_info.get("auto_model", None)
if transformers_pipeline_tag is not None:
pipeline_tag = transformers_info["pipeline_tag"]
inferred_task_name = cls.map_from_synonym(pipeline_tag)
elif transformers_auto_model is not None:
transformers_auto_model = transformers_auto_model.replace("TF", "")
for task_name, model_loaders in cls._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.items():
if isinstance(model_loaders, str):
model_loaders = (model_loaders,)
for model_loader_class_name in model_loaders:
if transformers_auto_model == model_loader_class_name:
inferred_task_name = task_name
break
if inferred_task_name is not None:
break

if inferred_task_name is None:
raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.")
raise KeyError(f"Could not find the proper task name for the model {model_name_or_path}.")

return inferred_task_name

Expand Down Expand Up @@ -1912,7 +1921,7 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai
if library_name == "diffusers":
inferred_model_type = None

for task_name, model_mapping in cls._DIFFUSERS_TASKS_TO_MODEL_MAPPING.items():
for task_name, model_mapping in cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS.items():
for model_type, model_class_name in model_mapping.items():
if model.__class__.__name__ == model_class_name:
inferred_model_type = model_type
Expand Down

0 comments on commit 2d5c926

Please sign in to comment.