Skip to content

Commit

Permalink
enable to only have the second tokenizer and text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 12, 2023
1 parent 00d26ba commit fc48e94
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 64 deletions.
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,11 @@ def main_export(
onnx_files_subpaths = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs]

# Saving the additional components needed to perform inference.
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if getattr(model, "feature_extractor", None) is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
if getattr(model, "tokenizer", None) is not None:
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if getattr(model, "tokenizer_2", None) is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
model.save_config(output)
Expand Down
18 changes: 10 additions & 8 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,16 @@ def _get_submodels_for_export_stable_diffusion(
from diffusers import StableDiffusionXLPipeline

models_for_export = {}

if isinstance(pipeline, StableDiffusionXLPipeline):
pipeline.text_encoder.config.output_hidden_states = True
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim

# Text encoder
models_for_export["text_encoder"] = pipeline.text_encoder
if pipeline.text_encoder is not None:
if isinstance(pipeline, StableDiffusionXLPipeline):
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder

# U-NET
# PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention
Expand Down Expand Up @@ -262,11 +263,12 @@ def get_stable_diffusion_models_for_export(
models_for_export = _get_submodels_for_export_stable_diffusion(pipeline)

# Text encoder
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder, exporter="onnx", task="feature-extraction"
)
text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config)
if "text_encoder" in models_for_export:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder, exporter="onnx", task="feature-extraction"
)
text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config)

# U-NET
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
Expand Down
8 changes: 4 additions & 4 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@
"ORTStableDiffusionImg2ImgPipeline",
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipelineMixin",
"ORTStableDiffusionXLImg2ImgPipeline",
]
else:
_import_structure["modeling_diffusion"] = [
"ORTStableDiffusionPipeline",
"ORTStableDiffusionImg2ImgPipeline",
"ORTStableDiffusionInpaintPipeline",
"ORTStableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipelineMixin",
"ORTStableDiffusionXLImg2ImgPipeline",
]


Expand Down Expand Up @@ -128,16 +128,16 @@
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipelineMixin,
)
else:
from .modeling_diffusion import (
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipelineMixin,
)
else:
import sys
Expand Down
72 changes: 37 additions & 35 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,16 @@ def __init__(
self._internal_dict = config
self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self)
self.vae_decoder_model_path = Path(vae_decoder_session._model_path)
self.text_encoder = ORTModelTextEncoder(text_encoder_session, self)
self.text_encoder_model_path = Path(text_encoder_session._model_path)
self.unet = ORTModelUnet(unet_session, self)
self.unet_model_path = Path(unet_session._model_path)

if text_encoder_session is not None:
self.text_encoder_model_path = Path(text_encoder_session._model_path)
self.text_encoder = ORTModelTextEncoder(text_encoder_session, self)
else:
self.text_encoder_model_path = None
self.text_encoder = None

if vae_encoder_session is not None:
self.vae_encoder_model_path = Path(vae_encoder_session._model_path)
self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self)
Expand Down Expand Up @@ -200,23 +205,22 @@ def load_model(
Provider option dictionary corresponding to the provider used. See available options
for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`.
"""
vae_decoder_session = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options)
text_encoder_session = ORTModel.load_model(text_encoder_path, provider, session_options, provider_options)
unet_session = ORTModel.load_model(unet_path, provider, session_options, provider_options)
vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options)
unet = ORTModel.load_model(unet_path, provider, session_options, provider_options)

if vae_encoder_path is not None:
vae_encoder_session = ORTModel.load_model(vae_encoder_path, provider, session_options, provider_options)
else:
vae_encoder_session = None
sessions = {
"vae_encoder": vae_encoder_path,
"text_encoder": text_encoder_path,
"text_encoder_2": text_encoder_2_path,
}

if text_encoder_2_path is not None:
text_encoder_2_session = ORTModel.load_model(
text_encoder_2_path, provider, session_options, provider_options
)
else:
text_encoder_2_session = None
for key, value in sessions.items():
if value is not None and value.is_file():
sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options)
else:
sessions[key] = None

return vae_decoder_session, text_encoder_session, unet_session, vae_encoder_session, text_encoder_2_session
return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"]

def _save_pretrained(self, save_directory: Union[str, Path]):
save_directory = Path(save_directory)
Expand Down Expand Up @@ -247,10 +251,12 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
if config_path.is_file():
shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name)

self.tokenizer.save_pretrained(save_directory / "tokenizer")
self.scheduler.save_pretrained(save_directory / "scheduler")

if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory / "tokenizer")
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")

Expand Down Expand Up @@ -322,20 +328,14 @@ def _from_pretrained(
else:
sub_models[name] = load_method(new_model_save_dir)

vae_encoder_path = new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name
text_encoder_2_path = new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name

if not vae_encoder_path.is_file():
logger.warning(
f"VAE encoder not found in {model_id} and will not be loaded for inference. This component is needed for some tasks."
)

inference_sessions = cls.load_model(
vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model(
vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name,
vae_encoder_path=vae_encoder_path if vae_encoder_path.is_file() else None,
text_encoder_2_path=text_encoder_2_path if text_encoder_2_path.is_file() else None,
vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
text_encoder_2_path=new_model_save_dir
/ DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER
/ text_encoder_2_file_name,
provider=provider,
session_options=session_options,
provider_options=provider_options,
Expand All @@ -350,14 +350,16 @@ def _from_pretrained(
)

return cls(
*inference_sessions[:-2],
vae_decoder_session=vae_decoder,
text_encoder_session=text_encoder,
unet_session=unet,
config=config,
tokenizer=sub_models["tokenizer"],
scheduler=sub_models["scheduler"],
feature_extractor=sub_models.pop("feature_extractor", None),
tokenizer_2=sub_models.pop("tokenizer_2", None),
vae_encoder_session=inference_sessions[-2],
text_encoder_2_session=inference_sessions[-1],
tokenizer=sub_models.get("tokenizer", None),
scheduler=sub_models.get("scheduler"),
feature_extractor=sub_models.get("feature_extractor", None),
tokenizer_2=sub_models.get("tokenizer_2", None),
vae_encoder_session=vae_encoder,
text_encoder_2_session=text_encoder_2,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
)
Expand Down
13 changes: 6 additions & 7 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,19 @@ def _encode_prompt(
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
return_tensors="np",
)
negative_prompt_embeds = text_encoder(
input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32))
)
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = np.concatenate(negative_prompt_embeds, axis=-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,19 @@ def _encode_prompt(
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
return_tensors="np",
)

negative_prompt_embeds = text_encoder(
input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32))
)
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[-2]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = np.concatenate(negative_prompt_embeds, axis=-1)

Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/dummy_diffusers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["diffusers"])


class StableDiffusionXLImg2ImgPipelineMixin(metaclass=DummyObject):
class ORTStableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["diffusers"]

def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit fc48e94

Please sign in to comment.