diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 4d52dd9cd9..0c965d76bc 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -16,7 +16,6 @@ import logging import os import shutil -import warnings from collections import OrderedDict from pathlib import Path from tempfile import TemporaryDirectory @@ -31,6 +30,7 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + DiffusionPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, StableDiffusionImg2ImgPipeline, @@ -42,7 +42,7 @@ ) from diffusers.schedulers import SchedulerMixin from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available +from diffusers.utils import is_invisible_watermark_available from huggingface_hub import snapshot_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import validate_hf_hub_args @@ -79,7 +79,7 @@ logger = logging.getLogger(__name__) -class ORTPipeline(ORTModel): +class ORTPipeline(ORTModel, ConfigMixin): auto_model_class = None model_type = "onnx_pipeline" @@ -89,12 +89,12 @@ class ORTPipeline(ORTModel): def __init__( self, config: Dict[str, Any], - scheduler: SchedulerMixin, - unet: ort.InferenceSession, + unet: Optional[ort.InferenceSession] = None, vae_encoder: Optional[ort.InferenceSession] = None, vae_decoder: Optional[ort.InferenceSession] = None, text_encoder: Optional[ort.InferenceSession] = None, text_encoder_2: Optional[ort.InferenceSession] = None, + scheduler: Optional[SchedulerMixin] = None, tokenizer: Optional[CLIPTokenizer] = None, tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, @@ -151,52 +151,51 @@ def __init__( if hasattr(self.vae.config, "block_out_channels"): self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) else: - self.vae_scale_factor = 8 + self.vae_scale_factor = 8 # for old configs without block_out_channels self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - sub_models = { + # Modify config to keep the resulting model compatible with diffusers pipelines + models_to_subfolder = { self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER, self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, } - - # Modify config to keep the resulting model compatible with diffusers pipelines - for model, model_name in sub_models.items(): - config[model_name] = ("optimum", model.__class__.__name__) if model is not None else (None, None) + for model, model_subfolder in models_to_subfolder.items(): + config[model_subfolder] = ("optimum", model.__class__.__name__) if model is not None else (None, None) self._internal_dict = FrozenDict(config) self.shared_attributes_init(model=unet, use_io_binding=use_io_binding, model_save_dir=model_save_dir) @staticmethod def load_model( - vae_decoder_path: Union[str, Path], - text_encoder_path: Union[str, Path], unet_path: Union[str, Path], vae_encoder_path: Optional[Union[str, Path]] = None, + vae_decoder_path: Optional[Union[str, Path]] = None, + text_encoder_path: Optional[Union[str, Path]] = None, text_encoder_2_path: Optional[Union[str, Path]] = None, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict] = None, ): """ - Creates three inference sessions for respectively the VAE decoder, the text encoder and the U-NET models. + Creates three inference sessions for the components of a Diffusion Pipeline (U-NET, VAE, Text Encoders). The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. Args: - vae_decoder_path (`Union[str, Path]`): - The path to the VAE decoder ONNX model. - text_encoder_path (`Union[str, Path]`): - The path to the text encoder ONNX model. unet_path (`Union[str, Path]`): The path to the U-NET ONNX model. vae_encoder_path (`Union[str, Path]`, defaults to `None`): The path to the VAE encoder ONNX model. + vae_decoder_path (`Union[str, Path]`, defaults to `None`): + The path to the VAE decoder ONNX model. + text_encoder_path (`Union[str, Path]`, defaults to `None`): + The path to the text encoder ONNX model. text_encoder_2_path (`Union[str, Path]`, defaults to `None`): The path to the second text decoder ONNX model. provider (`str`, defaults to `"CPUExecutionProvider"`): @@ -208,50 +207,48 @@ def load_model( Provider option dictionary corresponding to the provider used. See available options for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. """ - vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) - unet = ORTModel.load_model(unet_path, provider, session_options, provider_options) - - sessions = { + paths = { + "unet": unet_path, "vae_encoder": vae_encoder_path, + "vae_decoder": vae_decoder_path, "text_encoder": text_encoder_path, "text_encoder_2": text_encoder_2_path, } - for key, value in sessions.items(): - if value is not None and value.is_file(): - sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options) + sessions = {} + for model_name, model_path in paths.items(): + if model_path is not None and model_path.is_file(): + sessions[model_name] = ORTModel.load_model(model_path, provider, session_options, provider_options) else: - sessions[key] = None + sessions[model_name] = None - return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"] + return sessions def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) - sub_models_to_save = { - self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER, - self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, - self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, - self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, - self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + models_to_save_paths = { + self.unet: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.vae_decoder: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.vae_encoder: save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.text_encoder: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.text_encoder_2: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / ONNX_WEIGHTS_NAME, } - - for model, model_folder in sub_models_to_save.items(): + for model, model_save_path in models_to_save_paths.items(): if model is not None: model_path = Path(model.session._model_path) - model_save_path = save_directory / model_folder / ONNX_WEIGHTS_NAME model_save_path.parent.mkdir(parents=True, exist_ok=True) # copy onnx model shutil.copyfile(model_path, model_save_path) - # copy external data + # copy external onnx data external_data_paths = _get_model_external_data_paths(model_path) for external_data_path in external_data_paths: shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name) - # copy config - shutil.copyfile( - model_path.parent / self.sub_component_config_name, - model_save_path.parent / self.sub_component_config_name, - ) + # copy model config + config_path = model_path.parent / self.sub_component_config_name + if config_path.is_file(): + config_save_path = model_save_path.parent / self.sub_component_config_name + shutil.copyfile(config_path, config_save_path) self.scheduler.save_pretrained(save_directory / "scheduler") @@ -267,16 +264,15 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: Dict[str, Any], - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, + local_files_only: bool = False, revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, - vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, - text_encoder_file_name: str = ONNX_WEIGHTS_NAME, + token: Optional[Union[bool, str]] = None, unet_file_name: str = ONNX_WEIGHTS_NAME, vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, + vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, + text_encoder_file_name: str = ONNX_WEIGHTS_NAME, text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, - local_files_only: bool = False, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, @@ -284,18 +280,6 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - - if provider == "TensorrtExecutionProvider": - raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported") - model_id = str(model_id) patterns = set(config.keys()) sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) @@ -305,13 +289,13 @@ def _from_pretrained( allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} allow_patterns.update( { - vae_decoder_file_name, - text_encoder_file_name, unet_file_name, vae_encoder_file_name, + vae_decoder_file_name, + text_encoder_file_name, text_encoder_2_file_name, + cls.sub_component_config_name, SCHEDULER_CONFIG_NAME, - CONFIG_NAME, cls.config_name, } ) @@ -340,14 +324,18 @@ def _from_pretrained( else: sub_models[name] = load_method(new_model_save_dir) - vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model( - vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, - text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, - unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, - vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, - text_encoder_2_path=( + model_paths = { + "unet_path": new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, + "vae_encoder_path": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, + "vae_decoder_path": new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, + "text_encoder_path": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, + "text_encoder_2_path": ( new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name ), + } + + model_sessions = cls.load_model( + **model_paths, provider=provider, session_options=session_options, provider_options=provider_options, @@ -362,29 +350,21 @@ def _from_pretrained( ) return cls( - unet=unet, config=config, - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - scheduler=sub_models.get("scheduler"), - tokenizer=sub_models.get("tokenizer", None), - tokenizer_2=sub_models.get("tokenizer_2", None), - feature_extractor=sub_models.get("feature_extractor", None), + **model_sessions, + **sub_models, use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @classmethod - def _from_transformers( + def _export( cls, model_id: str, - config: Optional[str] = None, - use_auth_token: Optional[Union[bool, str]] = None, + config: Optional[Dict[str, Any]] = None, token: Optional[Union[bool, str]] = None, - revision: str = "main", - force_download: bool = True, + revision: Optional[str] = None, + force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, @@ -395,15 +375,6 @@ def _from_transformers( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTPipeline": - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - if task is None: task = cls._auto_model_to_task(cls.auto_model_class) @@ -866,7 +837,7 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTPipeline, LatentConsistencyMod ] -def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): +def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): for ort_pipeline_class in SUPPORTED_ORT_PIPELINES: if ( ort_pipeline_class.__name__ == pipeline_class_name @@ -879,6 +850,7 @@ def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool class ORTDiffusionPipeline(ConfigMixin): + auto_model_class = DiffusionPipeline config_name = "model_index.json" @classmethod @@ -898,7 +870,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline: config = config[0] if isinstance(config, tuple) else config class_name = config["_class_name"] - ort_pipeline_class = _get_pipeline_class(class_name) + ort_pipeline_class = _get_ort_class(class_name) return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) @@ -933,7 +905,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline: ] -def _get_task_class(mapping, pipeline_class_name): +def _get_task_ort_class(mapping, pipeline_class_name): def _get_model_name(pipeline_class_name): for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS: for model_name, ort_pipeline_class in ort_pipelines_mapping.items(): @@ -954,6 +926,7 @@ def _get_model_name(pipeline_class_name): class ORTPipelineForTask(ConfigMixin): + auto_model_class = DiffusionPipeline config_name = "model_index.json" @classmethod @@ -972,7 +945,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline: config = config[0] if isinstance(config, tuple) else config class_name = config["_class_name"] - ort_pipeline_class = _get_task_class(cls.ort_pipelines_mapping, class_name) + ort_pipeline_class = _get_task_ort_class(cls.ort_pipelines_mapping, class_name) return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs)