-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Musicgen ONNX export (text-conditional only) #1779
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Need #1780 for the CI |
Will be great to do it with Musicgen Melody as well (where music conditioning is different, so could maybe work) |
@ylacombe What are the main differences? |
I've been testing this out in transformers.js and will report back when the output matches! |
I upgraded to main transformers + this branch, (latest onnx, onnxruntime, optimum too), and running
results in:
(from huggingface/transformers#29939; cc @ylacombe) Downgrading to a version before that commit works. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm this works for transformers.js (v3 branch)! 🚀
Example code:
import { AutoTokenizer, MusicgenForConditionalGeneration } from '@xenova/transformers';
// Load tokenizer and model
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/musicgen-small');
const model = await MusicgenForConditionalGeneration.from_pretrained(
'Xenova/musicgen-small', { dtype: 'fp32' }
);
// Prepare text input
const prompt = '80s pop track with bassy drums and synth';
const inputs = tokenizer(prompt);
// Generate audio
const audio_values = await model.generate({
...inputs,
max_new_tokens: 512,
do_sample: true,
guidance_scale: 3,
});
// (Optional) Write the output to a WAV file
import wavefile from 'wavefile';
import fs from 'fs';
const wav = new wavefile.WaveFile();
wav.fromScratch(1, model.config.audio_encoder.sampling_rate, '32f', audio_values.data);
fs.writeFileSync('musicgen_out.wav', wav.toBuffer());
Samples:
sample_1.mp4
sample_2.mp4
sample_3.mp4
* WIP but need to work on encodec first * musicgen onnx export * better logs * add tests * rename audio_encoder_decode.onnx to encodec_decode.onnx * fix num heads in pkv * nits * add build_delay_pattern_mask * fix wrong hidden_size for cross attention pkv * fix tests * update doc
Exports Musicgen conditioned by a text prompt.
If we want to condition with audio, it is more tricky and we first need to be able to export
EncodecModel.encode
which requires a combination of jit.script/jit.trace as it has some unrollable loops, unfortuantely.Only KV cache export is tested & supported.
The following subcomponents are exported:
text_encoder.onnx
: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.audio_encoder_decode.onnx
: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.decoder_model.onnx
: The Musicgen decoder, without past key values input, and computing cross attention.decoder_with_past_model.onnx
: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention.decoder_model_merged.onnx
: The two previous models fused in one, to avoid duplicating weights. A boolean inputuse_cache_branch
allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.build_delay_pattern_mask.onnx
: corresponds to https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1054-L1125Partially fixes #1297