diff --git a/.github/workflows/checksum.yml b/.github/workflows/checksum.yml index 633064551..28c9b60da 100644 --- a/.github/workflows/checksum.yml +++ b/.github/workflows/checksum.yml @@ -13,7 +13,7 @@ jobs: - name: Run RVC-Models-Downloader run: | - wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.5/rvcmd_linux_amd64.deb + wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.6/rvcmd_linux_amd64.deb sudo apt -y install ./rvcmd_linux_amd64.deb rm -f ./rvcmd_linux_amd64.deb rvcmd -notrs -w 1 -notui assets/chtts diff --git a/ChatTTS/config/__init__.py b/ChatTTS/config/__init__.py new file mode 100644 index 000000000..cca5d9bdf --- /dev/null +++ b/ChatTTS/config/__init__.py @@ -0,0 +1 @@ +from .config import Config diff --git a/ChatTTS/config/config.py b/ChatTTS/config/config.py new file mode 100644 index 000000000..0f053bf6f --- /dev/null +++ b/ChatTTS/config/config.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass + + +@dataclass(repr=False, eq=False) +class Path(): + vocos_ckpt_path: str = "asset/Vocos.pt" + dvae_ckpt_path: str = "asset/DVAE.pt" + gpt_ckpt_path: str = "asset/GPT.pt" + decoder_ckpt_path: str = "asset/Decoder.pt" + tokenizer_path: str = "asset/tokenizer.pt" + + +@dataclass(repr=False, eq=False) +class Decoder(): + idim: int = 384 + odim: int = 384 + hidden: int = 512 + n_layer: int = 12 + bn_dim: int = 128 + + +@dataclass(repr=False, eq=False) +class VQ(): + dim: int = 1024 + levels: tuple = (5,5,5,5) + G: int = 2 + R: int = 2 + + +@dataclass(repr=False, eq=False) +class DVAE(): + decoder: Decoder = Decoder( + idim=512, + odim=512, + hidden=256, + n_layer=12, + bn_dim=128, + ) + vq: VQ = VQ() + +@dataclass(repr=False, eq=False) +class GPT(): + hidden_size: int = 768 + intermediate_size: int = 3072 + num_attention_heads: int = 12 + num_hidden_layers: int = 20 + use_cache: bool = False + max_position_embeddings: int = 4096 + + spk_emb_dim: int = 192 + spk_KL: bool = False + num_audio_tokens: int = 626 + num_vq: int = 4 + + +@dataclass(repr=False, eq=False) +class FeatureExtractorInitArgs(): + sample_rate: int = 24000 + n_fft: int = 1024 + hop_length: int = 256 + n_mels: int = 100 + padding: str = "center" + +@dataclass(repr=False, eq=False) +class FeatureExtractor(): + class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures" + init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs() + + +@dataclass(repr=False, eq=False) +class BackboneInitArgs(): + input_channels: int = 100 + dim: int = 512 + intermediate_dim: int = 1536 + num_layers: int = 8 + +@dataclass(repr=False, eq=False) +class Backbone(): + class_path: str = "vocos.models.VocosBackbone" + init_args: BackboneInitArgs = BackboneInitArgs() + + +@dataclass(repr=False, eq=False) +class FourierHeadInitArgs(): + dim: int = 512 + n_fft: int = 1024 + hop_length: int = 256 + padding: str = "center" + +@dataclass(repr=False, eq=False) +class FourierHead(): + class_path: str = "vocos.heads.ISTFTHead" + init_args: FourierHeadInitArgs = FourierHeadInitArgs() + + +@dataclass(repr=False, eq=False) +class Vocos(): + feature_extractor: FeatureExtractor = FeatureExtractor() + backbone: Backbone = Backbone() + head: FourierHead = FourierHead() + +@dataclass(repr=False, eq=False) +class Config(): + path: Path = Path() + decoder: Decoder = Decoder() + dvae: DVAE = DVAE() + gpt: GPT = GPT() + vocos: Vocos = Vocos() diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 7b7674396..7b1bb57d7 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -1,7 +1,7 @@ import os import logging import tempfile -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Literal, Optional, List, Tuple, Dict from json import load from pathlib import Path @@ -12,9 +12,11 @@ import torch.nn.functional as F from omegaconf import OmegaConf from vocos import Vocos +from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download import pybase16384 as b14 +from .config import Config from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( check_all_assets, @@ -33,6 +35,8 @@ def __init__(self, logger=logging.getLogger(__name__)): self.logger = logger utils_logger.set_logger(logger) + self.config = Config() + self.normalizer = Normalizer( os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"), logger, @@ -137,12 +141,7 @@ def load( compile=compile, coef=coef, use_flash_attn=use_flash_attn, - **{ - k: os.path.join(download_path, v) - for k, v in OmegaConf.load( - os.path.join(download_path, "config", "path.yaml") - ).items() - }, + **asdict(self.config.path), ) def unload(self): @@ -243,13 +242,9 @@ def interrupt(self): @torch.no_grad() def _load( self, - vocos_config_path: str = None, vocos_ckpt_path: str = None, - dvae_config_path: str = None, dvae_ckpt_path: str = None, - gpt_config_path: str = None, gpt_ckpt_path: str = None, - decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: Optional[torch.device] = None, @@ -263,67 +258,75 @@ def _load( self.device = device self.compile = compile - if vocos_config_path: - vocos = ( - Vocos.from_hparams(vocos_config_path) - .to( - # vocos on mps will crash, use cpu fallback - "cpu" - if "mps" in str(device) - else device - ) - .eval() - ) - assert vocos_ckpt_path, "vocos_ckpt_path should not be None" - vocos.load_state_dict( - torch.load(vocos_ckpt_path, weights_only=True, mmap=True) - ) - self.vocos = vocos - self.logger.log(logging.INFO, "vocos loaded.") - - if dvae_config_path: - cfg = OmegaConf.load(dvae_config_path) - dvae = DVAE(**cfg, coef=coef).to(device).eval() - coef = str(dvae) - assert dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_state_dict( - torch.load(dvae_ckpt_path, weights_only=True, mmap=True) - ) - self.dvae = dvae - self.logger.log(logging.INFO, "dvae loaded.") - - if gpt_config_path: - cfg = OmegaConf.load(gpt_config_path) - gpt = GPT( - **cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger - ).eval() - assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) - gpt.prepare(compile=compile and "cuda" in str(device)) - self.gpt = gpt - spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") - assert os.path.exists( - spk_stat_path - ), f"Missing spk_stat.pt: {spk_stat_path}" - spk_stat: torch.Tensor = torch.load( - spk_stat_path, - weights_only=True, - mmap=True, - map_location=device, + feature_extractor = instantiate_class(args=(), init=asdict(self.config.vocos.feature_extractor)) + backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone)) + head = instantiate_class(args=(), init=asdict(self.config.vocos.head)) + vocos = ( + Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head) + .to( + # vocos on mps will crash, use cpu fallback + "cpu" + if "mps" in str(device) + else device ) - self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) - self.logger.log(logging.INFO, "gpt loaded.") - - if decoder_config_path: - cfg = OmegaConf.load(decoder_config_path) - decoder = DVAE(**cfg, coef=coef).to(device).eval() - coef = str(decoder) - assert decoder_ckpt_path, "decoder_ckpt_path should not be None" - decoder.load_state_dict( - torch.load(decoder_ckpt_path, weights_only=True, mmap=True) - ) - self.decoder = decoder - self.logger.log(logging.INFO, "decoder loaded.") + .eval() + ) + assert vocos_ckpt_path, "vocos_ckpt_path should not be None" + vocos.load_state_dict( + torch.load(vocos_ckpt_path, weights_only=True, mmap=True) + ) + self.vocos = vocos + self.logger.log(logging.INFO, "vocos loaded.") + + dvae = DVAE( + decoder_config=asdict(self.config.dvae.decoder), + vq_config=asdict(self.config.dvae.vq), + dim=self.config.dvae.decoder.idim, + coef=coef, + ).to(device).eval() + coef = str(dvae) + assert dvae_ckpt_path, "dvae_ckpt_path should not be None" + dvae.load_state_dict( + torch.load(dvae_ckpt_path, weights_only=True, mmap=True) + ) + self.dvae = dvae + self.logger.log(logging.INFO, "dvae loaded.") + + gpt = GPT( + gpt_config=asdict(self.config.gpt), + use_flash_attn=use_flash_attn, + device=device, + logger=self.logger, + ).eval() + assert gpt_ckpt_path, "gpt_ckpt_path should not be None" + gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) + gpt.prepare(compile=compile and "cuda" in str(device)) + self.gpt = gpt + spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") + assert os.path.exists( + spk_stat_path + ), f"Missing spk_stat.pt: {spk_stat_path}" + spk_stat: torch.Tensor = torch.load( + spk_stat_path, + weights_only=True, + mmap=True, + map_location=device, + ) + self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) + self.logger.log(logging.INFO, "gpt loaded.") + + decoder = DVAE( + decoder_config=asdict(self.config.decoder), + dim=self.config.decoder.idim, + coef=coef, + ).to(device).eval() + coef = str(decoder) + assert decoder_ckpt_path, "decoder_ckpt_path should not be None" + decoder.load_state_dict( + torch.load(decoder_ckpt_path, weights_only=True, mmap=True) + ) + self.decoder = decoder + self.logger.log(logging.INFO, "decoder loaded.") if tokenizer_path: self.tokenizer = Tokenizer(tokenizer_path, device) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 9260cf38d..9132f8384 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -72,7 +72,7 @@ def __init__( super(GFSQ, self).__init__() self.quantizer = GroupedResidualFSQ( dim=dim, - levels=levels, + levels=list(levels), num_quantizers=R, groups=G, ) @@ -169,8 +169,8 @@ def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: class DVAE(nn.Module): def __init__( self, - decoder_config, - vq_config, + decoder_config: dict, + vq_config: Optional[dict]=None, dim=512, coef: Optional[str] = None, ): diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 6228aeec9..9aa33ef61 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,15 +1,8 @@ -import os, platform - -os.environ["TOKENIZERS_PARALLELISM"] = "false" -""" -https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning -""" - +import platform from dataclasses import dataclass import logging from typing import Union, List, Optional, Tuple -import omegaconf import torch import torch.nn as nn import torch.nn.functional as F @@ -28,9 +21,9 @@ class GPT(nn.Module): def __init__( self, - gpt_config: omegaconf.DictConfig, - num_audio_tokens: int, - num_text_tokens: int, + gpt_config: dict, + num_audio_tokens: int = 626, + num_text_tokens: int = 21178, num_vq=4, use_flash_attn=False, device=torch.device("cpu"), @@ -100,7 +93,7 @@ def get(self) -> bool: def _build_llama( self, - config: omegaconf.DictConfig, + config: dict, device: torch.device, ) -> LlamaModel: diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index b32964353..c90e71236 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -1,3 +1,10 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +""" +https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning +""" + from typing import List, Tuple import torch diff --git a/ChatTTS/utils/dl.py b/ChatTTS/utils/dl.py index da8984b44..1df57eb5d 100644 --- a/ChatTTS/utils/dl.py +++ b/ChatTTS/utils/dl.py @@ -61,22 +61,6 @@ def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) - ): return False - logger.get_logger().info("checking configs...") - current_dir = base_dir / "config" - names = [ - "decoder.yaml", - "dvae.yaml", - "gpt.yaml", - "path.yaml", - "vocos.yaml", - ] - for model in names: - menv = model.replace(".", "_") - if not check_model( - current_dir, model, sha256_map[f"sha256_config_{menv}"], update - ): - return False - logger.get_logger().info("all assets are already latest.") return True @@ -117,7 +101,7 @@ def download_dns_yaml(url: str, folder: str): logger.get_logger().info(f"downloaded into {folder}") -def download_all_assets(tmpdir: str, version="0.2.5"): +def download_all_assets(tmpdir: str, version="0.2.6"): import subprocess import platform diff --git a/tools/checksum/tmpl.go b/tools/checksum/tmpl.go index 6bf08c59d..47b216058 100644 --- a/tools/checksum/tmpl.go +++ b/tools/checksum/tmpl.go @@ -7,12 +7,6 @@ var files = [...]string{ "asset/spk_stat.pt", "asset/tokenizer.pt", "asset/Vocos.pt", - - "config/decoder.yaml", - "config/dvae.yaml", - "config/gpt.yaml", - "config/path.yaml", - "config/vocos.yaml", } const jsontmpl = `{ @@ -22,11 +16,5 @@ const jsontmpl = `{ "sha256_asset_spk_stat_pt" : "%s", "sha256_asset_tokenizer_pt" : "%s", "sha256_asset_Vocos_pt" : "%s", - - "sha256_config_decoder_yaml": "%s", - "sha256_config_dvae_yaml" : "%s", - "sha256_config_gpt_yaml" : "%s", - "sha256_config_path_yaml" : "%s", - "sha256_config_vocos_yaml" : "%s" } `