Skip to content

Commit

Permalink
rename audio_encoder_decode.onnx to encodec_decode.onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Mar 27, 2024
1 parent b2e2f07 commit e73f5ef
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast):
DEFAULT_ONNX_OPSET = 13 # Needed to avoid a bug in T5 encoder SelfAttention.

VARIANTS = {
"text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* audio_encoder_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention.\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention.\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.",
"text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention.\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention.\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.",
}
# TODO: support audio-prompted generation (- audio_encoder_encode.onnx: corresponds to the audio encoder part in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.\n\t)
# With that, we have full Encodec support.
Expand All @@ -1417,7 +1417,7 @@ def __init__(
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.ENCODER,
preprocessors: Optional[List[Any]] = None,
model_part: Optional[Literal["text_encoder", "audio_encoder_decode", "decoder"]] = None,
model_part: Optional[Literal["text_encoder", "encodec_decode", "decoder"]] = None,
legacy: bool = False,
variant: str = "text-conditional-with-past",
):
Expand All @@ -1435,7 +1435,7 @@ def __init__(
if legacy:
raise ValueError("Musicgen does not support legacy=True.")

if model_part in ["text_encoder", "audio_encoder_decode"] and behavior != ConfigBehavior.ENCODER:
if model_part in ["text_encoder", "encodec_decode"] and behavior != ConfigBehavior.ENCODER:
raise ValueError(
f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues."
)
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"input_ids": {0: "batch_size", 1: "encoder_sequence_length"},
"attention_mask": {0: "batch_size", 1: "encoder_sequence_length"},
}
elif self.model_part == "audio_encoder_decode":
elif self.model_part == "encodec_decode":
# 0: always 1 for chunk_length_s=None, 2: num_quantizers fixed.
common_inputs = {"audio_codes": {1: "batch_size", 3: "chunk_length"}}
elif self._behavior is ConfigBehavior.DECODER:
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

if self.model_part == "text_encoder":
common_outputs = super().outputs
elif self.model_part == "audio_encoder_decode":
elif self.model_part == "encodec_decode":
common_outputs["audio_values"] = {0: "batch_size", 2: "audio_length"}
elif self._behavior is ConfigBehavior.DECODER:
common_outputs = super().outputs
Expand Down Expand Up @@ -1601,11 +1601,11 @@ def post_process_exported_models(

# In order to do the validation of the two branches on the same file
text_encoder_path = onnx_files_subpaths[0]
audio_encoder_decode_path = onnx_files_subpaths[1]
encodec_decode_path = onnx_files_subpaths[1]

onnx_files_subpaths_new = [
text_encoder_path,
audio_encoder_decode_path,
encodec_decode_path,
decoder_merged_path.name,
decoder_merged_path.name,
]
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

if config.model_part == "audio_encoder_decode":
if config.model_part == "encodec_decode":
# EncodecModel.forward -> EncodecModel.decode
@functools.wraps(self.orig_forward)
def patched_forward(
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def get_stable_diffusion_models_for_export(
def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"):
models_for_export = {
"text_encoder": model.text_encoder,
"audio_encoder_decode": model.audio_encoder,
"encodec_decode": model.audio_encoder,
# For the decoder, we do not pass model.decoder because we may need to export model.enc_to_dec_proj
DECODER_NAME: model,
DECODER_WITH_PAST_NAME: model,
Expand All @@ -360,9 +360,9 @@ def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrained
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_config)

audio_encoder_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="audio_encoder_decode", variant=config.variant
model.config, task=config.task, legacy=False, model_part="encodec_decode", variant=config.variant
)
models_for_export["audio_encoder_decode"] = (models_for_export["audio_encoder_decode"], audio_encoder_config)
models_for_export["encodec_decode"] = (models_for_export["encodec_decode"], audio_encoder_config)

use_past = "with-past" in config.variant
decoder_export_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)
Expand Down

0 comments on commit e73f5ef

Please sign in to comment.