Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch compile problem and some more ideas #107

Closed
sang-nguyen-ts opened this issue Aug 15, 2024 · 3 comments
Closed

Torch compile problem and some more ideas #107

sang-nguyen-ts opened this issue Aug 15, 2024 · 3 comments

Comments

@sang-nguyen-ts
Copy link
Contributor

@ylacombe
Hi there! I've tried the latest version (mini-v1) with torch.compile and torch.sdpa on an A100. The results are very good, with a speed of about 80ms for generating 200ms of audio, which is excellent for streaming. However, I believe there's still room for improvement:

  • Torch compile warm-up time: It takes a long time to warm up. For example, if I want to warm up generation for a maximum of 2564 tokens, it takes around 15 minutes (second run to capture the CUDA graph).
  • High VRAM usage: The mini model takes around 4GB to store the model but around 20GB on VRAM once the warm-up is done.

Here are some ideas I have for optimization:

  1. Quantization of KVCache: This could help reduce the KVCache size.

  2. Export to ONNX using Optimum: I found an implementation for MusicGen, and I think it will be similar for this model.

  3. Implement PageAttention: This could help reduce wasted VRAM. I found a vLLM implementation for Whisper, and I think it will be similar to the current Static Cache's implementation based on Whisper. Maybe someday we can serve ParlerTTS like other LLM in vLLM, a study from our team that we can serve a 8b LLM with sub-second latency for ~20 CCUs

Please discuss and help determine a feasible approach we can take. Of course, I'm willing to contribute in any way I can.

@dgm3333
Copy link

dgm3333 commented Aug 16, 2024

what about a c++ implementation similar to llama
cpp - which has server implementation or whisper.cpp). Because its precompiled there should be no warmup effect (presumably due to jit Interpreter stabilising?)
https://github.com/ggerganov/llama.cpp
https://github.com/ggerganov/whisper.cpp
it's also handles quants and other optimisations

@sang-nguyen-ts
Copy link
Contributor Author

what about a c++ implementation similar to llama cpp - which has server implementation or whisper.cpp). Because its precompiled there should be no warmup effect (presumably due to jit Interpreter stabilising?) https://github.com/ggerganov/llama.cpp https://github.com/ggerganov/whisper.cpp it's also handles quants and other optimisations

Yah this will be a good one, but I'm not familiar with ggml yet, maybe I will try it someday or we can do it together :3

@sang-nguyen-ts
Copy link
Contributor Author

Follow up to the torch compile problem, I found that shape of tensor: decoder_attention_mask is change overtime based on sequence length which may make each generation step has it own CUDA graph:

generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
                decoder_attention_mask = torch.ones(
                    (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
                )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants