Skip to content

Commit

Permalink
created auto task mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 16, 2024
1 parent b865809 commit fcb1690
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import shutil
import warnings
from abc import abstractmethod
from collections import OrderedDict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Optional, Union
Expand All @@ -26,6 +27,7 @@
import torch
from diffusers import (
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
Expand Down Expand Up @@ -69,8 +71,8 @@
logger = logging.getLogger(__name__)


class ORTStableDiffusionPipelineBase(ORTModel):
auto_model_class = StableDiffusionPipeline
class ORTDiffusionPipeline(ORTModel):
auto_model_class = DiffusionPipeline
main_input_name = "input_ids"
base_model_prefix = "onnx_model"
config_name = "model_index.json"
Expand Down Expand Up @@ -350,9 +352,9 @@ def _from_pretrained(
text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
text_encoder_2_path=new_model_save_dir
/ DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER
/ text_encoder_2_file_name,
text_encoder_2_path=(
new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name
),
provider=provider,
session_options=session_options,
provider_options=provider_options,
Expand Down Expand Up @@ -561,7 +563,7 @@ def forward(self, sample: np.ndarray):


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusionPipelineMixin):
class ORTStableDiffusionPipeline(ORTDiffusionPipeline, StableDiffusionPipelineMixin):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline).
"""
Expand All @@ -570,7 +572,7 @@ class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusion


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin):
class ORTStableDiffusionImg2ImgPipeline(ORTDiffusionPipeline, StableDiffusionImg2ImgPipelineMixin):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline).
"""
Expand All @@ -579,7 +581,7 @@ class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDi


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin):
class ORTStableDiffusionInpaintPipeline(ORTDiffusionPipeline, StableDiffusionInpaintPipelineMixin):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline).
"""
Expand All @@ -588,15 +590,15 @@ class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDi


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTLatentConsistencyModelPipeline(ORTStableDiffusionPipelineBase, LatentConsistencyPipelineMixin):
class ORTLatentConsistencyModelPipeline(ORTDiffusionPipeline, LatentConsistencyPipelineMixin):
"""
ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline).
"""

__call__ = LatentConsistencyPipelineMixin.__call__


class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase):
class ORTStableDiffusionXLPipelineBase(ORTDiffusionPipeline):
auto_model_class = StableDiffusionXLImg2ImgPipeline

def __init__(
Expand Down Expand Up @@ -661,3 +663,24 @@ class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, Stab
"""

__call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__


AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", ORTStableDiffusionPipeline),
("stable-diffusion-xl", ORTStableDiffusionXLPipeline),
]
)

AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", ORTStableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline),
]
)

AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", ORTStableDiffusionInpaintPipeline),
]
)

0 comments on commit fcb1690

Please sign in to comment.