Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use str type spk_emb for easy recovery #463

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from typing import Literal, Optional, List, Callable, Tuple, Dict
from json import load
from pathlib import Path
import lzma

import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .model import DVAE, GPT, gen_logits
from .utils import (
Expand Down Expand Up @@ -151,10 +153,28 @@ def unload(self):
delattr(self, module)
self.__init__(logger)

def sample_random_speaker(self):
dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models["spk_stat"].chunk(2)
return torch.randn(dim, device=std.device) * std + mean
def sample_random_speaker(self) -> str:
with torch.no_grad():
spk = self._sample_random_speaker()
arr: np.ndarray = spk.cpu().numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr, spk
return s

def _sample_random_speaker(self) -> torch.Tensor:
with torch.no_grad():
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = torch.randn(dim, device=std.device, dtype=torch.float16).mul_(std).add_(mean)
del out, std, mean
return spk

@dataclass(repr=False, eq=False)
class RefineTextParams:
Expand All @@ -169,7 +189,7 @@ class RefineTextParams:
@dataclass(repr=False, eq=False)
class InferCodeParams:
prompt: str = "[speed_5]"
spk_emb: Optional[torch.Tensor] = None
spk_emb: Optional[str] = None
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.3
Expand Down Expand Up @@ -426,12 +446,18 @@ def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.T
def _apply_spk_emb(
self,
emb: torch.Tensor,
spk_emb: torch.Tensor,
spk_emb: str,
input_ids: torch.Tensor,
text_len: int,
):
n = F.normalize(
spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
torch.from_numpy(
np.frombuffer(lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
), dtype=np.float16).copy(),
).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/cn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/es/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/jp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# ガウス分布から話者をサンプリングします。

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # サンプリングされた話者を追加
Expand Down
1 change: 1 addition & 0 deletions docs/ru/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Выборка говорящего из Гауссиана.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # добавить выбранного говорящего
Expand Down
2 changes: 2 additions & 0 deletions examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
")\n",
Expand Down
2 changes: 2 additions & 0 deletions examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
")\n",
Expand Down
Loading