-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch compile problem and some more ideas #107
Comments
what about a c++ implementation similar to llama |
Yah this will be a good one, but I'm not familiar with ggml yet, maybe I will try it someday or we can do it together :3 |
Follow up to the torch compile problem, I found that shape of tensor: generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
decoder_attention_mask = torch.ones(
(input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
) |
@ylacombe
Hi there! I've tried the latest version (mini-v1) with
torch.compile
andtorch.sdpa
on anA100
. The results are very good, with a speed of about80ms
for generating200ms
of audio, which is excellent for streaming. However, I believe there's still room for improvement:15 minutes
(second run to capture the CUDA graph).4GB
to store the model but around20GB
on VRAM once the warm-up is done.Here are some ideas I have for optimization:
Quantization of KVCache: This could help reduce the KVCache size.
Export to ONNX using Optimum: I found an implementation for MusicGen, and I think it will be similar for this model.
Implement PageAttention: This could help reduce wasted VRAM. I found a vLLM implementation for Whisper, and I think it will be similar to the current Static Cache's implementation based on Whisper. Maybe someday we can serve ParlerTTS like other LLM in vLLM, a study from our team that we can serve a 8b LLM with sub-second latency for ~20 CCUs
Please discuss and help determine a feasible approach we can take. Of course, I'm willing to contribute in any way I can.
The text was updated successfully, but these errors were encountered: