Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ZillaRU authored Jul 28, 2024
2 parents c8ca35c + 06b823b commit 6cbe15f
Show file tree
Hide file tree
Showing 20 changed files with 5,413 additions and 49 deletions.
1 change: 1 addition & 0 deletions ChatTTS/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class GPT:
spk_emb_dim: int = 192
spk_KL: bool = False
num_audio_tokens: int = 626
num_text_tokens: int = 21178
num_vq: int = 4


Expand Down
124 changes: 103 additions & 21 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
from pathlib import Path
import lzma

import numpy as np
import torch
from vocos import Vocos
from vocos.pretrained import instantiate_class
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .config import Config
from .model import DVAE, GPT, gen_logits, Tokenizer
Expand Down Expand Up @@ -130,6 +128,7 @@ def load(
device: Optional[torch.device] = None,
coef: Optional[torch.Tensor] = None,
use_flash_attn=False,
use_vllm=False,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
if download_path is None:
Expand All @@ -139,6 +138,7 @@ def load(
compile=compile,
coef=coef,
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
**{
k: os.path.join(download_path, v)
for k, v in asdict(self.config.path).items()
Expand Down Expand Up @@ -167,7 +167,7 @@ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:

@torch.no_grad()
def _sample_random_speaker(self) -> torch.Tensor:
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
dim: int = self.config.gpt.hidden_size
spk = (
torch.randn(dim, device=self.std.device, dtype=self.std.dtype)
.mul_(self.std)
Expand Down Expand Up @@ -246,11 +246,13 @@ def _load(
compile: bool = True,
coef: Optional[str] = None,
use_flash_attn=False,
use_vllm=False,
):
if device is None:
device = select_device()
self.logger.info("use device %s", str(device))
self.device = device
self.device_gpt = device if "mps" not in str(device) else torch.device("cpu")
self.compile = compile

feature_extractor = instantiate_class(
Expand Down Expand Up @@ -293,13 +295,15 @@ def _load(
gpt = GPT(
gpt_config=asdict(self.config.gpt),
use_flash_attn=use_flash_attn,
use_vllm=use_vllm,
device=device,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.from_pretrained(gpt_ckpt_path)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt

spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
Expand Down Expand Up @@ -469,7 +473,7 @@ def _infer_code(
assert len(text), "text should not be empty"

if not isinstance(params.temperature, list):
temperature = [params.temperature] * gpt.num_vq
temperature = [params.temperature] * self.config.gpt.num_vq
else:
temperature = params.temperature

Expand All @@ -495,11 +499,62 @@ def _infer_code(

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.gpt.num_vq,
self.config.gpt.num_vq,
prompt_str=params.spk_smp,
device=gpt.device_gpt,
device=self.device_gpt,
)
start_idx = input_ids.shape[-2]

num_code = self.config.gpt.num_audio_tokens - 1

logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

if gpt.is_vllm:
from .model.velocity import SamplingParams

sample_params = SamplingParams(
temperature=temperature,
max_new_token=params.max_new_token,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_processors, logits_warpers),
eos_token=num_code,
infer_text=False,
start_idx=start_idx,
)
input_ids = [i.tolist() for i in input_ids]

result = gpt.llm.generate(
None,
sample_params,
input_ids,
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)

del text_mask, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return [
GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]

emb = gpt(input_ids, text_mask)

del text_mask
Expand All @@ -509,15 +564,6 @@ def _infer_code(
emb, params.spk_emb, input_ids, self.gpt.device_gpt
)

num_code = int(gpt.emb_code[0].num_embeddings - 1)

logits_warpers, logits_processors = gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

result = gpt.generate(
emb,
input_ids,
Expand All @@ -526,8 +572,7 @@ def _infer_code(
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
Expand Down Expand Up @@ -560,8 +605,8 @@ def _refine_text(

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.gpt.num_vq,
device=gpt.device_gpt,
self.config.gpt.num_vq,
device=self.device_gpt,
)

logits_warpers, logits_processors = gen_logits(
Expand All @@ -571,6 +616,44 @@ def _refine_text(
repetition_penalty=params.repetition_penalty,
)

if gpt.is_vllm:
from .model.velocity import SamplingParams

sample_params = SamplingParams(
repetition_penalty=params.repetition_penalty,
temperature=params.temperature,
top_p=params.top_P,
top_k=params.top_K,
max_new_token=params.max_new_token,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_processors, logits_warpers),
eos_token=self.tokenizer.eos_token,
infer_text=True,
start_idx=input_ids.shape[-2],
)
input_ids_list = [i.tolist() for i in input_ids]
del input_ids

result = gpt.llm.generate(
None, sample_params, input_ids_list, params.show_tqdm
)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states)

del text_mask, input_ids_list, result
del_all(logits_warpers)
del_all(logits_processors)

return GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
)

emb = gpt(input_ids, text_mask)

del text_mask
Expand All @@ -584,8 +667,7 @@ def _refine_text(
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=True,
stream=False,
show_tqdm=params.show_tqdm,
Expand Down
Loading

0 comments on commit 6cbe15f

Please sign in to comment.