Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple controlnet #691

Merged
merged 12 commits into from
Sep 18, 2024
3 changes: 3 additions & 0 deletions .github/workflows/test_inf2_inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
sudo apt-get update -y
sudo apt-get install aws-neuronx-tools=2.17.1.0 aws-neuronx-runtime-lib=2.20.22.0-1b3ca6425 aws-neuronx-collectives=2.20.22.0-c101c322e -y
export PATH=/opt/aws/neuron/bin:$PATH
- name: Install cv2 dependencies
run: |
sudo apt-get install ffmpeg libsm6 libxext6 -y
- name: Checkout
uses: actions/checkout@v2
- name: Install python dependencies
Expand Down
62 changes: 62 additions & 0 deletions docs/source/inference_tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,68 @@ compare.save("compare.png")
/>


### MultiControlNet

With Optimum Neuron, you can also compose multiple ControlNet conditionings from different image inputs:

* Compile multiple ControlNet for SD1.5

```bash
optimum-cli export neuron --inline-weights-neff --model jyoung105/stable-diffusion-v1-5 --task stable-diffusion --auto_cast matmul --auto_cast_type bf16 --batch_size 1 --num_images_per_prompt 1 --controlnet_ids lllyasviel/control_v11p_sd15_openpose lllyasviel/control_v11f1p_sd15_depth --height 512 --width 512 sd15-512x512-bf16-openpose-depth
```

* Run SD1.5 with OpenPose and Depth conditionings:

```python
import numpy as np
import torch
from PIL import Image

from controlnet_aux import OpenposeDetector
from transformers import pipeline
from diffusers import UniPCMultistepScheduler
from diffusers.utils import load_image
from optimum.neuron import NeuronStableDiffusionControlNetPipeline


# OpenPose+Depth ControlNet
model_id = "sd15-512x512-bf16-openpose-depth"

# Load ControlNet images

# 1. openpose
image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/input.png")
processor = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
openpose_image = processor(image)

# 2. depth
image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_depth/resolve/main/images/input.png")
depth_estimator = pipeline('depth-estimation')
image = depth_estimator(image)['depth']
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
depth_image = Image.fromarray(image)

images = [openpose_image.resize((512, 512)), depth_image.resize((512, 512))]

# Inference
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained(model_id)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
prompt = "a giant in a fantasy landscape, best quality"
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

image = pipe(prompt=prompt, image=images).images[0]
image.save('out.png')
```

<img
src="https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/multicontrolnet.png"
width="768"
height="256"
alt="stable diffusion 1.5 generated image with OpenPose and Depth controlnet."
/>


## ControlNet with Stable Diffusion XL

Expand Down
8 changes: 4 additions & 4 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def load_models_and_neuron_configs(
revision: str,
force_download: bool,
local_files_only: bool,
use_auth_token: Optional[Union[bool, str]],
token: Optional[Union[bool, str]],
submodels: Optional[Dict[str, Union[Path, str]]],
lora_model_ids: Optional[Union[str, List[str]]],
lora_weight_names: Optional[Union[str, List[str]]],
Expand All @@ -494,7 +494,7 @@ def load_models_and_neuron_configs(
"subfolder": subfolder,
"revision": revision,
"cache_dir": cache_dir,
"use_auth_token": use_auth_token,
"token": token,
"local_files_only": local_files_only,
"force_download": force_download,
"trust_remote_code": trust_remote_code,
Expand Down Expand Up @@ -544,7 +544,7 @@ def main_export(
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
do_validation: bool = True,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
Expand Down Expand Up @@ -575,7 +575,7 @@ def main_export(
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
submodels=submodels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def generate(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
Expand Down
21 changes: 21 additions & 0 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,13 @@ class NeuronModelForImageClassification(NeuronTracedModel):

auto_model_class = AutoModelForImageClassification

@property
def dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None.
"""
return getattr(self.config.neuron, "input_dtype", torch.float32)

@add_start_docstrings_to_model_forward(
NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ IMAGE_CLASSIFICATION_EXAMPLE.format(
Expand Down Expand Up @@ -763,6 +770,13 @@ class NeuronModelForSemanticSegmentation(NeuronTracedModel):

auto_model_class = AutoModelForSemanticSegmentation

@property
def dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None.
"""
return getattr(self.config.neuron, "input_dtype", torch.float32)

@add_start_docstrings_to_model_forward(
NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ SEMANTIC_SEGMENTATION_EXAMPLE.format(
Expand Down Expand Up @@ -843,6 +857,13 @@ class NeuronModelForObjectDetection(NeuronTracedModel):

auto_model_class = AutoModelForObjectDetection

@property
def dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None.
"""
return getattr(self.config.neuron, "input_dtype", torch.float32)

@add_start_docstrings_to_model_forward(
NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ OBJECT_DETECTION_EXAMPLE.format(
Expand Down
38 changes: 13 additions & 25 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Optional, Tuple, Union

from huggingface_hub import HfApi, get_token, snapshot_download
from huggingface_hub.utils import is_google_colab
from huggingface_hub import HfApi, snapshot_download
from transformers import AutoConfig, AutoModel, GenerationConfig

from ..exporters.neuron.model_configs import * # noqa: F403
Expand Down Expand Up @@ -225,7 +224,7 @@ def __init__(
def _create_checkpoint(
cls,
model_id: str,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
Expand All @@ -243,7 +242,7 @@ def _create_checkpoint(
revision=revision,
framework="pt",
cache_dir=cache_dir,
use_auth_token=use_auth_token,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
Expand All @@ -269,7 +268,7 @@ def get_export_config(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
task: Optional[str] = None,
batch_size: Optional[int] = None,
Expand All @@ -286,7 +285,7 @@ def get_export_config(
else:
checkpoint_id = model_id
# Get the exact checkpoint revision (SHA1)
api = HfApi(token=use_auth_token)
api = HfApi(token=token)
model_info = api.repo_info(model_id, revision=revision)
checkpoint_revision = model_info.sha

Expand Down Expand Up @@ -337,7 +336,7 @@ def _export(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
task: Optional[str] = None,
batch_size: Optional[int] = None,
Expand All @@ -353,7 +352,7 @@ def _export(
new_config = cls.get_export_config(
model_id,
config,
use_auth_token=use_auth_token,
token=token,
revision=revision,
task=task,
batch_size=batch_size,
Expand Down Expand Up @@ -396,7 +395,7 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
**kwargs,
) -> "NeuronDecoderModel":
Expand All @@ -411,7 +410,7 @@ def _from_pretrained(

model_path = model_id
if not os.path.isdir(model_id):
model_path = snapshot_download(model_id, token=use_auth_token, revision=revision)
model_path = snapshot_download(model_id, token=token, revision=revision)

checkpoint_dir, compiled_dir = cls._get_neuron_dirs(model_path)
if not os.path.isdir(checkpoint_dir):
Expand All @@ -425,7 +424,7 @@ def _from_pretrained(
checkpoint_id,
task=task,
revision=checkpoint_revision,
use_auth_token=use_auth_token,
token=token,
**kwargs,
)
assert os.path.isdir(compiled_dir)
Expand Down Expand Up @@ -467,24 +466,13 @@ def push_to_hub(
repository_id: str,
private: Optional[bool] = None,
revision: Optional[str] = None,
use_auth_token: Union[bool, str] = True,
token: Union[bool, str] = True,
endpoint: Optional[str] = None,
) -> str:
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token:
huggingface_token = get_token()
else:
raise ValueError("You need to provide `use_auth_token` to be able to push to the hub")
api = HfApi(endpoint=endpoint)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I actually think we should pass the token here, so that it can be omitted later.

cc @Wauplin

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds good, I just did it the way it was to pass the CIs, I will leave it to you if you want to change! (don't want to wait for CIs again hhh


user = api.whoami(huggingface_token)
if is_google_colab():
# Only in Google Colab to avoid the warning message
self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"])

api.create_repo(
token=huggingface_token,
token=token,
repo_id=repository_id,
exist_ok=True,
private=private,
Expand All @@ -498,7 +486,7 @@ def push_to_hub(
api.upload_folder(
repo_id=repository_id,
folder_path=save_directory,
token=huggingface_token,
token=token,
revision=revision,
ignore_patterns=ignore_patterns,
)
23 changes: 14 additions & 9 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: Dict[str, Any],
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
Expand Down Expand Up @@ -592,7 +592,7 @@ def _from_pretrained(
model_id,
cache_dir=cache_dir,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
revision=revision,
force_download=force_download,
allow_patterns=allow_patterns,
Expand Down Expand Up @@ -720,7 +720,7 @@ def _export(
model_id: Union[str, Path],
config: Dict[str, Any],
unet_id: Optional[Union[str, Path]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: str = "main",
force_download: bool = True,
cache_dir: Optional[str] = None,
Expand Down Expand Up @@ -758,9 +758,9 @@ def _export(
configuration files of compatible classes.
unet_id (`Optional[Union[str, Path]]`, defaults to `None`):
A string or a path point to the U-NET model to replace the one in the original pipeline.
use_auth_token (`Optional[Union[bool, str]]`, defaults to `None`):
token (`Optional[Union[bool,str]]`, defaults to `None`):
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`).
revision (`str`, defaults to `"main"`):
The specific model version to use (can be a branch name, tag name or commit id).
force_download (`bool`, defaults to `True`):
Expand Down Expand Up @@ -837,7 +837,7 @@ def _export(
framework="pt",
library_name=cls.library_name,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
Expand All @@ -863,7 +863,7 @@ def _export(
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
submodels=submodels,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
Expand Down Expand Up @@ -938,7 +938,7 @@ def _export(
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
token=token,
do_validation=False,
submodels={"unet": unet_id},
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -1189,9 +1189,14 @@ def forward(
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.model)):
if guess_mode:
logger.info(
"Guess mode is not yet supported. File us an issue on: https://github.com/huggingface/optimum-neuron/issues."
)
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
inputs = (sample, timestep, encoder_hidden_states, image, scale)
down_samples, mid_sample = controlnet(*inputs)

Expand Down
Loading
Loading