Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Musicgen ONNX export (text-conditional only) #1779

Merged
merged 12 commits into from
Apr 10, 2024
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- MobileNet v2
- MPNet
- MT5
- Musicgen (text-conditional only)
- Nystromformer
- OWL-ViT
- Pegasus
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ def __init__(
)

self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads = (
self._decoder_onnx_config._normalized_config.num_attention_heads
)
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads = (
self._decoder_onnx_config._normalized_config.num_attention_heads
)

if isinstance(self._decoder_onnx_config, OnnxSeq2SeqConfigWithPast):
self._past_key_values_generator = (
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@

SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [
"bart",
"musicgen",
"whisper",
]
5 changes: 4 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _run_validation(

model_kwargs = model_kwargs if model_kwargs is not None else {}

logger.info(f"Validating ONNX model {onnx_model.as_posix()}...")
logger.info(f"\nValidating ONNX model {onnx_model.as_posix()}...")

if atol is None:
atol = config.ATOL_FOR_VALIDATION
Expand Down Expand Up @@ -764,6 +764,9 @@ def export_models(
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)

logger.info(
f"\n***** Exporting submodel {i + 1}/{len(models_and_onnx_configs)}: {submodel.__class__.__name__} *****"
)
outputs.append(
export(
model=submodel,
Expand Down
307 changes: 305 additions & 2 deletions optimum/exporters/onnx/model_configs.py

Large diffs are not rendered by default.

137 changes: 136 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def patched_forward(*args, **kwargs):
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])

return filterd_outputs

self.patched_forward = patched_forward
Expand Down Expand Up @@ -796,3 +795,139 @@ def patched_forward(input_ids, attention_mask, pixel_values):
return {"text_embeds": text_embeds, "image_embeds": image_embeds}

self.patched_forward = patched_forward


# Triu with possible dynamic `diagonal` argument. Not possible with torch.triu unfortunately.
def triu_onnx(x, diagonal=0):
l, w = x.shape
arange_rows = torch.arange(l, device=x.device)

arange_cols = torch.arange(w, device=x.device)
mask = arange_cols.expand(l, w)

arange_rows = arange_rows[:, None] + diagonal
mask = mask >= arange_rows
return x.masked_fill(mask == 0, 0)


def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None):
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape

max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1

channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * channel_codebooks - 1:
raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.")

# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]

# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int8 leads to `Could not find an implementation for Where`
delay_pattern = triu_onnx(
torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1
)

# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here

# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))
delay_pattern = delay_pattern.to(torch.bool)

if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id

# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]

# TODO: Is this OK?
first_start_id = start_ids.min()

# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask}


class MusicgenModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
self.patch_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
# For build_delay_pattern_mask, we need to override the signature too.
self._model.forward = types.MethodType(patched_build_delay_pattern_mask, self._model)
else:
setattr(self._model, self.orig_forward_name, self.patched_forward)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
self._model.forward = self.original_decoder_forward
else:
setattr(self._model, self.orig_forward_name, self.orig_forward)

def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if config.model_part == "build_delay_pattern_mask":
self.original_decoder_forward = self.orig_forward
elif config.model_part == "encodec_decode":
# EncodecModel.forward -> EncodecModel.decode
@functools.wraps(self.orig_forward)
def patched_forward(
input_values: Optional["torch.Tensor"] = None,
padding_mask: Optional["torch.Tensor"] = None,
audio_codes: Optional["torch.Tensor"] = None,
bandwidth: Optional[float] = None,
audio_scales: Optional["torch.Tensor"] = None,
return_dict: Optional[bool] = None,
):
chunk_length = self.real_config._config.audio_encoder.chunk_length
if chunk_length is None:
if audio_scales is not None:
audio_scales = audio_scales[0]

if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._model._decode_frame(audio_codes[0], audio_scales)
else:
raise ValueError("Not supported, a meaningful error should have been raised ahead.")
decoded_frames = []

for frame, scale in zip(audio_codes, audio_scales):
frames = self._model._decode_frame(frame, scale)
decoded_frames.append(frames)

audio_values = self._model._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)

# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
audio_values = audio_values[..., : padding_mask.shape[-1]]

return {"audio_values": audio_values}

self.patched_forward = patched_forward
7 changes: 6 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TasksManager:
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": "AutoModelForTextToSpectrogram",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
Expand Down Expand Up @@ -334,6 +334,7 @@ class TasksManager:

# TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM
# Set of model topologies we support associated to the tasks supported by each topology and the factory
# TODO: remove `-with-past` tasks and rather rely on `variant`.
_SUPPORTED_MODEL_TYPE = {
"audio-spectrogram-transformer": supported_tasks_mapping(
"feature-extraction",
Expand Down Expand Up @@ -806,6 +807,10 @@ class TasksManager:
"text2text-generation-with-past",
onnx="MT5OnnxConfig",
),
"musicgen": supported_tasks_mapping(
"text-to-audio", # "variant" handles the "-with-past". We should generalize that.
onnx="MusicgenOnnxConfig",
),
"m2m-100": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
46 changes: 46 additions & 0 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,50 @@ def get_stable_diffusion_models_for_export(
return models_for_export


def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"):
models_for_export = {
"text_encoder": model.text_encoder,
"encodec_decode": model.audio_encoder,
# For the decoder, we do not pass model.decoder because we may need to export model.enc_to_dec_proj
DECODER_NAME: model,
DECODER_WITH_PAST_NAME: model,
"build_delay_pattern_mask": model.decoder,
}

text_encoder_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="text_encoder", variant=config.variant
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_config)

audio_encoder_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="encodec_decode", variant=config.variant
)
models_for_export["encodec_decode"] = (models_for_export["encodec_decode"], audio_encoder_config)

use_past = "with-past" in config.variant
decoder_export_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)
decoder_export_config.model_part = "decoder"
models_for_export[DECODER_NAME] = (models_for_export[DECODER_NAME], decoder_export_config)

if "with-past" in config.variant:
decoder_export_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
decoder_export_config_with_past.model_part = "decoder"
models_for_export[DECODER_WITH_PAST_NAME] = (
models_for_export[DECODER_WITH_PAST_NAME],
decoder_export_config_with_past,
)

build_delay_pattern_mask_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="build_delay_pattern_mask", variant=config.variant
)
models_for_export["build_delay_pattern_mask"] = (
models_for_export["build_delay_pattern_mask"],
build_delay_pattern_mask_config,
)

return models_for_export


def _get_submodels_for_export_sam(model, variant):
models_for_export = {}

Expand Down Expand Up @@ -513,6 +557,8 @@ def _get_submodels_and_export_configs(
models_and_export_configs = get_sam_models_for_export(model, export_config)
elif model.config.model_type == "speecht5":
models_and_export_configs = get_speecht5_models_for_export(model, export_config, model_kwargs)
elif model.config.model_type == "musicgen":
models_and_export_configs = get_musicgen_models_for_export(model, export_config)
else:
models_and_export_configs = {"model": (model, export_config)}

Expand Down
2 changes: 1 addition & 1 deletion optimum/onnx/transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool):
else:
logger.info(
f"The two models proto have different outputs ({len(model1_outputs)} and {len(model2_outputs)} outputs)."
" Constant outputs will be added to unify the two models outputs."
" Constant outputs will be added to unify the two models outputs. This is expected for encoder-decoder models where cached cross-attention key/values are constant outputs, omitted in the model with KV cache."
)

if model2_outputs.issubset(model1_outputs) is False:
Expand Down
3 changes: 3 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
DummyLabelsGenerator,
DummyPastKeyValuesGenerator,
DummyPix2StructInputGenerator,
Expand Down
Loading
Loading