Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Compatibility with transformers>=4.43.2 #65

Closed
SiriuslySirius opened this issue Jul 31, 2024 · 18 comments · Fixed by #109
Closed

[Feature request] Compatibility with transformers>=4.43.2 #65

SiriuslySirius opened this issue Jul 31, 2024 · 18 comments · Fixed by #109
Labels
bug Something isn't working

Comments

@SiriuslySirius
Copy link

Hello, I am currently working with the new LLaMA 3.1 models by Meta and they require the newer versions of transformers, optimum, and accelerate. I ran into compatibility issues with XTTS regarding the version of transformers.

I personally use the inference streaming feature, and that's where I am having issues.

Here is an error log I got:

Traceback (most recent call last):
  File "C:\Users\eyein\OneDrive\Desktop\Files\Discord Bots\JenEva-3.0\cogs\rt_tts_cog.py", line 501, in text_to_speech
    for j, chunk in enumerate(chunks):
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\torch\utils\_contextlib.py", line 35, in generator_context
    response = gen.send(None)
               ^^^^^^^^^^^^^^
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\TTS\tts\models\xtts.py", line 657, in inference_stream
    gpt_generator = self.gpt.get_generator(
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\TTS\tts\layers\xtts\gpt.py", line 602, in get_generator
    return self.gpt_inference.generate_stream(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\TTS\tts\layers\xtts\stream_generator.py", line 117, in generate
    - [~generation.BeamSampleDecoderOnlyOutput]
                             ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\eyein\miniconda3\envs\JenEva\Lib\site-packages\transformers\generation\utils.py", line 489, in _prepare_attention_mask_for_generation
    torch.isin(elements=inputs, test_elements=pad_token_id).any()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ), but expected one of:
 * (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

ERROR: None
@eginhard eginhard added the bug Something isn't working label Jul 31, 2024
@eginhard
Copy link
Member

Yes, also reported in #59 (comment). The streaming code unfortunately relies a lot on internals of the transformers library, so it can break at any time. Best would probably be to pin a specific version that works.

Could you share which exact package requires the latest transformers version?

@SiriuslySirius
Copy link
Author

SiriuslySirius commented Jul 31, 2024

Yes, also reported in #59 (comment). The streaming code unfortunately relies a lot on internals of the transformers library, so it can break at any time. Best would probably be to pin a specific version that works.

Could you share which exact package requires the latest transformers version?

It's not necessarily a package, but rather, it is a dependency for running the latest version of Meta LLaMA, LLaMA 3.1, which uses the transformers library and it was recommended to use the latest version of transformers. Right now, I am running the latest version of what Coqui TTS allows, which works fine, but I have a lot of warning messages about deprecating implementations from the transformers library.

@SiriuslySirius
Copy link
Author

Yeah, I'm currently trying out Google's Gemma 2 LLM and yeah, this is going to be an issue for those who are doing LLM + XTTS. Gemma 2 requires a newer version transformers because it doesn't recognize it in version 4.40.2. So we're left with a choice to be less flexible on what LLMs we can use or drop XTTS completely.

@eginhard
Copy link
Member

eginhard commented Aug 1, 2024

It would be helpful if you shared what package/repo/code you're running to be aware of how Coqui is used and how it is affected by external changes. But for this kind of use case the best solution is probably to put the TTS and the LLM into separate environments, so that their dependencies don't affect each other.

@SiriuslySirius
Copy link
Author

SiriuslySirius commented Aug 1, 2024

For my current use case, if I am using Nextcord for my Discord bot and I have TTS and LLM running in the same "cog", which is a way to isolated bot features grouped into their own "cog" for the sake of modularity. So to separate XTTS from my LLM requires a bit of an architectural change to my private codebase and having to separate them would add a bit more latency between the two modules, which is not ideal for real-time application. Everything runs locally on my machine.

The issue is mainly incompatibility between the versions of transformers required to run newer local open-source LLMs and XTTS.

I'm using inference streaming normally by passing text into the text input parameter as written in the docs for XTTS V2.

@SiriuslySirius
Copy link
Author

SiriuslySirius commented Aug 3, 2024

Yes, also reported in #59 (comment). The streaming code unfortunately relies a lot on internals of the transformers library, so it can break at any time. Best would probably be to pin a specific version that works.

Could you share which exact package requires the latest transformers version?

I tried the patch (https://github.com/h2oai/h2ogpt/blob/52923ac21a1532983c72b45a8e0785f6689dc770/docs/xtt.patch) mentioned in that thread and it worked.

@timwillhack
Copy link

Just throwing this in here because I ran into another set of models that relies on 4.43: Microsoft Phi-3.5-mini-instruct, which apparently is very decent for how small it is. I spent a day attempting to have gpt4o help me make coqui streaming work with transformers 4.43 and it did, I got it to output voice from text! but it added stuff that caused my vram to spike and I'm not familiar enough with neural net code to figure out what it did wrong. Python is also not my strong suit!

@SiriuslySirius
Copy link
Author

Just throwing this in here because I ran into another set of models that relies on 4.43: Microsoft Phi-3.5-mini-instruct, which apparently is very decent for how small it is. I spent a day attempting to have gpt4o help me make coqui streaming work with transformers 4.43 and it did, I got it to output voice from text! but it added stuff that caused my vram to spike and I'm not familiar enough with neural net code to figure out what it did wrong. Python is also not my strong suit!

It would help to see your implementation for streaming to see if it's the problem. It could be the LLM if you are running it locally and it is an issue for some LLMs to spike in VRAM usage as you use it, especially if you feed it with context like a chat history.

@timwillhack
Copy link

I'm just using the xtts/stream_generator.py script. I haven't tried to use Phi-3.5 because it relies on transformers 4.43, but coqui only works up to 4.42.4 or something right now. When I ran the gpt changed script (while transformers 4.43 was installed) it wasn't using other models so the spike in vram was just related to the changes it made (I'm guessing). It was pretty ugly looking to be honest.

@ajkessel
Copy link

ajkessel commented Oct 8, 2024

+1 for this

There is some transformers code that breaks on the Mac M1 family, specifically this:

        if inputs.device.type == "mps":
            # mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
            raise ValueError(
                "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
            )

This appears to be fixed in more recent transformers releases but can't be leveraged by coqui-ai-tts due to incompatibility.

@DrewThomasson
Copy link

I would also greatly appreciate the ability for mps Apple Silicon speedup on xtts inference 🥺

@eginhard
Copy link
Member

Thanks a lot to @JohnnyStreet for submitting a fix for this, which I just merged into the dev branch. Feel free to test this already, I'll wait a bit before releasing a new version with this. Also let me know if there are still breakages using MPS, I don't have a Mac to test.

@DrewThomasson
Copy link

DrewThomasson commented Oct 23, 2024

@eginhard

Running on mps Still Results in BREAKAGE: 😞 ----->

🐍 python version = Python 3.12.7

📦 pip installed latest dev using the command:

pip install git+https://github.com/idiap/coqui-ai-TTS.git@dev

Hardware Used: 💻

Hardware Overview:
  Model Name: MacBook Pro
  Model Identifier: MacBookPro18,1
  Model Number: MK193LL/A
  Chip: Apple M1 Pro
  Total Number of Cores: 10 (8 performance and 2 efficiency)
  Memory: 16 GB
  System Firmware Version: 10151.101.3
  OS Loader Version: 10151.101.3
  Activation Lock Status: Disabled

Code chunk used to test: 👨‍💻

import os
import torch
from TTS.api import TTS
     
# Ensure you're using the MPS device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Print the selected device
print(f"device selected is {device}")

# Initialize the TTS model on the appropriate device
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)

# Run TTS and save output to a file
tts.tts_to_file(text="Hello world!", speaker_wav="ref.wav", language="en", file_path="out$

Error result from running code block above ⬆️ : ❌ 🚫 🌧️

(newtts_test) drew@wmughal-CN4D09397T Desktop % python --version
Python 3.12.7
(newtts_test) drew@wmughal-CN4D09397T Desktop % pip show coqui-tts
Name: coqui-tts
Version: 0.24.2
Summary: Deep learning for Text to Speech.
Home-page: https://github.com/idiap/coqui-ai-TTS
Author: 
Author-email: Eren Gölge <[email protected]>
License: MPL-2.0
Location: /Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages
Requires: anyascii, coqpit, coqui-tts-trainer, cython, einops, encodec, fsspec, gruut, inflect, librosa, matplotlib, num2words, numpy, packaging, pysbd, pyyaml, scipy, soundfile, spacy, torch, torchaudio, tqdm, transformers
Required-by: 
(newtts_test) drew@wmughal-CN4D09397T Desktop % python test.py
device selected is mps
Traceback (most recent call last):
  File "/Users/drew/Desktop/test.py", line 15, in <module>
    tts.tts_to_file(text="Hello world!", speaker_wav="ref.wav", language="en", file_path="output.wav")
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/api.py", line 334, in tts_to_file
    wav = self.tts(
          ^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/api.py", line 276, in tts
    wav = self.synthesizer.tts(
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/utils/synthesizer.py", line 386, in tts
    outputs = self.tts_model.synthesize(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 425, in synthesize
    return self.full_inference(text, speaker_wav, language, **settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 486, in full_inference
    (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 369, in get_conditioning_latents
    speaker_embedding = self.get_speaker_embedding(audio, load_sr)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 324, in get_speaker_embedding
    self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/layers/xtts/hifigan_decoder.py", line 539, in forward
    x = self.torch_spec(x)
        ^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/layers/xtts/hifigan_decoder.py", line 419, in forward
    return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
(newtts_test) drew@wmughal-CN4D09397T Desktop % 

@eginhard
Copy link
Member

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

This looks to be a limitation in Pytorch that will hopefully get fixed in future versions. You can set that environment variable to avoid the error.

@JohnnyStreet
Copy link

JohnnyStreet commented Oct 29, 2024

I posted the original pr and I am not convinced this is a limitation [of PyTorch], but I don't have an mps device available to debug it. This is a very hacky shot in the dark, but as a workaround you might try installing accelerate and then doing

from accelerate import Accelerator
accelerator = Accelerator()

and then reference accelerator.device as your device. I would at least try that even though it might make zero difference.

@DrewThomasson
Copy link

DrewThomasson commented Oct 29, 2024

@eginhard
@JohnnyStreet

Updated testing and still: Running on mps Still Results in BREAKAGE: 😞 ----->

🐍 python version = Python 3.12.7

📦 pip installed latest dev using the command modified with @JohnnyStreet accelerate advice:

pip install git+https://github.com/idiap/coqui-ai-TTS.git@dev accelerate

🥶 pip freeze of my python env can be seen here ⬇️

my_environment_packages.txt

Hardware Used: 💻

Hardware Overview:
  Model Name: MacBook Pro
  Model Identifier: MacBookPro18,1
  Model Number: MK193LL/A
  Chip: Apple M1 Pro
  Total Number of Cores: 10 (8 performance and 2 efficiency)
  Memory: 16 GB
  System Firmware Version: 10151.101.3
  OS Loader Version: 10151.101.3
  Activation Lock Status: Disabled

🏃 Ran using this bash script

export PYTORCH_ENABLE_MPS_FALLBACK=1 as @eginhard suggested
python test.py

Code chunk used to test: 👨‍💻

import os
import sys

# Ensure the environment variable is set before importing any torch-related modules as @eginhard suggested
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# Restart the script with the correct environment (only runs once)
if "restarted" not in os.environ:
    os.environ["restarted"] = "1"
    os.execv(sys.executable, [sys.executable] + sys.argv)

# Now import the required libraries
import torch
import time
from TTS.api import TTS
from accelerate import Accelerator

def tts_generate(tts, device, text, output_file):
    print(f"Generating on {device}...")
    start_time = time.time()

    tts.tts_to_file(
        text=text,
        speaker_wav="ref.wav",
        language="en",
        file_path=output_file,
    )

    elapsed_time = time.time() - start_time
    print(f"Generation time on {device}: {elapsed_time:.2f} seconds")
    return elapsed_time

# Initialize TTS model
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")

# Test on CPU
cpu_device = torch.device("cpu")
tts.to(cpu_device)
cpu_time = tts_generate(tts, "CPU", "Hello world!", "output_cpu.wav")

# Test with `accelerate` (using MPS or CPU fallback) as suggested by @JohnnyStreet 
accelerator = Accelerator()
optimal_device = accelerator.device
tts.to(optimal_device)
accelerate_time = tts_generate(tts, optimal_device, "Hello world!", "output_accelerate.wav")

# Print comparison results
print(f"\nTime difference: {cpu_time - accelerate_time:.2f} seconds")

Error result from running code block above ⬆️ : ❌ 🚫 🌧️

(newtts_test) drew@wmughal-CN4D09397T Downloads % export PYTORCH_ENABLE_MPS_FALLBACK=1
python test.py

Generating on CPU...
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Generation time on CPU: 2.21 seconds
Generating on mps...
Traceback (most recent call last):
  File "/Users/drew/Downloads/test.py", line 45, in <module>
    accelerate_time = tts_generate(tts, optimal_device, "Hello world!", "output_accelerate.wav")
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/Downloads/test.py", line 22, in tts_generate
    tts.tts_to_file(
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/api.py", line 334, in tts_to_file
    wav = self.tts(
          ^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/api.py", line 276, in tts
    wav = self.synthesizer.tts(
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/utils/synthesizer.py", line 386, in tts
    outputs = self.tts_model.synthesize(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 425, in synthesize
    return self.full_inference(text, speaker_wav, language, **settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 486, in full_inference
    (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 369, in get_conditioning_latents
    speaker_embedding = self.get_speaker_embedding(audio, load_sr)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py", line 324, in get_speaker_embedding
    self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/layers/xtts/hifigan_decoder.py", line 539, in forward
    x = self.torch_spec(x)
        ^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/layers/xtts/hifigan_decoder.py", line 419, in forward
    return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
(newtts_test) drew@wmughal-CN4D09397T Downloads % 

@DrewThomasson
Copy link

DrewThomasson commented Oct 29, 2024

@eginhard
@JohnnyStreet

I attempted to fix the issue myself by moving the convolution operation in xtts.py to CPU and adding an explicit attention mask in gpt.py to avoid MPS errors, but found CPU outperformed MPS in inference speed

I'm probably doing something wrong lol. 😞

Results in testing: Using testing env from ➡️ here

(newtts_test) drew@wmughal-CN4D09397T Downloads % export PYTORCH_ENABLE_MPS_FALLBACK=1
python test.py

Generating on CPU...
Generation time on CPU: 2.40 seconds
Generating on mps...
Generation time on mps: 5.57 seconds

Time difference: -3.17 seconds

Overview of all steps I went through and modifications made to which files ⬇️

Overview

I encountered performance issues and errors when trying to run Coqui TTS on MPS (Apple Silicon) with PyTorch on my MacBook Pro (M1). Specifically, I ran into a couple of major problems when attempting to use MPS, and I wanted to document the modifications I made to work around these issues.


Problem Description

  1. MPS-specific Error:

    • The convolution operation (conv1d) in the HiFi-GAN decoder caused a NotImplementedError because MPS doesn’t support output channels greater than 65,536.
    • Error Message:
      NotImplementedError: Output channels > 65536 not supported at the MPS device.
      
  2. Missing Attention Mask Error:

    • The TTS model required an attention mask when running on MPS. Without it, I encountered the following error:
      ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.
      

Modifications I Made

  1. Fix for Convolution Operation:

    • I modified the get_speaker_embedding function in xtts.py to move both input tensors and model weights to the CPU for the problematic operation.

    File Path:
    /Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/models/xtts.py

    Modified Code:

    def get_speaker_embedding(self, audio, sr):
        audio_16k = torchaudio.functional.resample(audio, sr, 16000)
    
        # Move both input and weights to CPU to avoid MPS issues
        audio_16k = audio_16k.to("cpu")
        self.hifigan_decoder.speaker_encoder.to("cpu")
    
        # Perform the operation on CPU
        speaker_embedding = (
            self.hifigan_decoder.speaker_encoder.forward(audio_16k, l2_norm=True)
            .unsqueeze(-1)
        )
    
        # Move the result back to the original device
        return speaker_embedding.to(self.device)
  2. Fix for Attention Mask Error:

    • I modified the generate function in gpt.py to explicitly pass an attention mask.

    File Path:
    /Users/drew/miniconda3/envs/newtts_test/lib/python3.12/site-packages/TTS/tts/layers/xtts/gpt.py

    Modified Code:

    def generate(self, cond_latents, text_inputs, **hf_generate_kwargs):
        gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
    
        # Create an attention mask with ones
        attention_mask = torch.ones_like(gpt_inputs, dtype=torch.long, device=gpt_inputs.device)
    
        # Ensure the attention mask is on the same device as inputs
        attention_mask = attention_mask.to(gpt_inputs.device)
    
        # Pass the attention mask during generation
        gen = self.gpt_inference.generate(
            gpt_inputs,
            bos_token_id=self.start_audio_token,
            pad_token_id=self.stop_audio_token,
            eos_token_id=self.stop_audio_token,
            max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
            attention_mask=attention_mask,
            **hf_generate_kwargs,
        )
    
        if "return_dict_in_generate" in hf_generate_kwargs:
            return gen.sequences[:, gpt_inputs.shape[1]:], gen
    
        return gen[:, gpt_inputs.shape[1]:]

Results

After applying these changes, I was able to run the TTS model on both CPU and MPS. However, the results showed that CPU outperformed MPS in this scenario.

  • CPU Inference Time: 2.40 seconds
  • MPS Inference Time: 5.57 seconds
  • Time Difference: CPU was faster by 3.17 seconds.

Modified Files

my_modified_gpt.py.zip
my_modified_xtts.py.zip

Output Files

output_cpu.wav.zip
output_accelerate.wav.zip

@DrewThomasson
Copy link

Update: tried comparing with raw mps, with accelerate mps, and with only cpu, lol had the same results anyway

Results here ⬇️

Details ### Script used 📑
import torch
import time
from TTS.api import TTS
from accelerate import Accelerator

def tts_generate(tts, device, text, output_file):
  """Generate speech and measure inference time."""
  print(f"Generating on {device}...")
  start_time = time.time()

  tts.tts_to_file(
      text=text,
      speaker_wav="ref.wav",
      language="en",
      file_path=output_file,
  )

  elapsed_time = time.time() - start_time
  print(f"Generation time on {device}: {elapsed_time:.2f} seconds")
  return elapsed_time

def main():
  # Initialize the TTS model
  tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")

  # Test on CPU
  cpu_device = torch.device("cpu")
  tts.to(cpu_device)
  cpu_time = tts_generate(tts, "CPU", "Hello world!", "output_cpu.wav")

  # Test with Raw MPS
  if torch.backends.mps.is_available():
      raw_mps_device = torch.device("mps")
      tts.to(raw_mps_device)
      mps_time = tts_generate(tts, "Raw MPS", "Hello world!", "output_mps.wav")
  else:
      print("MPS is not available on this machine.")
      mps_time = None

  # Test with Accelerate MPS
  accelerator = Accelerator()
  accel_device = accelerator.device
  tts.to(accel_device)
  accelerate_time = tts_generate(tts, f"Accelerate ({accel_device})", "Hello world!", "output_accel_mps.wav")

  # Print comparison results
  print("\n--- Inference Time Comparison ---")
  print(f"CPU Time: {cpu_time:.2f} seconds")
  if mps_time:
      print(f"Raw MPS Time: {mps_time:.2f} seconds")
      print(f"Time difference (CPU vs Raw MPS): {cpu_time - mps_time:.2f} seconds")
  print(f"Accelerate MPS Time: {accelerate_time:.2f} seconds")
  print(f"Time difference (CPU vs Accelerate MPS): {cpu_time - accelerate_time:.2f} seconds")

if __name__ == "__main__":
  main()

Results: 🧪

(newtts_test) drew@wmughal-CN4D09397T Downloads % export PYTORCH_ENABLE_MPS_FALLBACK=1
python test.py

Generating on CPU...
Generation time on CPU: 2.93 seconds
Generating on Raw MPS...
Generation time on Raw MPS: 3.04 seconds
Generating on Accelerate (mps)...
Generation time on Accelerate (mps): 3.16 seconds

--- Inference Time Comparison ---
CPU Time: 2.93 seconds
Raw MPS Time: 3.04 seconds
Time difference (CPU vs Raw MPS): -0.11 seconds
Accelerate MPS Time: 3.16 seconds
Time difference (CPU vs Accelerate MPS): -0.23 seconds

Output Files ⬇️

output_files.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants