Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Musicgen ONNX export (text-conditional only) #1779

Merged
merged 12 commits into from
Apr 10, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Mar 27, 2024

Exports Musicgen conditioned by a text prompt.

optimum-cli export onnx --model facebook/musicgen-small musicgen_onnx &> export.log

If we want to condition with audio, it is more tricky and we first need to be able to export EncodecModel.encode which requires a combination of jit.script/jit.trace as it has some unrollable loops, unfortuantely.

Only KV cache export is tested & supported.

The following subcomponents are exported:

Partially fixes #1297

⚠️ Depends on huggingface/transformers#29913, please check out on this branch for now

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 27, 2024

Need #1780 for the CI

@ylacombe
Copy link
Contributor

ylacombe commented Apr 3, 2024

Will be great to do it with Musicgen Melody as well (where music conditioning is different, so could maybe work)

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 4, 2024

@ylacombe What are the main differences?

@xenova
Copy link
Contributor

xenova commented Apr 5, 2024

I've been testing this out in transformers.js and will report back when the output matches!

@xenova
Copy link
Contributor

xenova commented Apr 5, 2024

I upgraded to main transformers + this branch, (latest onnx, onnxruntime, optimum too), and running

optimum-cli export onnx -m facebook/musicgen-small output/facebook/musicgen-small

results in:

 File "/usr/local/lib/python3.10/dist-packages/transformers/models/musicgen/modeling_musicgen.py", line 1261, in forward
    decoder_outputs = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/musicgen/modeling_musicgen.py", line 1089, in forward
    attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_attn_mask_utils.py", line 351, in _prepare_4d_causal_attention_mask_for_sdpa
    raise ValueError(
ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.

(from huggingface/transformers#29939; cc @ylacombe)

Downgrading to a version before that commit works.

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm this works for transformers.js (v3 branch)! 🚀

Example code:

import { AutoTokenizer, MusicgenForConditionalGeneration } from '@xenova/transformers';

// Load tokenizer and model
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/musicgen-small');
const model = await MusicgenForConditionalGeneration.from_pretrained(
  'Xenova/musicgen-small', { dtype: 'fp32' }
);

// Prepare text input
const prompt = '80s pop track with bassy drums and synth';
const inputs = tokenizer(prompt);

// Generate audio
const audio_values = await model.generate({
  ...inputs,
  max_new_tokens: 512,
  do_sample: true,
  guidance_scale: 3,
});

// (Optional) Write the output to a WAV file
import wavefile from 'wavefile';
import fs from 'fs';

const wav = new wavefile.WaveFile();
wav.fromScratch(1, model.config.audio_encoder.sampling_rate, '32f', audio_values.data);
fs.writeFileSync('musicgen_out.wav', wav.toBuffer());

Samples:

sample_1.mp4
sample_2.mp4
sample_3.mp4

@fxmarty fxmarty merged commit 2f75b0d into huggingface:main Apr 10, 2024
57 of 61 checks passed
young-developer pushed a commit to young-developer/optimum that referenced this pull request May 10, 2024
* WIP but need to work on encodec first

* musicgen onnx export

* better logs

* add tests

* rename audio_encoder_decode.onnx to encodec_decode.onnx

* fix num heads in pkv

* nits

* add build_delay_pattern_mask

* fix wrong hidden_size for cross attention pkv

* fix tests

* update doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ONNX export] Musicgen for text-to-audio
5 participants