Skip to content

Commit

Permalink
simplify model saving
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 16, 2024
1 parent b70b641 commit 7f77b1c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 89 deletions.
16 changes: 16 additions & 0 deletions optimum/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def _get_external_data_paths(src_paths: List[Path], dst_paths: List[Path]) -> Tu
return src_paths, dst_paths


def _get_model_external_data_paths(model_path: Path) -> List[Path]:
"""
Gets external data paths from the model.
"""

onnx_model = onnx.load(str(model_path), load_external_data=False)
model_tensors = _get_initializer_tensors(onnx_model)
# filter out tensors that are not external data
model_tensors_ext = [
ExternalDataInfo(tensor).location
for tensor in model_tensors
if tensor.HasField("data_location") and tensor.data_location == onnx.TensorProto.EXTERNAL
]
return [model_path.parent / tensor_name for tensor_name in model_tensors_ext]


def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
"""
Checks if the model uses external data.
Expand Down
155 changes: 66 additions & 89 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,23 @@

import numpy as np
import torch
from diffusers import (
from diffusers.configuration_utils import ConfigMixin, FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.pipelines import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ConfigMixin,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
SchedulerMixin,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.schedulers import SchedulerMixin
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from huggingface_hub import snapshot_download
Expand All @@ -54,7 +53,7 @@
import onnxruntime as ort

from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..onnx.utils import _get_model_external_data_paths
from ..utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
Expand Down Expand Up @@ -90,15 +89,15 @@ class ORTPipeline(ORTModel):
def __init__(
self,
config: Dict[str, Any],
tokenizer: CLIPTokenizer,
scheduler: SchedulerMixin,
unet_session: ort.InferenceSession,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
vae_encoder_session: Optional[ort.InferenceSession] = None,
vae_decoder_session: Optional[ort.InferenceSession] = None,
text_encoder_session: Optional[ort.InferenceSession] = None,
text_encoder_2_session: Optional[ort.InferenceSession] = None,
unet: ort.InferenceSession,
vae_encoder: Optional[ort.InferenceSession] = None,
vae_decoder: Optional[ort.InferenceSession] = None,
text_encoder: Optional[ort.InferenceSession] = None,
text_encoder_2: Optional[ort.InferenceSession] = None,
tokenizer: Optional[CLIPTokenizer] = None,
tokenizer_2: Optional[CLIPTokenizer] = None,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
Expand All @@ -114,13 +113,13 @@ def __init__(
for the text encoder.
scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`):
A scheduler to be used in combination with the U-NET component to denoise the encoded image latents.
unet_session (`ort.InferenceSession`):
unet (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the U-NET.
feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`):
A model extracting features from generated images to be used as inputs for the `safety_checker`
vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`):
vae_encoder (`Optional[ort.InferenceSession]`, defaults to `None`):
The ONNX Runtime inference session associated to the VAE encoder.
text_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`):
text_encoder (`Optional[ort.InferenceSession]`, defaults to `None`):
The ONNX Runtime inference session associated to the text encoder.
tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`):
Tokenizer of class
Expand All @@ -133,36 +132,11 @@ def __init__(
The directory under which the model exported to ONNX was saved.
"""

# Text encoder
if text_encoder_session is not None:
self.text_encoder_model_path = Path(text_encoder_session._model_path)
self.text_encoder = ORTModelTextEncoder(text_encoder_session, self)
else:
self.text_encoder_model_path = None
self.text_encoder = None

# U-Net
self.unet = ORTModelUnet(unet_session, self)
self.unet_model_path = Path(unet_session._model_path)

# Text encoder 2
if text_encoder_2_session is not None:
self.text_encoder_2_model_path = Path(text_encoder_2_session._model_path)
self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, self)
else:
self.text_encoder_2_model_path = None
self.text_encoder_2 = None

# VAE
self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self)
self.vae_decoder_model_path = Path(vae_decoder_session._model_path)

if vae_encoder_session is not None:
self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self)
self.vae_encoder_model_path = Path(vae_encoder_session._model_path)
else:
self.vae_encoder = None
self.vae_encoder_model_path = None
self.unet = ORTModelUnet(unet, self)
self.vae_encoder = ORTModelVaeEncoder(vae_encoder, self) if vae_encoder is not None else None
self.vae_decoder = ORTModelVaeDecoder(vae_decoder, self) if vae_decoder is not None else None
self.text_encoder = ORTModelTextEncoder(text_encoder, self) if text_encoder is not None else None
self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2, self) if text_encoder_2 is not None else None

# We create VAE encoder & decoder and wrap them in one object to
# be used by the pipeline mixins with minimal code changes (simulating the diffusers API)
Expand All @@ -185,19 +159,19 @@ def __init__(
)

sub_models = {
DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet,
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder,
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2,
self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER,
self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
}

# Modify config to keep the resulting model compatible with diffusers pipelines
for name in sub_models.keys():
config[name] = ("diffusers", "OnnxRuntimeModel") if sub_models[name] is not None else (None, None)
for model, model_name in sub_models.items():
config[model_name] = ("optimum", model.__class__.__name__) if model is not None else (None, None)

self._internal_dict = FrozenDict(config)
self.shared_attributes_init(model=unet_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir)
self.shared_attributes_init(model=unet, use_io_binding=use_io_binding, model_save_dir=model_save_dir)

@staticmethod
def load_model(
Expand Down Expand Up @@ -253,41 +227,37 @@ def load_model(

def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory = Path(save_directory)
src_to_dst_path = {
self.vae_decoder_model_path: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.text_encoder_model_path: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME,
}

sub_models_to_save = {
self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER,
self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
}
for path, subfolder in sub_models_to_save.items():
if path is not None:
src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME

# TODO: Modify _get_external_data_paths to give dictionnary
src_paths = list(src_to_dst_path.keys())
dst_paths = list(src_to_dst_path.values())
# Add external data paths in case of large models
src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths)

for src_path, dst_path in zip(src_paths, dst_paths):
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src_path, dst_path)
config_path = src_path.parent / self.sub_component_config_name
if config_path.is_file():
shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name)

for model, model_folder in sub_models_to_save.items():
if model is not None:
model_path = Path(model.session._model_path)
model_save_path = save_directory / model_folder / ONNX_WEIGHTS_NAME
model_save_path.parent.mkdir(parents=True, exist_ok=True)
# copy onnx model
shutil.copyfile(model_path, model_save_path)
# copy external data
external_data_paths = _get_model_external_data_paths(model_path)
for external_data_path in external_data_paths:
shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name)
# copy config
shutil.copyfile(model_path.parent / CONFIG_NAME, model_save_path.parent / CONFIG_NAME)

self.scheduler.save_pretrained(save_directory / "scheduler")

if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -389,16 +359,16 @@ def _from_pretrained(
)

return cls(
vae_decoder_session=vae_decoder,
text_encoder_session=text_encoder,
unet_session=unet,
unet=unet,
config=config,
tokenizer=sub_models.get("tokenizer", None),
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
scheduler=sub_models.get("scheduler"),
feature_extractor=sub_models.get("feature_extractor", None),
tokenizer=sub_models.get("tokenizer", None),
tokenizer_2=sub_models.get("tokenizer_2", None),
vae_encoder_session=vae_encoder,
text_encoder_2_session=text_encoder_2,
feature_extractor=sub_models.get("feature_extractor", None),
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
)
Expand Down Expand Up @@ -482,13 +452,20 @@ def to(self, device: Union[torch.device, str, int]):
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
self.unet.session.set_providers([provider], provider_options=[provider_options])

if self.vae_encoder is not None:
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])

if self.vae_decoder is not None:
self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])

if self.text_encoder is not None:
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])

if self.text_encoder_2 is not None:
self.text_encoder_2.session.set_providers([provider], provider_options=[provider_options])

self.providers = self.vae_decoder.session.get_providers()
self._device = device

Expand Down

0 comments on commit 7f77b1c

Please sign in to comment.