Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 18, 2024
1 parent 4933c7c commit 7d50df3
Showing 1 changed file with 68 additions and 95 deletions.
163 changes: 68 additions & 95 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import logging
import os
import shutil
import warnings
from collections import OrderedDict
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -31,6 +30,7 @@
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
DiffusionPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
StableDiffusionImg2ImgPipeline,
Expand All @@ -42,7 +42,7 @@
)
from diffusers.schedulers import SchedulerMixin
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from diffusers.utils import is_invisible_watermark_available
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import validate_hf_hub_args
Expand Down Expand Up @@ -79,7 +79,7 @@
logger = logging.getLogger(__name__)


class ORTPipeline(ORTModel):
class ORTPipeline(ORTModel, ConfigMixin):
auto_model_class = None
model_type = "onnx_pipeline"

Expand All @@ -89,12 +89,12 @@ class ORTPipeline(ORTModel):
def __init__(
self,
config: Dict[str, Any],
scheduler: SchedulerMixin,
unet: ort.InferenceSession,
unet: Optional[ort.InferenceSession] = None,
vae_encoder: Optional[ort.InferenceSession] = None,
vae_decoder: Optional[ort.InferenceSession] = None,
text_encoder: Optional[ort.InferenceSession] = None,
text_encoder_2: Optional[ort.InferenceSession] = None,
scheduler: Optional[SchedulerMixin] = None,
tokenizer: Optional[CLIPTokenizer] = None,
tokenizer_2: Optional[CLIPTokenizer] = None,
feature_extractor: Optional[CLIPFeatureExtractor] = None,
Expand Down Expand Up @@ -151,52 +151,51 @@ def __init__(
if hasattr(self.vae.config, "block_out_channels"):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
else:
self.vae_scale_factor = 8
self.vae_scale_factor = 8 # for old configs without block_out_channels

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)

sub_models = {
# Modify config to keep the resulting model compatible with diffusers pipelines
models_to_subfolder = {
self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER,
self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
}

# Modify config to keep the resulting model compatible with diffusers pipelines
for model, model_name in sub_models.items():
config[model_name] = ("optimum", model.__class__.__name__) if model is not None else (None, None)
for model, model_subfolder in models_to_subfolder.items():
config[model_subfolder] = ("optimum", model.__class__.__name__) if model is not None else (None, None)

self._internal_dict = FrozenDict(config)
self.shared_attributes_init(model=unet, use_io_binding=use_io_binding, model_save_dir=model_save_dir)

@staticmethod
def load_model(
vae_decoder_path: Union[str, Path],
text_encoder_path: Union[str, Path],
unet_path: Union[str, Path],
vae_encoder_path: Optional[Union[str, Path]] = None,
vae_decoder_path: Optional[Union[str, Path]] = None,
text_encoder_path: Optional[Union[str, Path]] = None,
text_encoder_2_path: Optional[Union[str, Path]] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict] = None,
):
"""
Creates three inference sessions for respectively the VAE decoder, the text encoder and the U-NET models.
Creates three inference sessions for the components of a Diffusion Pipeline (U-NET, VAE, Text Encoders).
The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX.
Args:
vae_decoder_path (`Union[str, Path]`):
The path to the VAE decoder ONNX model.
text_encoder_path (`Union[str, Path]`):
The path to the text encoder ONNX model.
unet_path (`Union[str, Path]`):
The path to the U-NET ONNX model.
vae_encoder_path (`Union[str, Path]`, defaults to `None`):
The path to the VAE encoder ONNX model.
vae_decoder_path (`Union[str, Path]`, defaults to `None`):
The path to the VAE decoder ONNX model.
text_encoder_path (`Union[str, Path]`, defaults to `None`):
The path to the text encoder ONNX model.
text_encoder_2_path (`Union[str, Path]`, defaults to `None`):
The path to the second text decoder ONNX model.
provider (`str`, defaults to `"CPUExecutionProvider"`):
Expand All @@ -208,50 +207,48 @@ def load_model(
Provider option dictionary corresponding to the provider used. See available options
for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`.
"""
vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options)
unet = ORTModel.load_model(unet_path, provider, session_options, provider_options)

sessions = {
paths = {
"unet": unet_path,
"vae_encoder": vae_encoder_path,
"vae_decoder": vae_decoder_path,
"text_encoder": text_encoder_path,
"text_encoder_2": text_encoder_2_path,
}

for key, value in sessions.items():
if value is not None and value.is_file():
sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options)
sessions = {}
for model_name, model_path in paths.items():
if model_path is not None and model_path.is_file():
sessions[model_name] = ORTModel.load_model(model_path, provider, session_options, provider_options)
else:
sessions[key] = None
sessions[model_name] = None

return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"]
return sessions

def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory = Path(save_directory)

sub_models_to_save = {
self.unet: DIFFUSION_MODEL_UNET_SUBFOLDER,
self.vae_decoder: DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
self.vae_encoder: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
self.text_encoder: DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
self.text_encoder_2: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
models_to_save_paths = {
self.unet: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.vae_decoder: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.vae_encoder: save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.text_encoder: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
self.text_encoder_2: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / ONNX_WEIGHTS_NAME,
}

for model, model_folder in sub_models_to_save.items():
for model, model_save_path in models_to_save_paths.items():
if model is not None:
model_path = Path(model.session._model_path)
model_save_path = save_directory / model_folder / ONNX_WEIGHTS_NAME
model_save_path.parent.mkdir(parents=True, exist_ok=True)
# copy onnx model
shutil.copyfile(model_path, model_save_path)
# copy external data
# copy external onnx data
external_data_paths = _get_model_external_data_paths(model_path)
for external_data_path in external_data_paths:
shutil.copyfile(external_data_path, model_save_path.parent / external_data_path.name)
# copy config
shutil.copyfile(
model_path.parent / self.sub_component_config_name,
model_save_path.parent / self.sub_component_config_name,
)
# copy model config
config_path = model_path.parent / self.sub_component_config_name
if config_path.is_file():
config_save_path = model_save_path.parent / self.sub_component_config_name
shutil.copyfile(config_path, config_save_path)

self.scheduler.save_pretrained(save_directory / "scheduler")

Expand All @@ -267,35 +264,22 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: Dict[str, Any],
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
vae_decoder_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_file_name: str = ONNX_WEIGHTS_NAME,
token: Optional[Union[bool, str]] = None,
unet_file_name: str = ONNX_WEIGHTS_NAME,
vae_encoder_file_name: str = ONNX_WEIGHTS_NAME,
vae_decoder_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_file_name: str = ONNX_WEIGHTS_NAME,
text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME,
local_files_only: bool = False,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

if provider == "TensorrtExecutionProvider":
raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported")

model_id = str(model_id)
patterns = set(config.keys())
sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"})
Expand All @@ -305,13 +289,13 @@ def _from_pretrained(
allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")}
allow_patterns.update(
{
vae_decoder_file_name,
text_encoder_file_name,
unet_file_name,
vae_encoder_file_name,
vae_decoder_file_name,
text_encoder_file_name,
text_encoder_2_file_name,
cls.sub_component_config_name,
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
}
)
Expand Down Expand Up @@ -340,14 +324,18 @@ def _from_pretrained(
else:
sub_models[name] = load_method(new_model_save_dir)

vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model(
vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
text_encoder_2_path=(
model_paths = {
"unet_path": new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
"vae_encoder_path": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
"vae_decoder_path": new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
"text_encoder_path": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
"text_encoder_2_path": (
new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name
),
}

model_sessions = cls.load_model(
**model_paths,
provider=provider,
session_options=session_options,
provider_options=provider_options,
Expand All @@ -362,29 +350,21 @@ def _from_pretrained(
)

return cls(
unet=unet,
config=config,
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
scheduler=sub_models.get("scheduler"),
tokenizer=sub_models.get("tokenizer", None),
tokenizer_2=sub_models.get("tokenizer_2", None),
feature_extractor=sub_models.get("feature_extractor", None),
**model_sessions,
**sub_models,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
)

@classmethod
def _from_transformers(
def _export(
cls,
model_id: str,
config: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
config: Optional[Dict[str, Any]] = None,
token: Optional[Union[bool, str]] = None,
revision: str = "main",
force_download: bool = True,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
Expand All @@ -395,15 +375,6 @@ def _from_transformers(
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTPipeline":
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token

if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)

Expand Down Expand Up @@ -866,7 +837,7 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTPipeline, LatentConsistencyMod
]


def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True):
def _get_ort_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True):
for ort_pipeline_class in SUPPORTED_ORT_PIPELINES:
if (
ort_pipeline_class.__name__ == pipeline_class_name
Expand All @@ -879,6 +850,7 @@ def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool


class ORTDiffusionPipeline(ConfigMixin):
auto_model_class = DiffusionPipeline
config_name = "model_index.json"

@classmethod
Expand All @@ -898,7 +870,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
config = config[0] if isinstance(config, tuple) else config
class_name = config["_class_name"]

ort_pipeline_class = _get_pipeline_class(class_name)
ort_pipeline_class = _get_ort_class(class_name)

return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs)

Expand Down Expand Up @@ -933,7 +905,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
]


def _get_task_class(mapping, pipeline_class_name):
def _get_task_ort_class(mapping, pipeline_class_name):
def _get_model_name(pipeline_class_name):
for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS:
for model_name, ort_pipeline_class in ort_pipelines_mapping.items():
Expand All @@ -954,6 +926,7 @@ def _get_model_name(pipeline_class_name):


class ORTPipelineForTask(ConfigMixin):
auto_model_class = DiffusionPipeline
config_name = "model_index.json"

@classmethod
Expand All @@ -972,7 +945,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs) -> ORTPipeline:
config = config[0] if isinstance(config, tuple) else config
class_name = config["_class_name"]

ort_pipeline_class = _get_task_class(cls.ort_pipelines_mapping, class_name)
ort_pipeline_class = _get_task_ort_class(cls.ort_pipelines_mapping, class_name)

return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs)

Expand Down

0 comments on commit 7d50df3

Please sign in to comment.