Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Aug 28, 2023
1 parent 83d1725 commit 2f1fe6b
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 608 deletions.
260 changes: 5 additions & 255 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,11 @@

from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..pipelines.diffusers.pipeline_stable_diffusion import (
StableDiffusionPipelineMixin,
)
from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import (
StableDiffusionImg2ImgPipelineMixin,
)
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import (
StableDiffusionInpaintPipelineMixin,
)
from ..pipelines.diffusers.pipeline_stable_diffusion_xl import (
StableDiffusionXLPipelineMixin,
)
from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
StableDiffusionXLImg2ImgPipelineMixin,
)
from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from ..utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
Expand Down Expand Up @@ -551,48 +541,6 @@ class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusion
"""

__call__ = StableDiffusionPipelineMixin.__call__
"""
@add_end_docstrings(STABLE_DIFFUSION_PIPELINE_CALL_DOCSTRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
guidance_rescale: float = 0.0,
):
return StableDiffusionPipelineMixin.__call__(
prompt,
height,
width,
num_inference_steps,
guidance_scale,
negative_prompt,
num_images_per_prompt,
eta,
generator,
latents,
prompt_embeds,
negative_prompt_embeds,
output_type,
return_dict,
callback,
callback_steps,
guidance_rescale,
)
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand All @@ -603,45 +551,6 @@ class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDi

__call__ = StableDiffusionImg2ImgPipelineMixin.__call__

"""
@add_end_docstrings(STABLE_DIFFUSION_PIPELINE_IMG2IMG_CALL_DOCSTRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
StableDiffusionImg2ImgPipelineMixin.__call__(
prompt,
image,
strength,
num_inference_steps,
guidance_scale,
negative_prompt,
num_images_per_prompt,
eta,
generator,
prompt_embeds,
negative_prompt_embeds,
output_type,
return_dict,
callback,
callback_steps,
)
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin):
Expand All @@ -651,51 +560,6 @@ class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDi

__call__ = StableDiffusionInpaintPipelineMixin.__call__

"""
@add_end_docstrings(STABLE_DIFFUSION_PIPELINE_INPAINT_CALL_DOCSTRING)
def __call__(
self,
prompt: Union[str, List[str]],
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
StableDiffusionInpaintPipelineMixin.__call__(
prompt,
image,
mask_image,
height,
width,
num_inference_steps,
guidance_scale,
negative_prompt,
num_images_per_prompt,
eta,
generator,
latents,
prompt_embeds,
negative_prompt_embeds,
output_type,
return_dict,
callback,
callback_steps,
)
"""


class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase):
auto_model_class = StableDiffusionXLImg2ImgPipeline
Expand Down Expand Up @@ -744,61 +608,6 @@ class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffu

__call__ = StableDiffusionXLPipelineMixin.__call__

"""
@add_end_docstrings(STABLE_DIFFUSION_PIPELINE_XL_CALL_DOCSTRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
StableDiffusionXLPipelineMixin.__call__(
prompt,
height,
width,
num_inference_steps,
guidance_scale,
negative_prompt,
num_images_per_prompt,
eta,
generator,
latents,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
output_type,
return_dict,
callback,
callback_steps,
cross_attention_kwargs,
guidance_rescale,
original_size,
crops_coords_top_left,
target_size,
)
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin):
Expand All @@ -807,62 +616,3 @@ class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, Stab
"""

__call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__

"""
@add_end_docstrings(STABLE_DIFFUSION_PIPELINE_XL_IMG2IMG_CALL_DOCSTRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.3,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
):
StableDiffusionXLImg2ImgPipelineMixin.__call__(
prompt,
image,
strength,
num_inference_steps,
guidance_scale,
negative_prompt,
num_images_per_prompt,
eta,
generator,
latents,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
output_type,
return_dict,
callback,
callback_steps,
cross_attention_kwargs,
guidance_rescale,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
)
"""
Loading

0 comments on commit 2f1fe6b

Please sign in to comment.