Skip to content

Commit

Permalink
Add Stable Diffusion XL inference support (#212)
Browse files Browse the repository at this point in the history
* fix sdxl unet inf

* inference done

* add post processing and doc

* fix style

* Update docs/source/guides/models.mdx

Co-authored-by: Pedro Cuenca <[email protected]>

* add test

* update doc prompt

* fix num images per prompt issue

* fix test

* remove useless

* Update optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py

Co-authored-by: Michael Benayoun <[email protected]>

* fix docstring

* update image

* fix

* remove text encoder 2 empty folder

---------

Co-authored-by: JingyaHuang <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
4 people authored Sep 8, 2023
1 parent 5b531fd commit fd29acd
Show file tree
Hide file tree
Showing 16 changed files with 1,015 additions and 122 deletions.
Binary file added docs/assets/guides/models/02-sdxl-image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 33 additions & 1 deletion docs/source/guides/export_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,39 @@ optimum-cli export neuron --model stabilityai/stable-diffusion-2-1-base \
--batch_size 1 \
--height 512 `# height in pixels of generated image, eg. 512, 768` \
--width 512 `# width in pixels of generated image, eg. 512, 768` \
--num_image_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \
--num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \
--auto_cast matmul `# cast only matrix multiplication operations` \
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
sd_neuron/
```
## Exporting Stable Diffusion XL to Neuron
Similar to Stable Diffusion, you will be able to use Optimum CLI to compile components in the SDXL pipeline for inference on neuron devices.
We support the export of following components in the pipeline to boost the speed:
* Text encoder
* Second text encoder
* U-Net (a three times larger UNet than the one in Stable Diffusion pipeline)
* VAE encoder
* VAE decoder
<Tip>
"Stable Diffusion XL works especially well with images between 768 and 1024."
</Tip>
Exporting a SDXL checkpoint can be done using the CLI:
```bash
optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
--task stable-diffusion-xl \
--batch_size 1 \
--height 1024 `# height in pixels of generated image, eg. 768, 1024` \
--width 1024 `# width in pixels of generated image, eg. 768, 1024` \
--num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \
--auto_cast matmul `# cast only matrix multiplication operations` \
--auto_cast_type bf16 `# cast operations from FP32 to BF16` \
sd_neuron/
Expand Down
41 changes: 37 additions & 4 deletions docs/source/guides/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ And the next time when you want to run inference, just load your compiled model
As you see, there is no need to pass the neuron arguments used during the export as they are
saved in a `config.json` file, and will be restored automatically by `NeuronModelForXXX` class.

## Export and inference of Discriminative NLP models
## Discriminative NLP models

As explained in the previous section, you will need only few modifications to your Transformers code to export and run NLP models:

Expand Down Expand Up @@ -133,7 +133,7 @@ No worries, `NeuronModelForXXX` class will pad your inputs to an eligible shape.

</Tip>

## Export and inference of Generative NLP models
## Generative NLP models

As explained before, you will need only a few modifications to your Transformers code to export and run NLP models:

Expand Down Expand Up @@ -196,7 +196,7 @@ with torch.inference_mode():
print(outputs)
```

## Inference of Stable Diffusion Models
## Stable Diffusion

Optimum extends 🤗`Diffusers` to support inference on Neuron. To get started, make sure you have installed Diffusers:

Expand Down Expand Up @@ -247,7 +247,40 @@ Now generate an image with a prompt on neuron:

<img
src="https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/guides/models/01-sd-image.png"
alt="search ami"
alt="stable diffusion generated image"
/>

## Stable Diffusion XL

Similar to Stable Diffusion, you will be able to use `NeuronStableDiffusionXLPipeline` API to export and run inference on Neuron devices with SDXL models.

```python
>>> from optimum.neuron import NeuronStableDiffusionXLPipeline

>>> model_id = "stabilityai/stable-diffusion-xl-base-1.0"
>>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
>>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024}

>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained(model_id, export=True, **compiler_args, **input_shapes)

# Save locally or upload to the HuggingFace Hub
>>> save_directory = "sd_neuron_xl/"
>>> stable_diffusion_xl.save_pretrained(save_directory)
>>> stable_diffusion_xl.push_to_hub(
... save_directory, repository_id="my-neuron-repo", use_auth_token=True
... )
```

Now generate an image with a prompt on neuron:

```python
>>> prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
>>> image = stable_diffusion_xl(prompt).images[0]
```

<img
src="https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/guides/models/02-sdxl-image.jpeg"
alt="sdxl generated image"
/>


Expand Down
7 changes: 6 additions & 1 deletion optimum/exporters/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .__main__ import main_export, normalize_input_shapes, normalize_stable_diffusion_input_shapes
from .__main__ import (
infer_stable_diffusion_shapes_from_diffusers,
main_export,
normalize_input_shapes,
normalize_stable_diffusion_input_shapes,
)
from .base import NeuronConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
Expand Down
47 changes: 23 additions & 24 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,20 @@ def normalize_stable_diffusion_input_shapes(
) -> Dict[str, Dict[str, int]]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args"))
# Remove `sequence_length` as diffusers will pad it to the max and remove number of channels .
# Remove `sequence_length` as diffusers will pad it to the max and remove number of channels.
mandatory_axes = mandatory_axes - {
"sequence_length",
"unet_num_channels",
"vae_encoder_num_channels",
"vae_decoder_num_channels",
"num_images_per_prompt", # default to 1
}
if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
)
mandatory_shapes = {name: args[name] for name in mandatory_axes}
if "num_images_per_prompt" in args and args["num_images_per_prompt"] > 1:
batch_size = args["num_images_per_prompt"] * args["batch_size"]
mandatory_shapes["batch_size"] = batch_size
mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1)
input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes)
return input_shapes

Expand Down Expand Up @@ -184,18 +183,19 @@ def main_export(

task = TasksManager.map_from_synonym(task)

model = TasksManager.get_model_from_task(
task=task,
model_name_or_path=model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
framework="pt",
)
model_kwargs = {
"task": task,
"model_name_or_path": model_name_or_path,
"subfolder": subfolder,
"revision": revision,
"cache_dir": cache_dir,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
"framework": "pt",
}
model = TasksManager.get_model_from_task(**model_kwargs)

is_stable_diffusion = "stable-diffusion" in task
if not is_stable_diffusion:
Expand All @@ -216,12 +216,6 @@ def main_export(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)
models_and_neuron_configs = get_stable_diffusion_models_for_export(
pipeline=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
**input_shapes,
)

# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
Expand All @@ -232,6 +226,12 @@ def main_export(
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

models_and_neuron_configs = get_stable_diffusion_models_for_export(
pipeline=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
**input_shapes,
)
output_model_names = {
DIFFUSION_MODEL_TEXT_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
Expand All @@ -242,6 +242,7 @@ def main_export(
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
del model

_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
Expand All @@ -250,8 +251,6 @@ def main_export(
compiler_kwargs=compiler_kwargs,
)

del model

# Validate compiled model
if do_validation is True:
if is_stable_diffusion:
Expand Down
27 changes: 25 additions & 2 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,28 @@ def generate_dummy_inputs(
else:
return dummy_inputs

@classmethod
def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Flatten nested structure in dummy inputs, e.g `addition_embed_type` of unet model.
"""
flatten = {}
for name, value in inputs.items():
if isinstance(value, dict):
for sub_name, sub_value in value.items():
flatten[sub_name] = sub_value
else:
flatten[name] = value
return flatten

def check_model_inputs_order(
self,
model: "PreTrainedModel",
dummy_inputs: Dict[str, torch.Tensor],
dummy_inputs: Optional[Dict[str, torch.Tensor]] = None,
forward_with_tuple: bool = False,
eligible_outputs: Optional[List[Union[str, int]]] = None,
custom_model_wrapper: Optional[torch.nn.Module] = None,
custom_wrapper_kwargs: Optional[Dict] = None,
):
"""
Checks if inputs order of the model's forward pass correspond to the generated dummy inputs to ensure the dummy inputs tuple used for
Expand Down Expand Up @@ -320,7 +336,14 @@ def forward(self, *input):

return outputs

return ModelWrapper(model, list(dummy_inputs.keys()))
if custom_model_wrapper:
return (
custom_model_wrapper(model)
if custom_wrapper_kwargs is None
else custom_model_wrapper(model, **custom_wrapper_kwargs)
)
else:
return ModelWrapper(model, list(dummy_inputs.keys()))


class NeuronDecoderConfig(ExportConfig):
Expand Down
39 changes: 27 additions & 12 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Neuron compiled model check and export functions."""

import copy
import os
import time
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from transformers import PretrainedConfig

from ...exporters.error_utils import OutputMatchError, ShapeError
Expand Down Expand Up @@ -178,7 +177,7 @@ def validate_model_outputs(
neuron_inputs = ref_inputs
else:
ref_outputs = reference_model(**ref_inputs)
neuron_inputs = tuple(ref_inputs.values())
neuron_inputs = tuple(config.flatten_inputs(ref_inputs).values())

# Neuron outputs
neuron_model = torch.jit.load(neuron_model_path)
Expand Down Expand Up @@ -292,6 +291,7 @@ def export_models(
)

failed_models = []
total_compilation_time = 0
for i, model_name in enumerate(models_and_neuron_configs.keys()):
logger.info(f"***** Compiling {model_name} *****")
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
Expand All @@ -311,6 +311,7 @@ def export_models(
**compiler_kwargs,
)
compilation_time = time.time() - start_time
total_compilation_time += compilation_time
logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.")
outputs.append((neuron_inputs, neuron_outputs))
# Add neuron specific configs to model components' original config
Expand Down Expand Up @@ -347,6 +348,7 @@ def export_models(
f"An error occured when trying to trace {model_name} with the error message: {e}.\n"
f"The export is failed and {model_name} neuron model won't be stored."
)
logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")

# remove models failed to export
for i, model_name in failed_models:
Expand Down Expand Up @@ -421,6 +423,7 @@ def export_neuronx(
input_shapes[axis] = getattr(config, axis)

dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs = config.flatten_inputs(dummy_inputs)
dummy_inputs_tuple = tuple(dummy_inputs.values())
checked_model = config.check_model_inputs_order(model, dummy_inputs)

Expand All @@ -435,6 +438,10 @@ def export_neuronx(
else:
compiler_args = ["--auto-cast", "none"]

# WARNING: Enabled experimental parallel compilation
compiler_args.extend(["--enable-experimental-O1"])
compiler_args.extend(["--num-parallel-jobs", str(os.cpu_count())])

# diffusers specific
compiler_args = add_stable_diffusion_compiler_args(config, compiler_args)

Expand All @@ -447,31 +454,36 @@ def export_neuronx(
improve_stable_diffusion_loading(config, neuron_model)

torch.jit.save(neuron_model, output)
del model
del checked_model
del dummy_inputs
del neuron_model

return config.inputs, config.outputs


def add_stable_diffusion_compiler_args(config, compiler_args):
if hasattr(config._config, "_name_or_path"):
sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"]
sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder"]
if any(component in config._config._name_or_path.lower() for component in sd_components):
compiler_args.extend(["--enable-fast-loading-neuron-binaries"])
# unet
if "unet" in config._config._name_or_path.lower():
# SDXL unet doesn't support fast loading neuron binaries
if "stable-diffusion-xl" not in config._config._name_or_path.lower():
compiler_args.extend(["--enable-fast-loading-neuron-binaries"])
compiler_args.extend(["--model-type=unet-inference"])
return compiler_args


def improve_stable_diffusion_loading(config, neuron_model):
if version.parse(neuronx.__version__) >= version.parse("1.13.1.1.9.0"):
if hasattr(config._config, "_name_or_path"):
sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"]
if any(component in config._config._name_or_path.lower() for component in sd_components):
neuronx.async_load(neuron_model)
# unet
if "unet" in config._config._name_or_path.lower():
neuronx.lazy_load(neuron_model)
if hasattr(config._config, "_name_or_path"):
sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"]
if any(component in config._config._name_or_path.lower() for component in sd_components):
neuronx.async_load(neuron_model)
# unet
if "unet" in config._config._name_or_path.lower():
neuronx.lazy_load(neuron_model)


def export_neuron(
Expand Down Expand Up @@ -537,6 +549,9 @@ def export_neuron(
fallback=not disable_fallback,
)
torch.jit.save(neuron_model, output)
del model
del checked_model
del dummy_inputs
del neuron_model

return config.inputs, config.outputs
Loading

0 comments on commit fd29acd

Please sign in to comment.