Skip to content

Commit

Permalink
feat(config): drop download of folder config
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 16, 2024
1 parent 9f402ba commit 27331c3
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.5/rvcmd_linux_amd64.deb
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.6/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import Config
108 changes: 108 additions & 0 deletions ChatTTS/config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from dataclasses import dataclass


@dataclass(repr=False, eq=False)
class Path():
vocos_ckpt_path: str = "asset/Vocos.pt"
dvae_ckpt_path: str = "asset/DVAE.pt"
gpt_ckpt_path: str = "asset/GPT.pt"
decoder_ckpt_path: str = "asset/Decoder.pt"
tokenizer_path: str = "asset/tokenizer.pt"


@dataclass(repr=False, eq=False)
class Decoder():
idim: int = 384
odim: int = 384
hidden: int = 512
n_layer: int = 12
bn_dim: int = 128


@dataclass(repr=False, eq=False)
class VQ():
dim: int = 1024
levels: tuple = (5,5,5,5)
G: int = 2
R: int = 2


@dataclass(repr=False, eq=False)
class DVAE():
decoder: Decoder = Decoder(
idim=512,
odim=512,
hidden=256,
n_layer=12,
bn_dim=128,
)
vq: VQ = VQ()

@dataclass(repr=False, eq=False)
class GPT():
hidden_size: int = 768
intermediate_size: int = 3072
num_attention_heads: int = 12
num_hidden_layers: int = 20
use_cache: bool = False
max_position_embeddings: int = 4096

spk_emb_dim: int = 192
spk_KL: bool = False
num_audio_tokens: int = 626
num_vq: int = 4


@dataclass(repr=False, eq=False)
class FeatureExtractorInitArgs():
sample_rate: int = 24000
n_fft: int = 1024
hop_length: int = 256
n_mels: int = 100
padding: str = "center"

@dataclass(repr=False, eq=False)
class FeatureExtractor():
class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures"
init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs()


@dataclass(repr=False, eq=False)
class BackboneInitArgs():
input_channels: int = 100
dim: int = 512
intermediate_dim: int = 1536
num_layers: int = 8

@dataclass(repr=False, eq=False)
class Backbone():
class_path: str = "vocos.models.VocosBackbone"
init_args: BackboneInitArgs = BackboneInitArgs()


@dataclass(repr=False, eq=False)
class FourierHeadInitArgs():
dim: int = 512
n_fft: int = 1024
hop_length: int = 256
padding: str = "center"

@dataclass(repr=False, eq=False)
class FourierHead():
class_path: str = "vocos.heads.ISTFTHead"
init_args: FourierHeadInitArgs = FourierHeadInitArgs()


@dataclass(repr=False, eq=False)
class Vocos():
feature_extractor: FeatureExtractor = FeatureExtractor()
backbone: Backbone = Backbone()
head: FourierHead = FourierHead()

@dataclass(repr=False, eq=False)
class Config():
path: Path = Path()
decoder: Decoder = Decoder()
dvae: DVAE = DVAE()
gpt: GPT = GPT()
vocos: Vocos = Vocos()
145 changes: 74 additions & 71 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import logging
import tempfile
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict
from json import load
from pathlib import Path
Expand All @@ -12,9 +12,11 @@
import torch.nn.functional as F
from omegaconf import OmegaConf
from vocos import Vocos
from vocos.pretrained import instantiate_class
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .config import Config
from .model import DVAE, GPT, gen_logits, Tokenizer
from .utils import (
check_all_assets,
Expand All @@ -33,6 +35,8 @@ def __init__(self, logger=logging.getLogger(__name__)):
self.logger = logger
utils_logger.set_logger(logger)

self.config = Config()

self.normalizer = Normalizer(
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"),
logger,
Expand Down Expand Up @@ -137,12 +141,7 @@ def load(
compile=compile,
coef=coef,
use_flash_attn=use_flash_attn,
**{
k: os.path.join(download_path, v)
for k, v in OmegaConf.load(
os.path.join(download_path, "config", "path.yaml")
).items()
},
**asdict(self.config.path),
)

def unload(self):
Expand Down Expand Up @@ -243,13 +242,9 @@ def interrupt(self):
@torch.no_grad()
def _load(
self,
vocos_config_path: str = None,
vocos_ckpt_path: str = None,
dvae_config_path: str = None,
dvae_ckpt_path: str = None,
gpt_config_path: str = None,
gpt_ckpt_path: str = None,
decoder_config_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: Optional[torch.device] = None,
Expand All @@ -263,67 +258,75 @@ def _load(
self.device = device
self.compile = compile

if vocos_config_path:
vocos = (
Vocos.from_hparams(vocos_config_path)
.to(
# vocos on mps will crash, use cpu fallback
"cpu"
if "mps" in str(device)
else device
)
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(
torch.load(vocos_ckpt_path, weights_only=True, mmap=True)
)
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(
torch.load(dvae_ckpt_path, weights_only=True, mmap=True)
)
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT(
**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
spk_stat_path
), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
mmap=True,
map_location=device,
feature_extractor = instantiate_class(args=(), init=asdict(self.config.vocos.feature_extractor))
backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone))
head = instantiate_class(args=(), init=asdict(self.config.vocos.head))
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
.to(
# vocos on mps will crash, use cpu fallback
"cpu"
if "mps" in str(device)
else device
)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
)
self.decoder = decoder
self.logger.log(logging.INFO, "decoder loaded.")
.eval()
)
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
vocos.load_state_dict(
torch.load(vocos_ckpt_path, weights_only=True, mmap=True)
)
self.vocos = vocos
self.logger.log(logging.INFO, "vocos loaded.")

dvae = DVAE(
decoder_config=asdict(self.config.dvae.decoder),
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
).to(device).eval()
coef = str(dvae)
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
dvae.load_state_dict(
torch.load(dvae_ckpt_path, weights_only=True, mmap=True)
)
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")

gpt = GPT(
gpt_config=asdict(self.config.gpt),
use_flash_attn=use_flash_attn,
device=device,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
spk_stat_path
), f"Missing spk_stat.pt: {spk_stat_path}"
spk_stat: torch.Tensor = torch.load(
spk_stat_path,
weights_only=True,
mmap=True,
map_location=device,
)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

decoder = DVAE(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
decoder.load_state_dict(
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
)
self.decoder = decoder
self.logger.log(logging.INFO, "decoder loaded.")

if tokenizer_path:
self.tokenizer = Tokenizer(tokenizer_path, device)
Expand Down
6 changes: 3 additions & 3 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
super(GFSQ, self).__init__()
self.quantizer = GroupedResidualFSQ(
dim=dim,
levels=levels,
levels=list(levels),
num_quantizers=R,
groups=G,
)
Expand Down Expand Up @@ -169,8 +169,8 @@ def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
class DVAE(nn.Module):
def __init__(
self,
decoder_config,
vq_config,
decoder_config: dict,
vq_config: Optional[dict]=None,
dim=512,
coef: Optional[str] = None,
):
Expand Down
17 changes: 5 additions & 12 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import os, platform

os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

import platform
from dataclasses import dataclass
import logging
from typing import Union, List, Optional, Tuple

import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -28,9 +21,9 @@
class GPT(nn.Module):
def __init__(
self,
gpt_config: omegaconf.DictConfig,
num_audio_tokens: int,
num_text_tokens: int,
gpt_config: dict,
num_audio_tokens: int = 626,
num_text_tokens: int = 21178,
num_vq=4,
use_flash_attn=False,
device=torch.device("cpu"),
Expand Down Expand Up @@ -100,7 +93,7 @@ def get(self) -> bool:

def _build_llama(
self,
config: omegaconf.DictConfig,
config: dict,
device: torch.device,
) -> LlamaModel:

Expand Down
7 changes: 7 additions & 0 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
"""
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from typing import List, Tuple

import torch
Expand Down
Loading

0 comments on commit 27331c3

Please sign in to comment.