From 2a6a0c7627b5619c1ee2f4cf64093c1bf492a10b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 10 Apr 2024 10:54:34 +0200 Subject: [PATCH] Musicgen ONNX export (text-conditional only) (#1779) * WIP but need to work on encodec first * musicgen onnx export * better logs * add tests * rename audio_encoder_decode.onnx to encodec_decode.onnx * fix num heads in pkv * nits * add build_delay_pattern_mask * fix wrong hidden_size for cross attention pkv * fix tests * update doc --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/config.py | 7 + optimum/exporters/onnx/constants.py | 1 + optimum/exporters/onnx/convert.py | 5 +- optimum/exporters/onnx/model_configs.py | 307 +++++++++++++++++++++++- optimum/exporters/onnx/model_patcher.py | 137 ++++++++++- optimum/exporters/tasks.py | 7 +- optimum/exporters/utils.py | 46 ++++ optimum/onnx/transformations_utils.py | 2 +- optimum/utils/__init__.py | 3 + optimum/utils/input_generators.py | 132 +++++++++- tests/exporters/exporters_utils.py | 2 + 12 files changed, 637 insertions(+), 13 deletions(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 0dd4c823fd..095c8721e9 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -71,6 +71,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - MobileNet v2 - MPNet - MT5 +- Musicgen (text-conditional only) - Nystromformer - OWL-ViT - Pegasus diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index c505237948..0faf5048f6 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -383,6 +383,13 @@ def __init__( ) self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads = ( + self._decoder_onnx_config._normalized_config.num_attention_heads + ) + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads = ( + self._decoder_onnx_config._normalized_config.num_attention_heads + ) if isinstance(self._decoder_onnx_config, OnnxSeq2SeqConfigWithPast): self._past_key_values_generator = ( diff --git a/optimum/exporters/onnx/constants.py b/optimum/exporters/onnx/constants.py index 02a0654b8d..bac31b73a0 100644 --- a/optimum/exporters/onnx/constants.py +++ b/optimum/exporters/onnx/constants.py @@ -37,5 +37,6 @@ SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [ "bart", + "musicgen", "whisper", ] diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 1ad0f89681..053a7a5aeb 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -258,7 +258,7 @@ def _run_validation( model_kwargs = model_kwargs if model_kwargs is not None else {} - logger.info(f"Validating ONNX model {onnx_model.as_posix()}...") + logger.info(f"\nValidating ONNX model {onnx_model.as_posix()}...") if atol is None: atol = config.ATOL_FOR_VALIDATION @@ -764,6 +764,9 @@ def export_models( output_path = output_dir / output_name output_path.parent.mkdir(parents=True, exist_ok=True) + logger.info( + f"\n***** Exporting submodel {i + 1}/{len(models_and_onnx_configs)}: {submodel.__class__.__name__} *****" + ) outputs.append( export( model=submodel, diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f7a0208c3d..ea14711659 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -14,17 +14,22 @@ # limitations under the License. """Model specific ONNX configurations.""" import random -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union from packaging import version from transformers.utils import is_tf_available +from ...onnx import merge_decoders from ...utils import ( DEFAULT_DUMMY_SHAPES, BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, + DummyCodegenDecoderTextInputGenerator, DummyDecoderTextInputGenerator, + DummyEncodecInputGenerator, DummyInputGenerator, + DummyIntGenerator, DummyPastKeyValuesGenerator, DummyPix2StructInputGenerator, DummyPointsGenerator, @@ -47,6 +52,7 @@ NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedVisionConfig, + is_diffusers_available, logging, ) from ...utils.normalized_config import NormalizedConfigManager @@ -62,8 +68,10 @@ TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) +from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( FalconModelPatcher, + MusicgenModelPatcher, SAMModelPatcher, SentenceTransformersCLIPPatcher, SentenceTransformersTransformerPatcher, @@ -82,6 +90,9 @@ if is_tf_available(): from transformers.modeling_tf_utils import TFPreTrainedModel + if is_diffusers_available(): + from diffusers import ModelMixin + logger = logging.get_logger(__name__) @@ -1400,10 +1411,302 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs +class MusicgenOnnxConfig(OnnxSeq2SeqConfigWithPast): + # NOTE: Several warnings during the export are not to worry about: + # * for i, indices in enumerate(codes): --> can be unrolled, fixed length (num_quantizers). + # * max_pad = max(padding_left, padding_right) --> does not impact later controlflows. + # if length <= max_pad: --> appears to be always False for Musicgen. + + # opset>=13 needed to avoid a bug in T5 encoder SelfAttention. + # opset>=14 needed for torch.tril export. + DEFAULT_ONNX_OPSET = 14 + + VARIANTS = { + "text-conditional-with-past": "Exports Musicgen to ONNX to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:\n\t\t* text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.\n\t\t* encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.\n\t\t* decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).\n\t\t* decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input `use_cache_branch` allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.\n\t\t* build_delay_pattern_mask.onnx: A model taking as input `input_ids`, `pad_token_id`, `max_length`, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.", + } + # TODO: support audio-prompted generation (- audio_encoder_encode.onnx: corresponds to the audio encoder part in https://github.com/huggingface/transformers/blob/f01e1609bf4dba146d1347c1368c8c49df8636f6/src/transformers/models/musicgen/modeling_musicgen.py#L2087.\n\t) + # With that, we have full Encodec support. + DEFAULT_VARIANT = "text-conditional-with-past" + + NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTextInputGenerator, + DummyCodegenDecoderTextInputGenerator, + DummySeq2SeqPastKeyValuesGenerator, + DummyEncodecInputGenerator, + DummyIntGenerator, + ) + DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.ENCODER, + preprocessors: Optional[List[Any]] = None, + model_part: Optional[Literal["text_encoder", "encodec_decode", "decoder", "build_delay_pattern_mask"]] = None, + legacy: bool = False, + variant: str = "text-conditional-with-past", + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=preprocessors, + legacy=legacy, + ) + if legacy: + raise ValueError("Musicgen does not support legacy=True.") + + if ( + model_part in ["text_encoder", "encodec_decode", "build_delay_pattern_mask"] + and behavior != ConfigBehavior.ENCODER + ): + raise ValueError( + f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues." + ) + + if model_part == "decoder" and behavior != ConfigBehavior.DECODER: + raise ValueError( + f"model_part is {model_part} and behavior is {behavior}. This is not supported, please open an issue at https://github.com/huggingface/optimum/issues." + ) + + if behavior == ConfigBehavior.MONOLITH: + raise ValueError( + "Musicgen does not support behavior=ConfigBehavior.MONOLITH. Please open an issue at https://github.com/huggingface/optimum/issues." + ) + + if config.audio_encoder.model_type != "encodec": + raise ValueError( + f"Optimum ONNX export for Musicgen supports only Encodec as the audio encoder, got: {config.audio_encoder.model_type}. Please open an issue at https://github.com/huggingface/optimum/issues." + ) + + # Handling it would require to trace the audio_encoder.decode with torch.jit.script as we than have an unrollable loop. + if config.audio_encoder.chunk_length_s is not None: + raise ValueError( + f"Musicgen ONNX export currently does not support audio_encoder.chunk_length_s not None (got {config.audio_encoder.chunk_length_s}). Please open an issue at https://github.com/huggingface/optimum/issues." + ) + + self.model_part = model_part + if self.model_part == "decoder": + self.use_past = True # without past is not supported, hard-code it here. + + self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = NormalizedTextConfig(self._config.text_encoder) + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = NormalizedConfig(self._config.decoder) + self._normalized_config.decoder_num_layers = self._config.decoder.num_hidden_layers + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_layers = self._config.decoder.num_hidden_layers + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads = ( + self._config.decoder.num_attention_heads + ) + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads = ( + self._config.decoder.num_attention_heads + ) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + # Batched inference is not supported in Transformers. + if self.model_part == "text_encoder": + common_inputs = { + "input_ids": {0: "batch_size", 1: "encoder_sequence_length"}, + "attention_mask": {0: "batch_size", 1: "encoder_sequence_length"}, + } + elif self.model_part == "encodec_decode": + # 0: always 1 for chunk_length_s=None, 2: num_quantizers fixed. + common_inputs = {"audio_codes": {1: "batch_size", 3: "chunk_length"}} + elif self.model_part == "build_delay_pattern_mask": + common_inputs = { + "input_ids": {0: "batch_size_x_num_codebooks"}, + "pad_token_id": {}, + "max_length": {}, + } + elif self._behavior is ConfigBehavior.DECODER: + # Naming it total_batch_size as in case we use guidance_scale, the dimension 0 may be larger than simply the batch_size. + # Reference: https://github.com/huggingface/transformers/blob/31c575bcf13c2b85b65d652dd1b5b401f99be999/src/transformers/models/musicgen/modeling_musicgen.py#L1932-L1935 + common_inputs = { + "decoder_input_ids": {0: "total_batch_size_x_num_codebooks"}, + "encoder_outputs": {0: "total_batch_size", 1: "encoder_sequence_length"}, + # MusicgenForConditionalGeneration maps attention_mask to encoder_attention_mask. + "attention_mask": { + 0: "batch_size", + 1: "encoder_sequence_length", + }, + } + if self.use_past_in_inputs: + # TODO: validate the axis name for attention_mask + # common_inputs["attention_mask"][1] = "past_encoder_sequence_length + sequence_length" + self.add_past_key_values(common_inputs, direction="inputs") + else: + common_inputs["decoder_input_ids"] = { + 0: "total_batch_size_x_num_codebooks", + 1: "decoder_sequence_length", + } + else: + raise ValueError( + "This should not happen. Please open an issue at https://github.com/huggingface/optimum/issues." + ) + + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = {} + + if self.model_part == "text_encoder": + common_outputs = super().outputs + elif self.model_part == "encodec_decode": + common_outputs["audio_values"] = {0: "batch_size", 2: "audio_length"} + elif self.model_part == "build_delay_pattern_mask": + common_outputs["input_ids_edited"] = {0: "total_batch_size_x_num_codebooks"} + common_outputs["delay_pattern_mask"] = {0: "total_batch_size_x_num_codebooks", 1: "max_length"} + elif self._behavior is ConfigBehavior.DECODER: + common_outputs = super().outputs + + # MusicgenForConditionalGeneration output is named logits, not last_hidden_state. + # Rename last_hidden_state -> logits while keeping the order. + common_outputs = { + "logits" if name == "last_hidden_state" else name: value for name, value in common_outputs.items() + } + else: + raise ValueError( + "This should not happen. Please open an issue at https://github.com/huggingface/optimum/issues." + ) + + return common_outputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_decoder_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_decoder_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.decoder_num_layers): + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "total_batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "total_batch_size", 2: decoder_sequence_name} + + if ( + self.is_merged is True + or (self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs) + or direction == "inputs" + ): + # TODO: we only need to call it encoder_sequence_length_out in the merge case - but at torch.onnx.export() + # time we have currently no case to check whether we will merge at a later step or not (self.is_merged is + # not yet set at this time) + inputs_or_outputs[f"{name}.{i}.encoder.key"] = { + 0: "total_batch_size", + 2: "encoder_sequence_length_out", + } + inputs_or_outputs[f"{name}.{i}.encoder.value"] = { + 0: "total_batch_size", + 2: "encoder_sequence_length_out", + } + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MusicgenModelPatcher(self, model, model_kwargs=model_kwargs) + + @property + def torch_to_onnx_input_map(self) -> Dict[str, str]: + if self._behavior is ConfigBehavior.DECODER: + return { + "decoder_input_ids": "input_ids", + "encoder_outputs": "encoder_hidden_states", + "attention_mask": "encoder_attention_mask", + } + return {} + + def post_process_exported_models( + self, + path: Path, + models_and_onnx_configs: Dict[ + str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] + ], + onnx_files_subpaths: List[str], + ): + # Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task + if "with-past" in self.variant: + decoder_path = Path(path, onnx_files_subpaths[2]) + decoder_with_past_path = Path(path, onnx_files_subpaths[3]) + decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") + try: + # The decoder with past does not output the cross attention past key values as they are constant, + # hence the need for strict=False + merge_decoders( + decoder=decoder_path, + decoder_with_past=decoder_with_past_path, + save_path=decoder_merged_path, + strict=False, + ) + except Exception as e: + raise Exception(f"Unable to merge decoders. Detailed error: {e}") + + # In order to do the validation of the two branches on the same file + text_encoder_path = onnx_files_subpaths[0] + encodec_decode_path = onnx_files_subpaths[1] + build_delay_pattern_mask_path = onnx_files_subpaths[4] + + onnx_files_subpaths_new = [ + text_encoder_path, + encodec_decode_path, + decoder_merged_path.name, + decoder_merged_path.name, + build_delay_pattern_mask_path, + ] + + # We validate the two branches of the decoder model then + models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True + models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False + + # Past key values won't be generated by default, but added in the input + models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True + + models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True + models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True + else: + onnx_files_subpaths_new = onnx_files_subpaths + + return models_and_onnx_configs, onnx_files_subpaths_new + + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + if self.model_part == "build_delay_pattern_mask" and input_name == "input_ids": + original_batch_size = dummy_input_gen.batch_size + dummy_input_gen.batch_size = ( + original_batch_size * dummy_input_gen.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_codebooks + ) + + dummy_input = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + + dummy_input_gen.batch_size = original_batch_size + + else: + dummy_input = super().overwrite_shape_and_generate_input( + dummy_input_gen, input_name, framework, input_shapes + ) + + return dummy_input + + class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): # TODO: Transformers batched generation for Speecht5 is BROKEN (https://github.com/huggingface/transformers/pull/25943), # so we won't support for now. - NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(decoder_num_layers="decoder_layers") NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( hidden_size="hidden_size", num_attention_heads="encoder_attention_heads", # TODO: bugged in case encoder and decoder have different number of heads diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 523d1ae0ed..0a10534354 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -254,7 +254,6 @@ def patched_forward(*args, **kwargs): elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs: # The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one. filterd_outputs[name] = tuple([v[:2] for v in value]) - return filterd_outputs self.patched_forward = patched_forward @@ -796,3 +795,139 @@ def patched_forward(input_ids, attention_mask, pixel_values): return {"text_embeds": text_embeds, "image_embeds": image_embeds} self.patched_forward = patched_forward + + +# Triu with possible dynamic `diagonal` argument. Not possible with torch.triu unfortunately. +def triu_onnx(x, diagonal=0): + l, w = x.shape + arange_rows = torch.arange(l, device=x.device) + + arange_cols = torch.arange(w, device=x.device) + mask = arange_cols.expand(l, w) + + arange_rows = arange_rows[:, None] + diagonal + mask = mask >= arange_rows + return x.masked_fill(mask == 0, 0) + + +def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None): + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1]) + bsz, num_codebooks, seq_len = input_ids.shape + + max_length = max_length if max_length is not None else self.generation_config.max_length + input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 + + channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks + # we only apply the mask if we have a large enough seq len - otherwise we return as is + if max_length < 2 * channel_codebooks - 1: + raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.") + + # fill the shifted ids with the prompt entries, offset by the codebook idx + for codebook in range(channel_codebooks): + if self.config.audio_channels == 1: + # mono channel - loop over the codebooks one-by-one + input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] + else: + # left/right channels are interleaved in the generated codebooks, so handle one then the other + input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook] + input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1] + + # construct a pattern mask that indicates the positions of padding tokens for each codebook + # first fill the upper triangular part (the EOS padding) + # NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.` + # Using int8 leads to `Could not find an implementation for Where` + delay_pattern = triu_onnx( + torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1 + ) + + # NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.` + # Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here + + # then fill the lower triangular part (the BOS padding) + delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64)) + delay_pattern = delay_pattern.to(torch.bool) + + if self.config.audio_channels == 2: + # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion + delay_pattern = delay_pattern.repeat_interleave(2, dim=0) + + mask = ~delay_pattern.to(input_ids.device) + input_ids = mask * input_ids_shifted + ~mask * pad_token_id + + # find the first position to start generating - this is the first place we have the -1 token + # and will always be in the first codebook (since it has no codebook offset) + first_codebook_ids = input_ids[:, 0, :] + start_ids = (first_codebook_ids == -1).nonzero()[:, 1] + + # TODO: Is this OK? + first_start_id = start_ids.min() + + # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len) + pattern_mask = input_ids.reshape(bsz * num_codebooks, -1) + input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1) + return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask} + + +class MusicgenModelPatcher(Seq2SeqModelPatcher): + def __enter__(self): + self.patch_ops() + if self.real_config.model_part == "build_delay_pattern_mask": + # For build_delay_pattern_mask, we need to override the signature too. + self._model.forward = types.MethodType(patched_build_delay_pattern_mask, self._model) + else: + setattr(self._model, self.orig_forward_name, self.patched_forward) + + def __exit__(self, exc_type, exc_value, traceback): + self.restore_ops() + if self.real_config.model_part == "build_delay_pattern_mask": + self._model.forward = self.original_decoder_forward + else: + setattr(self._model, self.orig_forward_name, self.orig_forward) + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if config.model_part == "build_delay_pattern_mask": + self.original_decoder_forward = self.orig_forward + elif config.model_part == "encodec_decode": + # EncodecModel.forward -> EncodecModel.decode + @functools.wraps(self.orig_forward) + def patched_forward( + input_values: Optional["torch.Tensor"] = None, + padding_mask: Optional["torch.Tensor"] = None, + audio_codes: Optional["torch.Tensor"] = None, + bandwidth: Optional[float] = None, + audio_scales: Optional["torch.Tensor"] = None, + return_dict: Optional[bool] = None, + ): + chunk_length = self.real_config._config.audio_encoder.chunk_length + if chunk_length is None: + if audio_scales is not None: + audio_scales = audio_scales[0] + + if len(audio_codes) != 1: + raise ValueError(f"Expected one frame, got {len(audio_codes)}") + audio_values = self._model._decode_frame(audio_codes[0], audio_scales) + else: + raise ValueError("Not supported, a meaningful error should have been raised ahead.") + decoded_frames = [] + + for frame, scale in zip(audio_codes, audio_scales): + frames = self._model._decode_frame(frame, scale) + decoded_frames.append(frames) + + audio_values = self._model._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]: + audio_values = audio_values[..., : padding_mask.shape[-1]] + + return {"audio_values": audio_values} + + self.patched_forward = patched_forward diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index cb878db34b..6fd1690e4b 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -177,7 +177,7 @@ class TasksManager: "object-detection": "AutoModelForObjectDetection", "question-answering": "AutoModelForQuestionAnswering", "semantic-segmentation": "AutoModelForSemanticSegmentation", - "text-to-audio": "AutoModelForTextToSpectrogram", + "text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"), "text-generation": "AutoModelForCausalLM", "text2text-generation": "AutoModelForSeq2SeqLM", "text-classification": "AutoModelForSequenceClassification", @@ -334,6 +334,7 @@ class TasksManager: # TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM # Set of model topologies we support associated to the tasks supported by each topology and the factory + # TODO: remove `-with-past` tasks and rather rely on `variant`. _SUPPORTED_MODEL_TYPE = { "audio-spectrogram-transformer": supported_tasks_mapping( "feature-extraction", @@ -813,6 +814,10 @@ class TasksManager: "text2text-generation-with-past", onnx="MT5OnnxConfig", ), + "musicgen": supported_tasks_mapping( + "text-to-audio", # "variant" handles the "-with-past". We should generalize that. + onnx="MusicgenOnnxConfig", + ), "m2m-100": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index abec09ff5e..74d2d98385 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -345,6 +345,50 @@ def get_stable_diffusion_models_for_export( return models_for_export +def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"): + models_for_export = { + "text_encoder": model.text_encoder, + "encodec_decode": model.audio_encoder, + # For the decoder, we do not pass model.decoder because we may need to export model.enc_to_dec_proj + DECODER_NAME: model, + DECODER_WITH_PAST_NAME: model, + "build_delay_pattern_mask": model.decoder, + } + + text_encoder_config = config.__class__( + model.config, task=config.task, legacy=False, model_part="text_encoder", variant=config.variant + ) + models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_config) + + audio_encoder_config = config.__class__( + model.config, task=config.task, legacy=False, model_part="encodec_decode", variant=config.variant + ) + models_for_export["encodec_decode"] = (models_for_export["encodec_decode"], audio_encoder_config) + + use_past = "with-past" in config.variant + decoder_export_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False) + decoder_export_config.model_part = "decoder" + models_for_export[DECODER_NAME] = (models_for_export[DECODER_NAME], decoder_export_config) + + if "with-past" in config.variant: + decoder_export_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + decoder_export_config_with_past.model_part = "decoder" + models_for_export[DECODER_WITH_PAST_NAME] = ( + models_for_export[DECODER_WITH_PAST_NAME], + decoder_export_config_with_past, + ) + + build_delay_pattern_mask_config = config.__class__( + model.config, task=config.task, legacy=False, model_part="build_delay_pattern_mask", variant=config.variant + ) + models_for_export["build_delay_pattern_mask"] = ( + models_for_export["build_delay_pattern_mask"], + build_delay_pattern_mask_config, + ) + + return models_for_export + + def _get_submodels_for_export_sam(model, variant): models_for_export = {} @@ -513,6 +557,8 @@ def _get_submodels_and_export_configs( models_and_export_configs = get_sam_models_for_export(model, export_config) elif model.config.model_type == "speecht5": models_and_export_configs = get_speecht5_models_for_export(model, export_config, model_kwargs) + elif model.config.model_type == "musicgen": + models_and_export_configs = get_musicgen_models_for_export(model, export_config) else: models_and_export_configs = {"model": (model, export_config)} diff --git a/optimum/onnx/transformations_utils.py b/optimum/onnx/transformations_utils.py index 05931753bf..1f0765112e 100644 --- a/optimum/onnx/transformations_utils.py +++ b/optimum/onnx/transformations_utils.py @@ -160,7 +160,7 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool): else: logger.info( f"The two models proto have different outputs ({len(model1_outputs)} and {len(model2_outputs)} outputs)." - " Constant outputs will be added to unify the two models outputs." + " Constant outputs will be added to unify the two models outputs. This is expected for encoder-decoder models where cached cross-attention key/values are constant outputs, omitted in the model with KV cache." ) if model2_outputs.issubset(model1_outputs) is False: diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 99ce8693d4..07be3f7e1a 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -48,8 +48,11 @@ BloomDummyPastKeyValuesGenerator, DummyAudioInputGenerator, DummyBboxInputGenerator, + DummyCodegenDecoderTextInputGenerator, DummyDecoderTextInputGenerator, + DummyEncodecInputGenerator, DummyInputGenerator, + DummyIntGenerator, DummyLabelsGenerator, DummyPastKeyValuesGenerator, DummyPix2StructInputGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 0fed01ce36..b23acd1fba 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -402,7 +402,12 @@ def __init__( **kwargs, ): self.task = task - self.vocab_size = normalized_config.vocab_size + + if isinstance(normalized_config, NormalizedEncoderDecoderConfig): + self.vocab_size = normalized_config.vocab_size + else: + self.vocab_size = normalized_config.vocab_size + if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -419,6 +424,7 @@ def __init__( else: self.num_choices = num_choices self.padding_side = padding_side + self.normalized_config = normalized_config def generate( self, @@ -610,7 +616,7 @@ class DummySeq2SeqPastKeyValuesGenerator(DummyInputGenerator): def __init__( self, task: str, - normalized_config: NormalizedSeq2SeqConfig, + normalized_config: Union[NormalizedSeq2SeqConfig, NormalizedEncoderDecoderConfig], batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], encoder_sequence_length: Optional[int] = None, @@ -633,18 +639,37 @@ def __init__( self.sequence_length if encoder_sequence_length is None else encoder_sequence_length ) + if isinstance(normalized_config, NormalizedEncoderDecoderConfig): + # encoder_num_attention_heads / decoder_num_attention_heads are bad names, they rather refer to cross / self attention num heads. + self.encoder_num_attention_heads = ( + self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads + ) + self.decoder_num_attention_heads = ( + self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads + ) + # Same, `encoder_hidden_size` and `decoder_hidden_size` are bad names. + self.encoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size + self.decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size + self.decoder_num_layers = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_layers + else: + self.encoder_num_attention_heads = self.normalized_config.encoder_num_attention_heads + self.decoder_num_attention_heads = self.normalized_config.decoder_num_attention_heads + self.encoder_hidden_size = self.normalized_config.hidden_size + self.decoder_hidden_size = self.normalized_config.hidden_size + self.decoder_num_layers = self.normalized_config.decoder_num_layers + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): encoder_shape = ( self.batch_size, - self.normalized_config.encoder_num_attention_heads, + self.encoder_num_attention_heads, self.encoder_sequence_length, - self.normalized_config.hidden_size // self.normalized_config.encoder_num_attention_heads, + self.encoder_hidden_size // self.encoder_num_attention_heads, ) decoder_shape = ( self.batch_size, - self.normalized_config.decoder_num_attention_heads, + self.decoder_num_attention_heads, self.sequence_length, - self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads, + self.decoder_hidden_size // self.decoder_num_attention_heads, ) return [ ( @@ -653,7 +678,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), ) - for _ in range(self.normalized_config.decoder_num_layers) + for _ in range(self.decoder_num_layers) ] @@ -1287,3 +1312,96 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) for _ in range(self.num_layers) ] + + +class DummyCodegenDecoderTextInputGenerator(DummySeq2SeqDecoderTextInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + random_num_choices_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task, + normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + num_choices=num_choices, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + random_num_choices_range=random_num_choices_range, + ) + self.num_codebooks = normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_codebooks + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name in ["decoder_input_ids"]: + min_value = 0 + max_value = 2 if input_name != "input_ids" else self.vocab_size + shape = [self.batch_size * self.num_codebooks, self.sequence_length] + return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) + + return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) + + +class DummyEncodecInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("audio_codes",) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + **kwargs, + ): + self.task = task + self.batch_size = batch_size + + self.num_codebooks = normalized_config.decoder.num_codebooks + self.sequence_length = sequence_length + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "audio_codes": + # Kind of a hack to use `self.sequence_length` here, for Musicgen pad tokens are filtered out, see + # https://github.com/huggingface/transformers/blob/31c575bcf13c2b85b65d652dd1b5b401f99be999/src/transformers/models/musicgen/modeling_musicgen.py#L2458 + shape = [1, self.batch_size, self.num_codebooks, self.sequence_length] + else: + raise ValueError(f"Unsupported input {input_name} for DummyEncodecInputGenerator") + + return self.random_int_tensor( + shape=shape, + min_value=0, + max_value=50, + framework=framework, + dtype=int_dtype, + ) + + +class DummyIntGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "pad_token_id", + "max_length", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + **kwargs, + ): + pass + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + return self.random_int_tensor(shape=(1,), min_value=20, max_value=22, framework=framework, dtype=int_dtype) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 4d987ed982..bc1d8a4a28 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -121,6 +121,7 @@ "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", + "musicgen": "hf-internal-testing/tiny-random-MusicgenForConditionalGeneration", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "opt": "hf-internal-testing/tiny-random-OPTModel", "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", @@ -246,6 +247,7 @@ "mobilevit": "apple/mobilevit-small", "mpt": "mosaicml/mpt-7b", "mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing. + "musicgen": "facebook/musicgen-small", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel", "owlv2": "google/owlv2-base-patch16", "owlvit": "google/owlvit-base-patch32",