From e73f5efcb47e7f72dbc9c0d7c4331cf5fa8ae9a7 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 27 Mar 2024 18:21:14 +0100 Subject: [PATCH] rename audio_encoder_decode.onnx to encodec_decode.onnx --- optimum/exporters/onnx/model_configs.py | 14 +++++++------- optimum/exporters/onnx/model_patcher.py | 2 +- optimum/exporters/utils.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 18cf47cf09..37d78b4cba 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1391,7 +1391,7 @@ class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast): DEFAULT_ONNX_OPSET = 13 # Needed to avoid a bug in T5 encoder SelfAttention. VARIANTS = { - "text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* audio_encoder_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention.\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention.\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.", + "text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention.\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention.\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.", } # TODO: support audio-prompted generation (- audio_encoder_encode.onnx: corresponds to the audio encoder part in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.\n\t) # With that, we have full Encodec support. @@ -1417,7 +1417,7 @@ def __init__( use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.ENCODER, preprocessors: Optional[List[Any]] = None, - model_part: Optional[Literal["text_encoder", "audio_encoder_decode", "decoder"]] = None, + model_part: Optional[Literal["text_encoder", "encodec_decode", "decoder"]] = None, legacy: bool = False, variant: str = "text-conditional-with-past", ): @@ -1435,7 +1435,7 @@ def __init__( if legacy: raise ValueError("Musicgen does not support legacy=True.") - if model_part in ["text_encoder", "audio_encoder_decode"] and behavior != ConfigBehavior.ENCODER: + if model_part in ["text_encoder", "encodec_decode"] and behavior != ConfigBehavior.ENCODER: raise ValueError( f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues." ) @@ -1478,7 +1478,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, } - elif self.model_part == "audio_encoder_decode": + elif self.model_part == "encodec_decode": # 0: always 1 for chunk_length_s=None, 2: num_quantizers fixed. common_inputs = {"audio_codes": {1: "batch_size", 3: "chunk_length"}} elif self._behavior is ConfigBehavior.DECODER: @@ -1510,7 +1510,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: if self.model_part == "text_encoder": common_outputs = super().outputs - elif self.model_part == "audio_encoder_decode": + elif self.model_part == "encodec_decode": common_outputs["audio_values"] = {0: "batch_size", 2: "audio_length"} elif self._behavior is ConfigBehavior.DECODER: common_outputs = super().outputs @@ -1601,11 +1601,11 @@ def post_process_exported_models( # In order to do the validation of the two branches on the same file text_encoder_path = onnx_files_subpaths[0] - audio_encoder_decode_path = onnx_files_subpaths[1] + encodec_decode_path = onnx_files_subpaths[1] onnx_files_subpaths_new = [ text_encoder_path, - audio_encoder_decode_path, + encodec_decode_path, decoder_merged_path.name, decoder_merged_path.name, ] diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 23d4b3a183..d9e800d59d 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -806,7 +806,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if config.model_part == "audio_encoder_decode": + if config.model_part == "encodec_decode": # EncodecModel.forward -> EncodecModel.decode @functools.wraps(self.orig_forward) def patched_forward( diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index b056e1a9a0..fd4aa33f72 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -348,7 +348,7 @@ def get_stable_diffusion_models_for_export( def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"): models_for_export = { "text_encoder": model.text_encoder, - "audio_encoder_decode": model.audio_encoder, + "encodec_decode": model.audio_encoder, # For the decoder, we do not pass model.decoder because we may need to export model.enc_to_dec_proj DECODER_NAME: model, DECODER_WITH_PAST_NAME: model, @@ -360,9 +360,9 @@ def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrained models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_config) audio_encoder_config = config.__class__( - model.config, task=config.task, legacy=False, model_part="audio_encoder_decode", variant=config.variant + model.config, task=config.task, legacy=False, model_part="encodec_decode", variant=config.variant ) - models_for_export["audio_encoder_decode"] = (models_for_export["audio_encoder_decode"], audio_encoder_config) + models_for_export["encodec_decode"] = (models_for_export["encodec_decode"], audio_encoder_config) use_past = "with-past" in config.variant decoder_export_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)