Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable Diffusion ControlNet support #622

Merged
merged 22 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check_code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
- name: Check style with ruff
run: |
source venv/bin/activate
ruff .
ruff check .
5 changes: 5 additions & 0 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ The following Neuron model classes are available for stable diffusion tasks.
[[autodoc]] modeling_diffusion.NeuronLatentConsistencyModelPipeline
- __call__

### NeuronStableDiffusionControlNetPipeline

[[autodoc]] modeling_diffusion.NeuronStableDiffusionControlNetPipeline
- __call__

### NeuronStableDiffusionXLPipeline

[[autodoc]] modeling_diffusion.NeuronStableDiffusionXLPipeline
Expand Down
85 changes: 85 additions & 0 deletions docs/source/tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -516,4 +516,89 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
alt="stable diffusion generated image with LoRA adapter."
/>


## ControlNet

ControlNet conditions the stable diffusion model with an additional input image. In Optimum Neuron, we support the compilation of one or multiple ControlNet(s) along with the stable diffusion checkpoint. The you can use the compiled artifacts to generate styled images.

### Compile ControlNet

We can either compile one or multiple ControlNet via the Optimum CLI or programatically via the `NeuronStableDiffusionControlNetPipeline` class by passing the `controlnet_ids`.

* Export via the Optimum CLI

```bash
optimum-cli export neuron -m runwayml/stable-diffusion-v1-5 --task stable-diffusion --batch_size 1 --height 512 --width 512 --controlnet_ids lllyasviel/sd-controlnet-canny --num_images_per_prompt 1 sd_neuron_controlnet/
```

* Export via Python API

```python
from optimum.neuron import NeuronStableDiffusionControlNetPipeline

model_id = "runwayml/stable-diffusion-v1-5"
controlnet_id = "lllyasviel/sd-controlnet-canny"

# [Neuron] pipeline
input_shapes = {"batch_size": 1, "height": 512, "width": 512, "num_images_per_prompt": 1}
compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained(
model_id,
controlnet_ids=controlnet_id,
export=True,
**input_shapes,
**compiler_args,
)
pipe.save_pretrained("sd_neuron_controlnet")
```

### Text-to-Image

For text-to-image, we can specify an additional conditioning input.

Here is an example with a canny image, a white outline of an image on a black background. The ControlNet will use the canny image as a control to guide the model to generate an image with the same outline.

```python
import cv2
import numpy as np
from diffusers import UniPCMultistepScheduler
from diffusers.utils import load_image, make_image_grid
from PIL import Image

from optimum.neuron import NeuronStableDiffusionControlNetPipeline


# prepare canny image
original_image = load_image(
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)

image = np.array(original_image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

# load pre-compiled neuron model
pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained("sd_neuron_controlnet")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

# inference
output = pipe("the mona lisa", image=canny_image).images[0]
compare = make_image_grid([original_image, canny_image, output], rows=1, cols=3)
compare.save("compare.png")
```

<img
src="https://huggingface.co/datasets/optimum/documentation-images/resolve/main/neuron/models/10-sd-text2img-controlnet.png?download=true"
width="768"
height="256"
alt="stable diffusion 1.5 generated image with controlnet."
/>


Are there any other stable diffusion features that you want us to support in 🤗`Optimum-neuron`? Please file an issue to [`Optimum-neuron` Github repo](https://github.com/huggingface/optimum-neuron) or discuss with us on [HuggingFace’s community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 !
7 changes: 7 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=float,
help="List of scaling factors for the lora adapters.",
)
optional_group.add_argument(
"--controlnet_ids",
default=None,
nargs="*",
type=str,
help="List of model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`) of ControlNet models.",
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
Expand Down
45 changes: 42 additions & 3 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...neuron.utils import (
DECODER_NAME,
DIFFUSION_MODEL_CONTROLNET_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
Expand All @@ -51,6 +52,7 @@
check_mandatory_input_shapes,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
load_controlnets,
replace_stable_diffusion_submodels,
)

Expand All @@ -74,7 +76,7 @@
from transformers import PreTrainedModel

if is_diffusers_available():
from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline
from diffusers import ControlNetModel, DiffusionPipeline, ModelMixin, StableDiffusionPipeline


logger = logging.get_logger()
Expand Down Expand Up @@ -205,6 +207,7 @@ def normalize_stable_diffusion_input_shapes(
def infer_stable_diffusion_shapes_from_diffusers(
input_shapes: Dict[str, Dict[str, int]],
model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"],
controlnets: Optional[List["ControlNetModel"]] = None,
):
if model.tokenizer is not None:
sequence_length = model.tokenizer.model_max_length
Expand Down Expand Up @@ -232,11 +235,24 @@ def infer_stable_diffusion_shapes_from_diffusers(
"width": scaled_width,
}
)
input_shapes["unet"]["vae_scale_factor"] = vae_scale_factor
input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width})
input_shapes["vae_decoder"].update(
{"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width}
)

# ControlNet
if controlnets:
input_shapes["controlnet"] = {
"batch_size": input_shapes["unet"]["batch_size"],
"sequence_length": sequence_length,
"num_channels": unet_num_channels,
"height": scaled_height,
"width": scaled_width,
"vae_scale_factor": vae_scale_factor,
"encoder_hidden_size": model.text_encoder.config.hidden_size,
}

return input_shapes


Expand All @@ -256,6 +272,7 @@ def get_submodels_and_neuron_configs(
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnets: Optional[List["ControlNetModel"]] = None,
):
is_stable_diffusion = "stable-diffusion" in task
is_encoder_decoder = (
Expand All @@ -278,6 +295,7 @@ def get_submodels_and_neuron_configs(
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnets=controlnets,
)
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
Expand Down Expand Up @@ -338,14 +356,19 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnets: Optional[List["ControlNetModel"]] = None,
):
check_compiler_compatibility_for_stable_diffusion()
model = replace_stable_diffusion_submodels(model, submodels)
if is_neuron_available():
raise RuntimeError(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(
input_shapes=input_shapes,
model=model,
controlnets=controlnets,
)

# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
Expand Down Expand Up @@ -373,6 +396,8 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnets=controlnets,
controlnet_input_shapes=input_shapes.get("controlnet", None),
)
output_model_names = {
DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
Expand All @@ -387,7 +412,15 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)

# ControlNet models
if controlnets:
for idx in range(len(controlnets)):
controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx)
output_model_names[controlnet_name] = os.path.join(controlnet_name, NEURON_FILE_NAME)

del model
del controlnets

return models_and_neuron_configs, output_model_names

Expand Down Expand Up @@ -442,6 +475,7 @@ def load_models_and_neuron_configs(
lora_weight_names: Optional[Union[str, List[str]]],
lora_adapter_names: Optional[Union[str, List[str]]],
lora_scales: Optional[Union[float, List[float]]],
controlnet_ids: Optional[Union[str, List[str]]],
output_attentions: bool = False,
output_hidden_states: bool = False,
library_name: Optional[str] = None,
Expand All @@ -466,6 +500,7 @@ def load_models_and_neuron_configs(
}
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
controlnets = load_controlnets(controlnet_ids)

models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
Expand All @@ -483,6 +518,7 @@ def load_models_and_neuron_configs(
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnets=controlnets,
)

return models_and_neuron_configs, output_model_names
Expand Down Expand Up @@ -516,6 +552,7 @@ def main_export(
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
**input_shapes,
):
output = Path(output)
Expand Down Expand Up @@ -545,6 +582,7 @@ def main_export(
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
**input_shapes,
)

Expand All @@ -565,7 +603,7 @@ def main_export(
is_stable_diffusion = "stable-diffusion" in task
if is_stable_diffusion:
# Do not validate vae encoder due to the sampling randomness
del neuron_outputs[-2] # -2 is the index of `vae_encoder`
neuron_outputs.pop("vae_encoder")
models_and_neuron_configs.pop("vae_encoder", None)
output_model_names.pop("vae_encoder", None)

Expand Down Expand Up @@ -687,6 +725,7 @@ def main():
lora_weight_names=getattr(args, "lora_weight_names", None),
lora_adapter_names=getattr(args, "lora_adapter_names", None),
lora_scales=getattr(args, "lora_scales", None),
controlnet_ids=getattr(args, "controlnet_ids", None),
**optional_outputs,
**input_shapes,
)
Expand Down
29 changes: 29 additions & 0 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Neuron configuration base classes."""

import importlib
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -162,6 +163,8 @@ def __init__(
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
num_beams: Optional[int] = None,
vae_scale_factor: Optional[int] = None,
encoder_hidden_size: Optional[int] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
# TODO: add custom dtype after optimum 1.13 release
Expand Down Expand Up @@ -197,6 +200,8 @@ def __init__(
"num_beams": num_beams,
"image_size": image_size or getattr(self._config, "image_size", None),
"patch_size": patch_size or getattr(self._config, "patch_size", None),
"vae_scale_factor": vae_scale_factor,
"encoder_hidden_size": encoder_hidden_size,
}
input_shapes = {}
for name, value in axes_values.items():
Expand Down Expand Up @@ -331,6 +336,30 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
flatten[name] = value
return flatten

@classmethod
def unflatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""
Re-construct inputs that have been flatten for tracing.
"""
unflatten = {}
to_group = {}
for name, value in inputs.items():
name_with_idx = re.findall(r"(.*?)_(\d+)", name)
if len(name_with_idx) > 0:
if name_with_idx[0][0] in to_group:
to_group[name_with_idx[0][0]].append((int(name_with_idx[0][1]), value))
else:
to_group[name_with_idx[0][0]] = [(int(name_with_idx[0][1]), value)]
else:
unflatten[name] = value

if to_group:
for name, values in to_group.items():
ordered = sorted(values, key=lambda x: x[0])
unflatten[name] = tuple([item[1] for item in ordered])

return unflatten

def patch_model_for_export(
self,
model: "PreTrainedModel",
Expand Down
Loading
Loading