Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore optimized attention score for sd15 & fix the generated images quality issue #646

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
get_attention_scores_sd2,
get_attention_scores_sd15,
get_attention_scores_sd,
get_attention_scores_sdxl,
)
from ...utils import (
Expand All @@ -54,10 +53,7 @@
"Please update diffusers by running `pip install --upgrade diffusers`"
)
from diffusers import ControlNetModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
)
from diffusers.models.attention_processor import Attention


if TYPE_CHECKING:
Expand Down Expand Up @@ -388,7 +384,6 @@ def get_submodels_for_export_stable_diffusion(
models_for_export.append((DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, copy.deepcopy(text_encoder_2)))

# U-NET
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
Expand All @@ -400,12 +395,9 @@ def get_submodels_for_export_stable_diffusion(
if is_sdxl:
logger.info("Applying optimized attention score computation for sdxl.")
Attention.get_attention_scores = get_attention_scores_sdxl
elif "v1-5" in pipeline.config._name_or_path:
logger.info("Applying optimized attention score computation for stable diffusion 1.5.")
Attention.get_attention_scores = get_attention_scores_sd15
else:
logger.info("Applying optimized attention score computation for stable diffusion 2.")
Attention.get_attention_scores = get_attention_scores_sd2
logger.info("Applying optimized attention score computation for stable diffusion.")
Attention.get_attention_scores = get_attention_scores_sd
else:
logger.warning(
"You are not applying optimized attention score computation. If you want better performance, please"
Expand Down
6 changes: 2 additions & 4 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@
],
"model_utils": ["get_tied_parameters_dict", "tie_parameters"],
"optimization_utils": [
"get_attention_scores_sd2",
"get_attention_scores_sd15",
"get_attention_scores_sd",
"get_attention_scores_sdxl",
],
"patching": [
Expand Down Expand Up @@ -105,8 +104,7 @@
)
from .model_utils import get_tied_parameters_dict, tie_parameters
from .optimization_utils import (
get_attention_scores_sd2,
get_attention_scores_sd15,
get_attention_scores_sd,
get_attention_scores_sdxl,
)
from .patching import (
Expand Down
39 changes: 2 additions & 37 deletions optimum/neuron/utils/optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,8 @@
import torch


def get_attention_scores_sd15(self, query, key, attention_mask) -> torch.Tensor:
"""Optimized attention for Stable Diffusion 1.5 UNET."""
dtype = query.dtype

if self.upcast_attention:
query = query.float()
key = key.float()

baddbmm_input = torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
beta = 0

attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input

# TODO: following line is supposed to give the same result and reduce unnecessary overhead(no attention mask)
# however the compiled model output is far off from the one on cpu, need to further investigate.
# attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2)) # -> bad perf, max diff: 5.696073055267334 (atol: 0.001)

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
del attention_scores

attention_probs = attention_probs.to(dtype)

return attention_probs


def get_attention_scores_sd2(self, query, key, attn_mask):
"""Optimized attention for Stable Diffusion 2 UNET."""
def get_attention_scores_sd(self, query, key, attn_mask):
"""Optimized attention for Stable Diffusion UNET."""
dtype = query.dtype

if self.upcast_attention:
Expand Down
Loading