Skip to content

Commit

Permalink
fix(gpt): add flash_attention_2 import
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 2, 2024
1 parent c109089 commit 7515712
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
13 changes: 1 addition & 12 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,7 @@ def _load(
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
if compile and "cuda" in str(device):
try:
gpt.forward = torch.compile(
gpt.forward, backend="inductor", dynamic=True
)
gpt.gpt.forward = torch.compile(
gpt.gpt.forward, backend="inductor", dynamic=True
)
except RuntimeError as e:
self.logger.warning(
f"compile failed: {e}. fallback to normal mode."
)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
Expand Down
26 changes: 22 additions & 4 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get(self) -> bool:
def _build_llama(
self, config: omegaconf.DictConfig, device: torch.device
) -> LlamaModel:

model = None

if "cuda" in str(device) and platform.system().lower() == "linux":
try:
from .cuda import TELlamaModel
Expand All @@ -109,16 +111,32 @@ def _build_llama(
self.logger.warn(
f"use default LlamaModel for importing TELlamaModel error: {e}"
)
if is_flash_attn_2_available():
llama_config = LlamaConfig(**config, attn_implementation="flash_attention_2")
else:
llama_config = LlamaConfig(**config)

if model is None:
if is_flash_attn_2_available():
llama_config = LlamaConfig(
**config,
attn_implementation="flash_attention_2",
)
else:
llama_config = LlamaConfig(**config)
model = LlamaModel(llama_config)
del model.embed_tokens

return model.to(device)

def prepare(self, compile=False):
if is_flash_attn_2_available():
self.gpt = self.gpt.to(dtype=torch.float16)
if compile:
try:
self.compile(backend="inductor", dynamic=True)
self.gpt.compile(backend="inductor", dynamic=True)
except RuntimeError as e:
self.logger.warning(
f"compile failed: {e}. fallback to normal mode."
)

def __call__(
self, input_ids: torch.Tensor, text_mask: torch.Tensor
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ pip install -r requirements.txt
> The installation process is very slow.
> [!Warning]
> The TransformerEngine adaption is currently developing and CANNOT run properly now.
> Only install it in developing purpose.
> The adaptation of TransformerEngine is currently under development and CANNOT run properly now.
> Only install it on developing purpose.
```bash
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Expand Down

0 comments on commit 7515712

Please sign in to comment.