diff --git a/optimum/onnx/utils.py b/optimum/onnx/utils.py index b52c4f4cda..c014c1b342 100644 --- a/optimum/onnx/utils.py +++ b/optimum/onnx/utils.py @@ -71,6 +71,22 @@ def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tu return src_paths, dst_paths +def _get_model_external_data_paths(model_path: Path) -> List[Path]: + """ + Gets external data paths from the model. + """ + + onnx_model = onnx.load(str(model_path), load_external_data=False) + model_tensors = _get_initializer_tensors(onnx_model) + # filter out tensors that are not external data + model_tensors_ext = [ + ExternalDataInfo(tensor).location + for tensor in model_tensors + if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL + ] + return [model_path.parent / tensor_name for tensor_name in model_tensors_ext] + + def check_model_uses_external_data(model: onnx.ModelProto) -> bool: """ Checks if the model uses external data. diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index a0a33d3cb7..e91142dfa8 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -24,14 +24,15 @@ import numpy as np import torch -from diffusers import ( +from diffusers.configuration_utils import ConfigMixin, FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.pipelines import ( AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, - ConfigMixin, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, - SchedulerMixin, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, @@ -39,9 +40,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) -from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from diffusers.schedulers import SchedulerMixin from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from huggingface_hub import snapshot_download @@ -54,7 +53,7 @@ import onnxruntime as ort from ..exporters.onnx import main_export -from ..onnx.utils import _get_external_data_paths +from ..onnx.utils import _get_model_external_data_paths from ..utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, @@ -90,15 +89,15 @@ class ORTPipeline(ORTModel): def __init__( self, config: Dict[str, Any], - tokenizer: CLIPTokenizer, scheduler: SchedulerMixin, - unet_session: ort.InferenceSession, - feature_extractor: Optional[CLIPFeatureExtractor] = None, - vae_encoder_session: Optional[ort.InferenceSession] = None, - vae_decoder_session: Optional[ort.InferenceSession] = None, - text_encoder_session: Optional[ort.InferenceSession] = None, - text_encoder_2_session: Optional[ort.InferenceSession] = None, + unet: ort.InferenceSession, + vae_encoder: Optional[ort.InferenceSession] = None, + vae_decoder: Optional[ort.InferenceSession] = None, + text_encoder: Optional[ort.InferenceSession] = None, + text_encoder_2: Optional[ort.InferenceSession] = None, + tokenizer: Optional[CLIPTokenizer] = None, tokenizer_2: Optional[CLIPTokenizer] = None, + feature_extractor: Optional[CLIPFeatureExtractor] = None, use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, @@ -114,13 +113,13 @@ def __init__( for the text encoder. scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. - unet_session (`ort.InferenceSession`): + unet (`ort.InferenceSession`): The ONNX Runtime inference session associated to the U-NET. feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` - vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): + vae_encoder (`Optional[ort.InferenceSession]`, defaults to `None`): The ONNX Runtime inference session associated to the VAE encoder. - text_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): + text_encoder (`Optional[ort.InferenceSession]`, defaults to `None`): The ONNX Runtime inference session associated to the text encoder. tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): Tokenizer of class @@ -133,36 +132,11 @@ def __init__( The directory under which the model exported to ONNX was saved. """ - # Text encoder - if text_encoder_session is not None: - self.text_encoder_model_path = Path(text_encoder_session._model_path) - self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) - else: - self.text_encoder_model_path = None - self.text_encoder = None - - # U-Net - self.unet = ORTModelUnet(unet_session, self) - self.unet_model_path = Path(unet_session._model_path) - - # Text encoder 2 - if text_encoder_2_session is not None: - self.text_encoder_2_model_path = Path(text_encoder_2_session._model_path) - self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, self) - else: - self.text_encoder_2_model_path = None - self.text_encoder_2 = None - - # VAE - self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) - self.vae_decoder_model_path = Path(vae_decoder_session._model_path) - - if vae_encoder_session is not None: - self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) - self.vae_encoder_model_path = Path(vae_encoder_session._model_path) - else: - self.vae_encoder = None - self.vae_encoder_model_path = None + self.unet = ORTModelUnet(unet, self) + self.vae_encoder = ORTModelVaeEncoder(vae_encoder, self) if vae_encoder is not None else None + self.vae_decoder = ORTModelVaeDecoder(vae_decoder, self) if vae_decoder is not None else None + self.text_encoder = ORTModelTextEncoder(text_encoder, self) if text_encoder is not None else None + self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2, self) if text_encoder_2 is not None else None # We create VAE encoder & decoder and wrap them in one object to # be used by the pipeline mixins with minimal code changes (simulating the diffusers API) @@ -185,19 +159,19 @@ def __init__( ) sub_models = { - DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet, - DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder, - DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder, - DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder, - DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2, + self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER, + self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, + self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, } # Modify config to keep the resulting model compatible with diffusers pipelines - for name in sub_models.keys(): - config[name] = ("diffusers", "OnnxRuntimeModel") if sub_models[name] is not None else (None, None) + for model, model_name in sub_models.items(): + config[model_name] = ("optimum", model.__class__.__name__) if model is not None else (None, None) self._internal_dict = FrozenDict(config) - self.shared_attributes_init(model=unet_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir) + self.shared_attributes_init(model=unet, use_io_binding=use_io_binding, model_save_dir=model_save_dir) @staticmethod def load_model( @@ -253,41 +227,37 @@ def load_model( def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) - src_to_dst_path = { - self.vae_decoder_model_path: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME, - self.text_encoder_model_path: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, - self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, - } sub_models_to_save = { - self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, - self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER, + self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, + self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, } - for path, subfolder in sub_models_to_save.items(): - if path is not None: - src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME - - # TODO: Modify _get_external_data_paths to give dictionnary - src_paths = list(src_to_dst_path.keys()) - dst_paths = list(src_to_dst_path.values()) - # Add external data paths in case of large models - src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths) - - for src_path, dst_path in zip(src_paths, dst_paths): - dst_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copyfile(src_path, dst_path) - config_path = src_path.parent / self.sub_component_config_name - if config_path.is_file(): - shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name) + + for model, model_folder in sub_models_to_save.items(): + if model is not None: + model_path = Path(model.session._model_path) + model_save_path = save_directory / model_folder / ONNX_WEIGHTS_NAME + model_save_path.parent.mkdir(parents=True, exist_ok=True) + # copy onnx model + shutil.copyfile(model_path, model_save_path) + # copy external data + external_data_paths = _get_model_external_data_paths(model_path) + for external_data_path in external_data_paths: + shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name) + # copy config + shutil.copyfile(model_path.parent / CONFIG_NAME, model_save_path.parent / CONFIG_NAME) self.scheduler.save_pretrained(save_directory / "scheduler") - if self.feature_extractor is not None: - self.feature_extractor.save_pretrained(save_directory / "feature_extractor") if self.tokenizer is not None: self.tokenizer.save_pretrained(save_directory / "tokenizer") if self.tokenizer_2 is not None: self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") + if self.feature_extractor is not None: + self.feature_extractor.save_pretrained(save_directory / "feature_extractor") @classmethod def _from_pretrained( @@ -389,16 +359,16 @@ def _from_pretrained( ) return cls( - vae_decoder_session=vae_decoder, - text_encoder_session=text_encoder, - unet_session=unet, + unet=unet, config=config, - tokenizer=sub_models.get("tokenizer", None), + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, scheduler=sub_models.get("scheduler"), - feature_extractor=sub_models.get("feature_extractor", None), + tokenizer=sub_models.get("tokenizer", None), tokenizer_2=sub_models.get("tokenizer_2", None), - vae_encoder_session=vae_encoder, - text_encoder_2_session=text_encoder_2, + feature_extractor=sub_models.get("feature_extractor", None), use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @@ -482,13 +452,20 @@ def to(self, device: Union[torch.device, str, int]): if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": return self - self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) - self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) self.unet.session.set_providers([provider], provider_options=[provider_options]) if self.vae_encoder is not None: self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) + if self.vae_decoder is not None: + self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) + + if self.text_encoder is not None: + self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) + + if self.text_encoder_2 is not None: + self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options]) + self.providers = self.vae_decoder.session.get_providers() self._device = device