From ab4341b2bc89bf893b927e7a906d79e6d9937f2d Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 16 Jul 2024 10:52:52 +0200 Subject: [PATCH] Refactor diffusers tasks (#1947) * refactor diffusers tasks "stable-diffusion" and "stable-diffusion-xl" into "text-to-image", "image-to-image" and "inpainting" * warn depreated tasks * generalize diffusion export * fix * fix * fix * clean up * trocr * fix * standardise model/pipeline mapping task lookup * add latent consistency * test * fix * fix * final * refactor * fix * fix offline hub support * remove unnecessary * misc * test * style * update docs --- .../onnx/usage_guides/export_a_model.mdx | 2 +- .../tflite/usage_guides/export_a_model.mdx | 2 +- optimum/exporters/onnx/__init__.py | 4 +- optimum/exporters/onnx/__main__.py | 23 +- optimum/exporters/onnx/convert.py | 17 +- optimum/exporters/onnx/utils.py | 16 +- optimum/exporters/tasks.py | 599 +++++++++++------- optimum/exporters/utils.py | 65 +- optimum/utils/import_utils.py | 2 +- tests/exporters/exporters_utils.py | 3 +- .../exporters/onnx/test_exporters_onnx_cli.py | 16 +- tests/exporters/onnx/test_onnx_export.py | 14 +- 12 files changed, 462 insertions(+), 301 deletions(-) diff --git a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx index 4d227e48c2..84c670579c 100644 --- a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx @@ -87,7 +87,7 @@ Required arguments: output Path indicating the directory where to store generated ONNX model. Optional arguments: - --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'image-to-text', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. + --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'image-to-text', 'zero-shot-object-detection', 'image-to-image', 'inpainting', 'text-to-image']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. --monolith Force to export the model as a single ONNX file. By default, the ONNX exporter may break the model in several ONNX files, for example for encoder-decoder models where the encoder should be run only once while the decoder is looped over. --device DEVICE The device to use to do the export. Defaults to "cpu". --opset OPSET If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used. diff --git a/docs/source/exporters/tflite/usage_guides/export_a_model.mdx b/docs/source/exporters/tflite/usage_guides/export_a_model.mdx index 8666f44543..ff06af8fb3 100644 --- a/docs/source/exporters/tflite/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/tflite/usage_guides/export_a_model.mdx @@ -59,7 +59,7 @@ Optional arguments: the model, but are among: ['default', 'fill-mask', 'text-generation', 'text2text-generation', 'text-classification', 'token-classification', 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic- segmentation', 'automatic-speech-recognition', 'audio-classification', 'audio-frame-classification', 'automatic-speech-recognition', 'audio-xvector', 'vision2seq- - lm', 'stable-diffusion', 'zero-shot-object-detection']. For decoder models, use `xxx-with-past` to export the model using past key + lm', 'zero-shot-object-detection', 'text-to-image', 'image-to-image', 'inpainting']. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder. --atol ATOL If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used. --pad_token_id PAD_TOKEN_ID diff --git a/optimum/exporters/onnx/__init__.py b/optimum/exporters/onnx/__init__.py index 609096e37e..6b99e48457 100644 --- a/optimum/exporters/onnx/__init__.py +++ b/optimum/exporters/onnx/__init__.py @@ -31,7 +31,7 @@ "utils": [ "get_decoder_models_for_export", "get_encoder_decoder_models_for_export", - "get_stable_diffusion_models_for_export", + "get_diffusion_models_for_export", "MODEL_TYPES_REQUIRING_POSITION_IDS", ], "__main__": ["main_export"], @@ -50,7 +50,7 @@ from .utils import ( get_decoder_models_for_export, get_encoder_decoder_models_for_export, - get_stable_diffusion_models_for_export, + get_diffusion_models_for_export, MODEL_TYPES_REQUIRING_POSITION_IDS, ) from .__main__ import main_export diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 1e36af06ad..703e98df3e 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -221,13 +221,24 @@ def main_export( " and passing it is not required anymore." ) + if task in ["stable-diffusion", "stable-diffusion-xl"]: + logger.warning( + f"The task `{task}` is deprecated and will be removed in a future release of Optimum. " + "Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`." + ) + original_task = task task = TasksManager.map_from_synonym(task) - framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) - library_name = TasksManager.infer_library_from_model( - model_name_or_path, subfolder=subfolder, library_name=library_name - ) + if framework is None: + framework = TasksManager.determine_framework( + model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) + + if library_name is None: + library_name = TasksManager.infer_library_from_model( + model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) torch_dtype = None if framework == "pt": @@ -321,9 +332,7 @@ def main_export( ) model.config.pad_token_id = pad_token_id - if "stable-diffusion" in task: - model_type = "stable-diffusion" - elif hasattr(model.config, "export_model_type"): + if hasattr(model.config, "export_model_type"): model_type = model.config.export_model_type.replace("_", "-") else: model_type = model.config.model_type.replace("_", "-") diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 4d5a2afc37..63a9067b90 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -60,7 +60,7 @@ from transformers.modeling_utils import PreTrainedModel if is_diffusers_available(): - from diffusers import ModelMixin + from diffusers import DiffusionPipeline, ModelMixin if is_tf_available(): from transformers.modeling_tf_utils import TFPreTrainedModel @@ -264,7 +264,7 @@ def _run_validation( atol = config.ATOL_FOR_VALIDATION if "diffusers" in str(reference_model.__class__) and not is_diffusers_available(): - raise ImportError("The pip package `diffusers` is required to validate stable diffusion ONNX models.") + raise ImportError("The pip package `diffusers` is required to validate diffusion ONNX models.") framework = "pt" if is_torch_available() and isinstance(reference_model, nn.Module) else "tf" @@ -388,7 +388,7 @@ def _run_validation( logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_output_names})") if "diffusers" in str(reference_model.__class__) and not is_diffusers_available(): - raise ImportError("The pip package `diffusers` is required to validate stable diffusion ONNX models.") + raise ImportError("The pip package `diffusers` is required to validate diffusion ONNX models.") # Check the shape and values match shape_failures = [] @@ -854,7 +854,7 @@ def export( opset = config.DEFAULT_ONNX_OPSET if "diffusers" in str(model.__class__) and not is_diffusers_available(): - raise ImportError("The pip package `diffusers` is required to export stable diffusion models to ONNX.") + raise ImportError("The pip package `diffusers` is required to export diffusion models to ONNX.") if not config.is_transformers_support_available: import transformers @@ -912,7 +912,7 @@ def export( def onnx_export_from_model( - model: Union["PreTrainedModel", "TFPreTrainedModel"], + model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"], output: Union[str, Path], opset: Optional[int] = None, optimize: Optional[str] = None, @@ -999,15 +999,16 @@ def onnx_export_from_model( >>> onnx_export_from_model(model, output="gpt2_onnx/") ``` """ - library_name = TasksManager._infer_library_from_model(model) - TasksManager.standardize_model_attributes(model, library_name) + TasksManager.standardize_model_attributes(model) if hasattr(model.config, "export_model_type"): model_type = model.config.export_model_type.replace("_", "-") else: model_type = model.config.model_type.replace("_", "-") + library_name = TasksManager.infer_library_from_model(model) + custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE if task is not None: @@ -1191,7 +1192,7 @@ def onnx_export_from_model( optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="") # Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any - # TODO: treating stable diffusion separately is quite ugly + # TODO: treating diffusion separately is quite ugly if not no_post_process and library_name != "diffusers": try: logger.info("Post-processing the exported models...") diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 8ecba9231f..675566ba23 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -34,6 +34,9 @@ from ..utils import ( get_decoder_models_for_export as _get_decoder_models_for_export, ) +from ..utils import ( + get_diffusion_models_for_export as _get_diffusion_models_for_export, +) from ..utils import ( get_encoder_decoder_models_for_export as _get_encoder_decoder_models_for_export, ) @@ -43,9 +46,6 @@ from ..utils import ( get_speecht5_models_for_export as _get_speecht5_models_for_export, ) -from ..utils import ( - get_stable_diffusion_models_for_export as _get_stable_diffusion_models_for_export, -) logger = logging.get_logger() @@ -68,7 +68,7 @@ from transformers.modeling_tf_utils import TFPreTrainedModel if is_diffusers_available(): - from diffusers import ModelMixin, StableDiffusionPipeline + from diffusers import DiffusionPipeline, ModelMixin MODEL_TYPES_REQUIRING_POSITION_IDS = { @@ -219,13 +219,13 @@ def _get_submodels_and_onnx_configs( DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT = "The usage of `optimum.exporters.onnx.utils.get_{model_type}_models_for_export` is deprecated and will be removed in a future release, please use `optimum.exporters.utils.get_{model_type}_models_for_export` instead." -def get_stable_diffusion_models_for_export( - pipeline: "StableDiffusionPipeline", +def get_diffusion_models_for_export( + pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", ) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "ExportConfig"]]: - logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="stable_diffusion")) - return _get_stable_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter="onnx") + logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="diffusion")) + return _get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter="onnx") def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"): diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index c0221f7bf6..4ea61ad1d9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -15,8 +15,6 @@ """Model export tasks manager.""" import importlib -import inspect -import itertools import os import warnings from functools import partial @@ -31,14 +29,12 @@ from transformers import AutoConfig, PretrainedConfig, is_tf_available, is_torch_available from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging -from ..utils import CONFIG_NAME -from ..utils.import_utils import is_onnx_available +from ..utils.import_utils import is_diffusers_available, is_onnx_available if TYPE_CHECKING: from .base import ExportConfig - logger = logging.get_logger(__name__) # pylint: disable=invalid-name if not is_torch_available() and not is_tf_available(): @@ -54,6 +50,14 @@ if is_tf_available(): from transformers import TFPreTrainedModel +if is_diffusers_available(): + from diffusers import DiffusionPipeline + from diffusers.pipelines.auto_pipeline import ( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + ) + ExportConfigConstructor = Callable[[PretrainedConfig], "ExportConfig"] TaskNameToExportConfigDict = Dict[str, ExportConfigConstructor] @@ -123,19 +127,45 @@ def supported_tasks_mapping( return mapping -def get_model_loaders_to_tasks(tasks_to_model_loaders: Dict[str, Union[str, Tuple[str]]]) -> Dict[str, str]: - """ - Reverses tasks_to_model_loaders while flattening the case where the same task maps to several - auto classes (e.g. automatic-speech-recognition). - """ - model_loaders_to_tasks = {} - for task, model_loaders in tasks_to_model_loaders.items(): +def get_diffusers_tasks_to_model_mapping(): + """task -> model mapping (model type -> model class)""" + + tasks_to_model_mapping = {} + + for task_name, model_mapping in ( + ("text-to-image", AUTO_TEXT2IMAGE_PIPELINES_MAPPING), + ("image-to-image", AUTO_IMAGE2IMAGE_PIPELINES_MAPPING), + ("inpainting", AUTO_INPAINT_PIPELINES_MAPPING), + ): + tasks_to_model_mapping[task_name] = {} + + for model_type, model_class in model_mapping.items(): + tasks_to_model_mapping[task_name][model_type] = model_class.__name__ + + return tasks_to_model_mapping + + +def get_transformers_tasks_to_model_mapping(tasks_to_model_loader, framework="pt"): + """task -> model mapping (model type -> model class)""" + + if framework == "pt": + auto_modeling_module = importlib.import_module("transformers.models.auto.modeling_auto") + elif framework == "tf": + auto_modeling_module = importlib.import_module("transformers.models.auto.modeling_tf_auto") + + tasks_to_model_mapping = {} + for task_name, model_loaders in tasks_to_model_loader.items(): if isinstance(model_loaders, str): - model_loaders_to_tasks[model_loaders] = task - else: - model_loaders_to_tasks.update({model_loader_name: task for model_loader_name in model_loaders}) + model_loaders = (model_loaders,) - return model_loaders_to_tasks + tasks_to_model_mapping[task_name] = {} + for model_loader in model_loaders: + model_loader_class = getattr(auto_modeling_module, model_loader, None) + if model_loader_class is not None: + # we can just update the model_type to model_class mapping since we only need one either way + tasks_to_model_mapping[task_name].update(model_loader_class._model_mapping._model_mapping) + + return tasks_to_model_mapping class TasksManager: @@ -149,10 +179,17 @@ class TasksManager: _TIMM_TASKS_TO_MODEL_LOADERS = {} _LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {} + # Torch model mappings + _TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = {} + _DIFFUSERS_TASKS_TO_MODEL_MAPPINGS = {} + # TF model loaders _TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS = {} _LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP = {} + # TF model mappings + _TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = {} + if is_torch_available(): # Refer to https://huggingface.co/datasets/huggingface/transformers-metadata/blob/main/pipeline_tags.json # In case the same task (pipeline tag) may map to several loading classes, we use a tuple and the @@ -166,7 +203,6 @@ class TasksManager: "audio-frame-classification": "AutoModelForAudioFrameClassification", "audio-xvector": "AutoModelForAudioXVector", "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), - "conversational": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "depth-estimation": "AutoModelForDepthEstimation", "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", @@ -189,10 +225,9 @@ class TasksManager: "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } - _DIFFUSERS_TASKS_TO_MODEL_LOADERS = { - "stable-diffusion": "StableDiffusionPipeline", - "stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline", - } + _TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS = get_transformers_tasks_to_model_mapping( + _TRANSFORMERS_TASKS_TO_MODEL_LOADERS, framework="pt" + ) _TIMM_TASKS_TO_MODEL_LOADERS = { "image-classification": "create_model", @@ -203,6 +238,15 @@ class TasksManager: "sentence-similarity": "SentenceTransformer", } + if is_diffusers_available(): + _DIFFUSERS_TASKS_TO_MODEL_LOADERS = { + "image-to-image": "AutoPipelineForImage2Image", + "inpainting": "AutoPipelineForInpainting", + "text-to-image": "AutoPipelineForText2Image", + } + + _DIFFUSERS_TASKS_TO_MODEL_MAPPINGS = get_diffusers_tasks_to_model_mapping() + _LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = { "diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS, "sentence_transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS, @@ -212,7 +256,6 @@ class TasksManager: if is_tf_available(): _TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS = { - "conversational": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "document-question-answering": "TFAutoModelForDocumentQuestionAnswering", "feature-extraction": "TFAutoModel", "fill-mask": "TFAutoModelForMaskedLM", @@ -222,15 +265,12 @@ class TasksManager: "text-classification": "TFAutoModelForSequenceClassification", "token-classification": "TFAutoModelForTokenClassification", "multiple-choice": "TFAutoModelForMultipleChoice", - "object-detection": "TFAutoModelForObjectDetection", "question-answering": "TFAutoModelForQuestionAnswering", "image-segmentation": "TFAutoModelForImageSegmentation", "masked-im": "TFAutoModelForMaskedImageModeling", "semantic-segmentation": "TFAutoModelForSemanticSegmentation", "automatic-speech-recognition": "TFAutoModelForSpeechSeq2Seq", "audio-classification": "TFAutoModelForAudioClassification", - "audio-frame-classification": "TFAutoModelForAudioFrameClassification", - "audio-xvector": "TFAutoModelForAudioXVector", "image-to-text": "TFAutoModelForVision2Seq", "zero-shot-image-classification": "TFAutoModelForZeroShotImageClassification", "zero-shot-object-detection": "TFAutoModelForZeroShotObjectDetection", @@ -240,6 +280,10 @@ class TasksManager: "transformers": _TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS, } + _TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS = get_transformers_tasks_to_model_mapping( + _TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS, framework="tf" + ) + _SYNONYM_TASK_MAP = { "audio-ctc": "automatic-speech-recognition", "causal-lm": "text-generation", @@ -260,17 +304,11 @@ class TasksManager: "vision2seq-lm": "image-to-text", "zero-shot-classification": "text-classification", "image-feature-extraction": "feature-extraction", - } - - # Reverse dictionaries str -> str, where several model loaders may map to the same task - _LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP = { - "diffusers": get_model_loaders_to_tasks(_DIFFUSERS_TASKS_TO_MODEL_LOADERS), - "sentence_transformers": get_model_loaders_to_tasks(_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS), - "timm": get_model_loaders_to_tasks(_TIMM_TASKS_TO_MODEL_LOADERS), - "transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS), - } - _LIBRARY_TO_TF_MODEL_LOADERS_TO_TASKS_MAP = { - "transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS), + # for backward compatibility and testing (where + # model task and model type are still the same) + "lcm": "text-to-image", + "stable-diffusion": "text-to-image", + "stable-diffusion-xl": "text-to-image", } _CUSTOM_CLASSES = { @@ -281,7 +319,6 @@ class TasksManager: ("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"), } - # TODO: why feature-extraction-with-past is here? _ENCODER_DECODER_TASKS = ( "automatic-speech-recognition", "document-question-answering", @@ -1136,7 +1173,7 @@ class TasksManager: "vae-decoder", "clip-text-model", "clip-text-with-projection", - "trocr", # TODO: why? + "trocr", # supported through the vision-encoder-decoder model type } _SUPPORTED_CLI_MODEL_TYPE = ( set(_SUPPORTED_MODEL_TYPE.keys()) @@ -1411,7 +1448,8 @@ def get_model_files( token = use_auth_token request_exception = None - full_model_path = Path(model_name_or_path) / subfolder + full_model_path = Path(model_name_or_path, subfolder) + if full_model_path.is_dir(): all_files = [ os.path.relpath(os.path.join(dirpath, file), full_model_path) @@ -1431,23 +1469,18 @@ def get_model_files( if subfolder != "": all_files = [file[len(subfolder) + 1 :] for file in all_files if file.startswith(subfolder)] except (RequestsConnectionError, OfflineModeIsEnabled) as e: - request_exception = e - object_id = model_name_or_path.replace("/", "--") - full_model_path = Path(cache_dir, f"models--{object_id}") - if full_model_path.is_dir(): # explore the cache first - # Resolve refs (for instance to convert main to the associated commit sha) - if revision is None: - revision_file = Path(full_model_path, "refs", "main") - revision = "" - if revision_file.is_file(): - with open(revision_file) as f: - revision = f.read() - cached_path = Path(full_model_path, "snapshots", revision, subfolder) + snapshot_path = huggingface_hub.snapshot_download( + repo_id=model_name_or_path, revision=revision, cache_dir=cache_dir, token=token + ) + full_model_path = Path(snapshot_path, subfolder) + if full_model_path.is_dir(): all_files = [ - os.path.relpath(os.path.join(dirpath, file), cached_path) - for dirpath, _, filenames in os.walk(cached_path) + os.path.relpath(os.path.join(dirpath, file), full_model_path) + for dirpath, _, filenames in os.walk(full_model_path) for file in filenames ] + else: + request_exception = e return all_files, request_exception @@ -1455,8 +1488,9 @@ def get_model_files( def determine_framework( model_name_or_path: Union[str, Path], subfolder: str = "", - framework: Optional[str] = None, + revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, ) -> str: """ Determines the framework to use for the export. @@ -1471,20 +1505,25 @@ def determine_framework( model_name_or_path (`Union[str, Path]`): Can be either the model id of a model repo on the Hugging Face Hub, or a path to a local directory containing a model. - subfolder (`str`, defaults to `""`): + subfolder (`str`, *optional*, defaults to `""`): In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. - framework (`Optional[str]`, *optional*): - The framework to use for the export. See above for priority if none provided. + revision (`Optional[str]`, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, *optional*): + Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). Returns: `str`: The framework to use for the export. """ - if framework is not None: - return framework - all_files, request_exception = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir) + all_files, request_exception = TasksManager.get_model_files( + model_name_or_path, subfolder=subfolder, cache_dir=cache_dir, token=token, revision=revision + ) pt_weight_name = Path(WEIGHTS_NAME).stem pt_weight_extension = Path(WEIGHTS_NAME).suffix @@ -1507,7 +1546,7 @@ def determine_framework( elif "model_index.json" in all_files and any( file.endswith((pt_weight_extension, safe_weight_extension)) for file in all_files ): - # stable diffusion case + # diffusers case framework = "pt" elif "config_sentence_transformers.json" in all_files: # Sentence Transformers libary relies on PyTorch. @@ -1538,58 +1577,67 @@ def determine_framework( @classmethod def _infer_task_from_model_or_model_class( cls, - model: Optional[Union["PreTrainedModel", "TFPreTrainedModel"]] = None, - model_class: Optional[Type] = None, + model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None, + model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None, ) -> str: if model is not None and model_class is not None: raise ValueError("Either a model or a model class must be provided, but both were given here.") if model is None and model_class is None: raise ValueError("Either a model or a model class must be provided, but none were given here.") - target_name = model.__class__.__name__ if model is not None else model_class.__name__ - task_name = None - iterable = () - for _, model_loader in cls._LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP.items(): - iterable += (model_loader.items(),) - for _, model_loader in cls._LIBRARY_TO_TF_MODEL_LOADERS_TO_TASKS_MAP.items(): - iterable += (model_loader.items(),) - - pt_auto_module = importlib.import_module("transformers.models.auto.modeling_auto") - tf_auto_module = importlib.import_module("transformers.models.auto.modeling_tf_auto") - for auto_cls_name, task in itertools.chain.from_iterable(iterable): - if any( - ( - target_name.startswith("Auto"), - target_name.startswith("TFAuto"), - "StableDiffusion" in target_name, - ) - ): - if target_name == auto_cls_name: - task_name = task - break - continue - - module = tf_auto_module if auto_cls_name.startswith("TF") else pt_auto_module - # getattr(module, auto_cls_name)._model_mapping is a _LazyMapping, it also has an attribute called - # "_model_mapping" that is what we want here: class names and not actual classes. - auto_cls = getattr(module, auto_cls_name, None) - # This is the case for StableDiffusionPipeline for instance. - if auto_cls is None: - continue - model_mapping = auto_cls._model_mapping._model_mapping - if target_name in model_mapping.values(): - task_name = task - break - if task_name is None: - raise ValueError(f"Could not infer the task name for {target_name}.") - - return task_name + target_class_name = model.__class__.__name__ if model is not None else model_class.__name__ + target_class_module = model.__class__.__module__ if model is not None else model_class.__module__ + + # using TASKS_TO_MODEL_LOADERS to infer the task name + tasks_to_model_loaders = None + + if target_class_name.startswith("AutoModel"): + tasks_to_model_loaders = cls._TRANSFORMERS_TASKS_TO_MODEL_LOADERS + elif target_class_name.startswith("TFAutoModel"): + tasks_to_model_loaders = cls._TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS + elif target_class_name.startswith("AutoPipeline"): + tasks_to_model_loaders = cls._DIFFUSERS_TASKS_TO_MODEL_LOADERS + + if tasks_to_model_loaders is not None: + for task_name, model_loaders in tasks_to_model_loaders.items(): + if isinstance(model_loaders, str): + model_loaders = (model_loaders,) + for model_loader_class_name in model_loaders: + if target_class_name == model_loader_class_name: + return task_name + + # using TASKS_TO_MODEL_MAPPINGS to infer the task name + tasks_to_model_mapping = None + + if target_class_module.startswith("transformers"): + if target_class_name.startswith("TF"): + tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_TF_MODEL_MAPPINGS + else: + tasks_to_model_mapping = cls._TRANSFORMERS_TASKS_TO_MODEL_MAPPINGS + elif target_class_module.startswith("diffusers"): + tasks_to_model_mapping = cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS + + if tasks_to_model_mapping is not None: + for task_name, model_mapping in tasks_to_model_mapping.items(): + for model_type, model_class_name in model_mapping.items(): + if target_class_name == model_class_name: + return task_name + + raise ValueError( + "The task name could not be automatically inferred. If using the command-line, please provide the argument --task task-name. Example: `--task text-classification`." + ) @classmethod def _infer_task_from_model_name_or_path( - cls, model_name_or_path: str, subfolder: str = "", revision: Optional[str] = None + cls, + model_name_or_path: str, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, ) -> str: inferred_task_name = None + is_local = os.path.isdir(os.path.join(model_name_or_path, subfolder)) if is_local: @@ -1603,70 +1651,78 @@ def _infer_task_from_model_name_or_path( "Cannot infer the task from a model repo with a subfolder yet, please specify the task manually." ) try: - model_info = huggingface_hub.model_info(model_name_or_path, revision=revision) + model_info = huggingface_hub.model_info(model_name_or_path, revision=revision, token=token) except (RequestsConnectionError, OfflineModeIsEnabled): raise RuntimeError( f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})." ) - library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder, revision) + library_name = cls.infer_library_from_model( + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + ) - if library_name == "diffusers": - if model_info.config["diffusers"].get("class_name", None): - class_name = model_info.config["diffusers"]["class_name"] - elif model_info.config["diffusers"].get("_class_name", None): - class_name = model_info.config["diffusers"]["_class_name"] - else: - raise ValueError( - f"Could not automatically infer the class name for {model_name_or_path}. Please open an issue at https://github.com/huggingface/optimum/issues." - ) - inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion" - elif library_name == "timm": + if library_name == "timm": inferred_task_name = "image-classification" - else: - pipeline_tag = getattr(model_info, "pipeline_tag", None) - # The Hub task "conversational" is not a supported task per se, just an alias that may map to - # text-generaton or text2text-generation. - # The Hub task "object-detection" is not a supported task per se, as in Transformers this may map to either - # zero-shot-object-detection or object-detection. - if pipeline_tag is not None and pipeline_tag not in ["conversational", "object-detection"]: - inferred_task_name = TasksManager.map_from_synonym(model_info.pipeline_tag) - else: - transformers_info = model_info.transformersInfo - if transformers_info is not None and transformers_info.get("pipeline_tag") is not None: - inferred_task_name = TasksManager.map_from_synonym(transformers_info["pipeline_tag"]) - else: - # transformersInfo does not always have a pipeline_tag attribute - class_name_prefix = "" - if is_torch_available(): - tasks_to_automodels = TasksManager._LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP[library_name] - else: - tasks_to_automodels = TasksManager._LIBRARY_TO_TF_TASKS_TO_MODEL_LOADER_MAP[library_name] - class_name_prefix = "TF" - - auto_model_class_name = transformers_info["auto_model"] - if not auto_model_class_name.startswith("TF"): - auto_model_class_name = f"{class_name_prefix}{auto_model_class_name}" - for task_name, class_name_for_task in tasks_to_automodels.items(): - if class_name_for_task == auto_model_class_name: - inferred_task_name = task_name + elif library_name == "diffusers": + pipeline_tag = pipeline_tag = model_info.pipeline_tag + model_config = model_info.config + if pipeline_tag is not None: + inferred_task_name = cls.map_from_synonym(pipeline_tag) + elif model_config is not None: + if model_config is not None and model_config.get("diffusers", None) is not None: + diffusers_class_name = model_config["diffusers"]["_class_name"] + for task_name, model_mapping in cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS.items(): + for model_type, model_class_name in model_mapping.items(): + if diffusers_class_name == model_class_name: + inferred_task_name = task_name + break + if inferred_task_name is not None: + break + elif library_name == "transformers": + pipeline_tag = model_info.pipeline_tag + transformers_info = model_info.transformersInfo + if pipeline_tag is not None: + inferred_task_name = cls.map_from_synonym(model_info.pipeline_tag) + elif transformers_info is not None: + transformers_pipeline_tag = transformers_info.get("pipeline_tag", None) + transformers_auto_model = transformers_info.get("auto_model", None) + if transformers_pipeline_tag is not None: + pipeline_tag = transformers_info["pipeline_tag"] + inferred_task_name = cls.map_from_synonym(pipeline_tag) + elif transformers_auto_model is not None: + transformers_auto_model = transformers_auto_model.replace("TF", "") + for task_name, model_loaders in cls._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.items(): + if isinstance(model_loaders, str): + model_loaders = (model_loaders,) + for model_loader_class_name in model_loaders: + if transformers_auto_model == model_loader_class_name: + inferred_task_name = task_name + break + if inferred_task_name is not None: break if inferred_task_name is None: - raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") + raise KeyError(f"Could not find the proper task name for the model {model_name_or_path}.") + return inferred_task_name @classmethod def infer_task_from_model( cls, - model: Union[str, "PreTrainedModel", "TFPreTrainedModel", Type], + model: Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline", Type], subfolder: str = "", revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, ) -> str: """ - Infers the task from the model repo. + Infers the task from the model repo, model instance, or model class. Args: - model (`str`): + model (`Union[str, PreTrainedModel, TFPreTrainedModel, DiffusionPipeline, Type]`): The model to infer the task from. This can either be the name of a repo on the HuggingFace Hub, an instance of a model, or a model class. subfolder (`str`, *optional*, defaults to `""`): @@ -1674,64 +1730,82 @@ def infer_task_from_model( Face Hub, you can specify the subfolder name here. revision (`Optional[str]`, defaults to `None`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, *optional*): + Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + Returns: - `str`: The task name automatically detected from the model repo. + `str`: The task name automatically detected from the HF hub repo, model instance, or model class. """ - is_torch_pretrained_model = is_torch_available() and isinstance(model, PreTrainedModel) - is_tf_pretrained_model = is_tf_available() and isinstance(model, TFPreTrainedModel) - task = None + inferred_task_name = None + if isinstance(model, str): - task = cls._infer_task_from_model_name_or_path(model, subfolder=subfolder, revision=revision) - elif is_torch_pretrained_model or is_tf_pretrained_model: - task = cls._infer_task_from_model_or_model_class(model=model) - elif inspect.isclass(model): - task = cls._infer_task_from_model_or_model_class(model_class=model) + inferred_task_name = cls._infer_task_from_model_name_or_path( + model_name_or_path=model, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + ) + elif type(model) == type: + inferred_task_name = cls._infer_task_from_model_or_model_class(model_class=model) + else: + inferred_task_name = cls._infer_task_from_model_or_model_class(model=model) - if task is None: - raise ValueError(f"Could not infer the task from {model}.") + if inferred_task_name is None: + raise ValueError( + "The task name could not be automatically inferred. If using the command-line, please provide the argument --task task-name. Example: `--task text-classification`." + ) - return task + return inferred_task_name - @staticmethod - def _infer_library_from_model( - model: Union["PreTrainedModel", "TFPreTrainedModel"], library_name: Optional[str] = None + @classmethod + def _infer_library_from_model_or_model_class( + cls, + model: Optional[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]] = None, + model_class: Optional[Type[Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]]] = None, ): - if library_name is not None: - return library_name + if model is not None and model_class is not None: + raise ValueError("Either a model or a model class must be provided, but both were given here.") + if model is None and model_class is None: + raise ValueError("Either a model or a model class must be provided, but none were given here.") + + target_class_module = model.__class__.__module__ if model is not None else model_class.__module__ - # SentenceTransformer models have no config attributes - if hasattr(model, "_model_config"): + if target_class_module.startswith("sentence_transformers"): library_name = "sentence_transformers" - elif ( - hasattr(model, "pretrained_cfg") - or hasattr(model.config, "pretrained_cfg") - or hasattr(model.config, "architecture") - ): - library_name = "timm" - elif hasattr(model.config, "_diffusers_version") or getattr(model, "config_name", "") == "model_index.json": - library_name = "diffusers" - else: + elif target_class_module.startswith("transformers"): library_name = "transformers" + elif target_class_module.startswith("diffusers"): + library_name = "diffusers" + elif target_class_module.startswith("timm"): + library_name = "timm" + + if library_name is None: + raise ValueError( + "The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`." + ) + return library_name @classmethod - def infer_library_from_model( + def _infer_library_from_model_name_or_path( cls, model_name_or_path: Union[str, Path], subfolder: str = "", revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, - library_name: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, ): """ - Infers the library from the model repo. + Infers the library from the model name or path. Args: model_name_or_path (`str`): - The model to infer the task from. This can either be the name of a repo on the HuggingFace Hub, an - instance of a model, or a model class. + The model to infer the task from. This can either be the name of a repo on the HuggingFace Hub, or a path + to a local directory containing the model. subfolder (`str`, defaults to `""`): In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. @@ -1739,10 +1813,6 @@ def infer_library_from_model( Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. - library_name (`Optional[str]`, *optional*): - The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". - use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): - Deprecated. Please use the `token` argument instead. token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). @@ -1751,72 +1821,64 @@ def infer_library_from_model( `str`: The library name automatically detected from the model repo. """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - - if library_name is not None: - return library_name + inferred_library_name = None all_files, _ = TasksManager.get_model_files( - model_name_or_path, subfolder, cache_dir, token=token, revision=revision + model_name_or_path, + subfolder=subfolder, + cache_dir=cache_dir, + revision=revision, + token=token, ) if "model_index.json" in all_files: - library_name = "diffusers" + inferred_library_name = "diffusers" elif ( any(file_path.startswith("sentence_") for file_path in all_files) or "config_sentence_transformers.json" in all_files ): - library_name = "sentence_transformers" - elif CONFIG_NAME in all_files: - # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type. + inferred_library_name = "sentence_transformers" + elif "config.json" in all_files: kwargs = { "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, "token": token, } + # We do not use PretrainedConfig.from_pretrained which has unwanted warnings about model type. config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs) model_config = PretrainedConfig.from_dict(config_dict, **kwargs) if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"): - library_name = "timm" + inferred_library_name = "timm" elif hasattr(model_config, "_diffusers_version"): - library_name = "diffusers" + inferred_library_name = "diffusers" else: - library_name = "transformers" - else: - library_name = "transformers" + inferred_library_name = "transformers" - if library_name is None: + if inferred_library_name is None: raise ValueError( "The library name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`." ) - return library_name + return inferred_library_name @classmethod - def standardize_model_attributes( + def infer_library_from_model( cls, - model: Union["PreTrainedModel", "TFPreTrainedModel"], - library_name: Optional[str] = None, + model: Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline", Type], + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, ): """ - Updates the model for export. This function is suitable to make required changes to the models from different - libraries to follow transformers style. + Infers the library from the model repo, model instance, or model class. Args: - model_name_or_path (`Union[str, Path]`): - Can be either the model id of a model repo on the Hugging Face Hub, or a path to a local directory - containing a model. - model (`Union[PreTrainedModel, TFPreTrainedModel]`): - The instance of the model. + model (`Union[str, PreTrainedModel, TFPreTrainedModel, DiffusionPipeline, Type]`): + The model to infer the task from. This can either be the name of a repo on the HuggingFace Hub, an + instance of a model, or a model class. subfolder (`str`, defaults to `""`): In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. @@ -1824,20 +1886,66 @@ def standardize_model_attributes( Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. cache_dir (`Optional[str]`, *optional*): Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. - library_name (`Optional[str]`, *optional*):: - The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + `str`: The library name automatically detected from the model repo, model instance, or model class. """ - library_name = TasksManager._infer_library_from_model(model, library_name) + + if isinstance(model, str): + library_name = cls._infer_library_from_model_name_or_path( + model_name_or_path=model, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + ) + elif type(model) == type: + library_name = cls._infer_library_from_model_or_model_class(model_class=model) + else: + library_name = cls._infer_library_from_model_or_model_class(model=model) + + return library_name + + @classmethod + def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]): + """ + Updates the model for export. This function is suitable to make required changes to the models from different + libraries to follow transformers style. + + Args: + model (`Union[PreTrainedModel, TFPreTrainedModel, DiffusionPipeline]`): + The instance of the model. + + """ + + library_name = TasksManager.infer_library_from_model(model) if library_name == "diffusers": - model.config.export_model_type = "stable-diffusion" - elif library_name == "timm": - # Retrieve model config - model_config = PretrainedConfig.from_dict(model.pretrained_cfg) + inferred_model_type = None + + for task_name, model_mapping in cls._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS.items(): + for model_type, model_class_name in model_mapping.items(): + if model.__class__.__name__ == model_class_name: + inferred_model_type = model_type + break + if inferred_model_type is not None: + break + + if inferred_model_type is None: + raise ValueError( + f"The export of a DiffusionPipeline model with the class name {model.__class__.__name__} is currently not supported in Optimum. " + "Please open an issue or submit a PR to add the support." + ) - # Set config as in transformers - setattr(model, "config", model_config) + # `model_type` is a class attribute in Transformers, let's avoid modifying it. + model.config.export_model_type = inferred_model_type + elif library_name == "timm": + # Retrieve model config and set it like in transformers + model.config = PretrainedConfig.from_dict(model.pretrained_cfg) # `model_type` is a class attribute in Transformers, let's avoid modifying it. model.config.export_model_type = model.pretrained_cfg["architecture"] @@ -1881,13 +1989,14 @@ def get_model_from_task( model_name_or_path: Union[str, Path], subfolder: str = "", revision: Optional[str] = None, - framework: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, + token: Optional[Union[bool, str]] = None, + framework: Optional[str] = None, torch_dtype: Optional["torch.dtype"] = None, device: Optional[Union["torch.device", str]] = None, - library_name: str = None, + library_name: Optional[str] = None, **model_kwargs, - ) -> Union["PreTrainedModel", "TFPreTrainedModel"]: + ) -> Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]: """ Retrieves a model from its name and the task to be enabled. @@ -1902,34 +2011,44 @@ def get_model_from_task( Face Hub, you can specify the subfolder name here. revision (`Optional[str]`, *optional*): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, *optional*): + Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). framework (`Optional[str]`, *optional*): The framework to use for the export. See `TasksManager.determine_framework` for the priority should none be provided. - cache_dir (`Optional[str]`, *optional*): - Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. torch_dtype (`Optional[torch.dtype]`, defaults to `None`): Data type to load the model on. PyTorch-only argument. device (`Optional[torch.device]`, defaults to `None`): Device to initialize the model on. PyTorch-only argument. For PyTorch, defaults to "cpu". - model_kwargs (`Dict[str, Any]`, *optional*): - Keyword arguments to pass to the model `.from_pretrained()` method. library_name (`Optional[str]`, defaults to `None`): The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". See `TasksManager.infer_library_from_model` for the priority should none be provided. + model_kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments to pass to the model `.from_pretrained()` method. Returns: The instance of the model. """ - framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) + + if framework is None: + framework = TasksManager.determine_framework( + model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) + + if library_name is None: + library_name = TasksManager.infer_library_from_model( + model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) original_task = task if task == "auto": - task = TasksManager.infer_task_from_model(model_name_or_path, subfolder=subfolder, revision=revision) - - library_name = TasksManager.infer_library_from_model( - model_name_or_path, subfolder, revision, cache_dir, library_name - ) + task = TasksManager.infer_task_from_model( + model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) model_type = None model_class_name = None @@ -2004,7 +2123,7 @@ def get_model_from_task( kwargs["from_pt"] = True model = model_class.from_pretrained(model_name_or_path, **kwargs) - TasksManager.standardize_model_attributes(model, library_name) + TasksManager.standardize_model_attributes(model) return model diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 74d2d98385..902dd89f77 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -43,6 +43,18 @@ f"We found an older version of diffusers {_diffusers_version} but we require diffusers to be >= {DIFFUSERS_MINIMUM_VERSION}. " "Please update diffusers by running `pip install --upgrade diffusers`" ) + + from diffusers import ( + DiffusionPipeline, + LatentConsistencyModelImg2ImgPipeline, + LatentConsistencyModelPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + ) from diffusers.models.attention_processor import ( Attention, AttnAddedKVProcessor, @@ -53,6 +65,7 @@ LoRAAttnProcessor2_0, ) + if TYPE_CHECKING: from .base import ExportConfig @@ -63,7 +76,7 @@ from transformers.modeling_tf_utils import TFPreTrainedModel if is_diffusers_available(): - from diffusers import ModelMixin, StableDiffusionPipeline + from diffusers import DiffusionPipeline, ModelMixin ENCODER_NAME = "encoder_model" @@ -72,23 +85,40 @@ DECODER_MERGED_NAME = "decoder_model_merged" -def _get_submodels_for_export_stable_diffusion( - pipeline: "StableDiffusionPipeline", +def _get_submodels_for_export_diffusion( + pipeline: "DiffusionPipeline", ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: """ Returns the components of a Stable Diffusion model. """ - from diffusers import StableDiffusionXLImg2ImgPipeline - models_for_export = {} - if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): + is_stable_diffusion = isinstance( + pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline) + ) + is_stable_diffusion_xl = isinstance( + pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline) + ) + is_latent_consistency_model = isinstance( + pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline) + ) + + if is_stable_diffusion_xl: projection_dim = pipeline.text_encoder_2.config.projection_dim - else: + elif is_stable_diffusion: projection_dim = pipeline.text_encoder.config.projection_dim + elif is_latent_consistency_model: + projection_dim = pipeline.text_encoder.config.projection_dim + else: + raise ValueError( + f"The export of a DiffusionPipeline model with the class name {pipeline.__class__.__name__} is currently not supported in Optimum. " + "Please open an issue or submit a PR to add the support." + ) + + models_for_export = {} # Text encoder if pipeline.text_encoder is not None: - if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): + if is_stable_diffusion_xl: pipeline.text_encoder.config.output_hidden_states = True models_for_export["text_encoder"] = pipeline.text_encoder @@ -97,6 +127,7 @@ def _get_submodels_for_export_stable_diffusion( is_torch_greater_or_equal_than_2_1 = version.parse(torch.__version__) >= version.parse("2.1.0") if not is_torch_greater_or_equal_than_2_1: pipeline.unet.set_attn_processor(AttnProcessor()) + pipeline.unet.config.text_encoder_projection_dim = projection_dim # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 @@ -258,17 +289,17 @@ def get_decoder_models_for_export( return models_for_export -def get_stable_diffusion_models_for_export( - pipeline: "StableDiffusionPipeline", +def get_diffusion_models_for_export( + pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "onnx", ) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "ExportConfig"]]: """ - Returns the components of a Stable Diffusion model and their subsequent export configs. + Returns the components of a Diffusion model and their subsequent export configs. Args: - pipeline ([`StableDiffusionPipeline`]): + pipeline ([`DiffusionPipeline`]): The model to export. int_dtype (`str`, defaults to `"int64"`): The data type of integer tensors, could be ["int64", "int32", "int8"], default to "int64". @@ -279,7 +310,7 @@ def get_stable_diffusion_models_for_export( `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `ExportConfig`]: A Dict containing the model and export configs for the different components of the model. """ - models_for_export = _get_submodels_for_export_stable_diffusion(pipeline) + models_for_export = _get_submodels_for_export_diffusion(pipeline) # Text encoder if "text_encoder" in models_for_export: @@ -505,7 +536,7 @@ def override_diffusers_2_0_attn_processors(model): def _get_submodels_and_export_configs( - model: Union["PreTrainedModel", "TFPreTrainedModel"], + model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"], task: str, monolith: bool, custom_export_configs: Dict, @@ -523,7 +554,7 @@ def _get_submodels_and_export_configs( if not custom_architecture: if library_name == "diffusers": export_config = None - models_and_export_configs = get_stable_diffusion_models_for_export( + models_and_export_configs = get_diffusion_models_for_export( model, int_dtype=int_dtype, float_dtype=float_dtype, exporter=exporter ) else: @@ -575,7 +606,7 @@ def _get_submodels_and_export_configs( submodels_for_export = fn_get_submodels(model) else: if library_name == "diffusers": - submodels_for_export = _get_submodels_for_export_stable_diffusion(model) + submodels_for_export = _get_submodels_for_export_diffusion(model) elif ( model.config.is_encoder_decoder and task.startswith(TasksManager._ENCODER_DECODER_TASKS) @@ -599,7 +630,7 @@ def _get_submodels_and_export_configs( for key, custom_export_config in custom_export_configs.items(): models_and_export_configs[key] = (submodels_for_export[key], custom_export_config) - # Default to the first ONNX config for stable-diffusion and custom architecture case. + # Default to the first ONNX config for diffusion and custom architecture case. if export_config is None: export_config = next(iter(models_and_export_configs.values()))[1] diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index a5df9e2624..4a57fda79c 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -50,7 +50,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ TORCH_MINIMUM_VERSION = version.parse("1.11.0") TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") -DIFFUSERS_MINIMUM_VERSION = version.parse("0.18.0") +DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9c5d2c8991..a55c7a124d 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -295,9 +295,10 @@ "roberta": "roberta-base", } -PYTORCH_STABLE_DIFFUSION_MODEL = { +PYTORCH_DIFFUSION_MODEL = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", + "lcm": "echarlaix/tiny-random-latent-consistency", } PYTORCH_TIMM_MODEL = { diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 667206b006..ed611ade04 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -41,9 +41,9 @@ from ..exporters_utils import ( NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, + PYTORCH_DIFFUSION_MODEL, PYTORCH_EXPORT_MODELS_TINY, PYTORCH_SENTENCE_TRANSFORMERS_MODEL, - PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES, @@ -252,29 +252,29 @@ def _onnx_export_no_dynamic_axes( except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch @require_vision @require_diffusers - def test_exporters_cli_pytorch_cpu_stable_diffusion(self, model_type: str, model_name: str): + def test_exporters_cli_pytorch_cpu_diffusion(self, model_type: str, model_name: str): self._onnx_export(model_name, model_type) - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch_gpu @require_vision @require_diffusers @slow @pytest.mark.run_slow - def test_exporters_cli_pytorch_gpu_stable_diffusion(self, model_type: str, model_name: str): + def test_exporters_cli_pytorch_gpu_diffusion(self, model_type: str, model_name: str): self._onnx_export(model_name, model_type, device="cuda") - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch_gpu @require_vision @require_diffusers @slow @pytest.mark.run_slow - def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: str): + def test_exporters_cli_fp16_diffusion(self, model_type: str, model_name: str): self._onnx_export(model_name, model_type, device="cuda", fp16=True) @parameterized.expand( @@ -594,7 +594,7 @@ def test_trust_remote_code(self): check=True, ) - def test_stable_diffusion(self): + def test_diffusion(self): with TemporaryDirectory() as tmpdirname: subprocess.run( f"python3 -m optimum.exporters.onnx --model hf-internal-testing/tiny-stable-diffusion-torch --task stable-diffusion {tmpdirname}", diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 9eddc4c86d..d1471aa218 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -32,8 +32,8 @@ OnnxConfigWithPast, export_models, get_decoder_models_for_export, + get_diffusion_models_for_export, get_encoder_decoder_models_for_export, - get_stable_diffusion_models_for_export, main_export, onnx_export_from_model, validate_models_outputs, @@ -48,9 +48,9 @@ from optimum.utils.testing_utils import grid_parameters, require_diffusers from ..exporters_utils import ( + PYTORCH_DIFFUSION_MODEL, PYTORCH_EXPORT_MODELS_TINY, PYTORCH_SENTENCE_TRANSFORMERS_MODEL, - PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, TENSORFLOW_EXPORT_MODELS, VALIDATE_EXPORT_ON_SHAPES_SLOW, @@ -294,7 +294,7 @@ def _onnx_export( def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"): pipeline = TasksManager.get_model_from_task(model_type, model_name, device=device) - models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) + models_and_onnx_configs = get_diffusion_models_for_export(pipeline) output_names = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] model, _ = models_and_onnx_configs["vae_encoder"] model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} @@ -398,14 +398,14 @@ def test_tensorflow_export( self._onnx_export(test_name, model_type, model_name, task, onnx_config_class_constructor, monolith=monolith) - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch @require_vision @require_diffusers - def test_pytorch_export_for_stable_diffusion_models(self, model_type, model_name): + def test_pytorch_export_for_diffusion_models(self, model_type, model_name): self._onnx_export_sd(model_type, model_name) - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @parameterized.expand(PYTORCH_DIFFUSION_MODEL.items()) @require_torch @require_vision @require_diffusers @@ -413,7 +413,7 @@ def test_pytorch_export_for_stable_diffusion_models(self, model_type, model_name @slow @pytest.mark.run_slow @pytest.mark.gpu_test - def test_pytorch_export_for_stable_diffusion_models_cuda(self, model_type, model_name): + def test_pytorch_export_for_diffusion_models_cuda(self, model_type, model_name): self._onnx_export_sd(model_type, model_name, device="cuda")