diff --git a/.github/workflows/checksum.yml b/.github/workflows/checksum.yml index 55e6843d5..162c43e12 100644 --- a/.github/workflows/checksum.yml +++ b/.github/workflows/checksum.yml @@ -13,7 +13,7 @@ jobs: - name: Run RVC-Models-Downloader run: | - wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.7/rvcmd_linux_amd64.deb + wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.8/rvcmd_linux_amd64.deb sudo apt -y install ./rvcmd_linux_amd64.deb rm -f ./rvcmd_linux_amd64.deb rvcmd -notrs -w 1 -notui assets/chtts diff --git a/ChatTTS/config/config.py b/ChatTTS/config/config.py index 904fee5cb..1f3e76280 100644 --- a/ChatTTS/config/config.py +++ b/ChatTTS/config/config.py @@ -5,9 +5,10 @@ class Path: vocos_ckpt_path: str = "asset/Vocos.pt" dvae_ckpt_path: str = "asset/DVAE_full.pt" - gpt_ckpt_path: str = "asset/GPT.pt" + gpt_ckpt_path: str = "asset/gpt" decoder_ckpt_path: str = "asset/Decoder.pt" tokenizer_path: str = "asset/tokenizer" + embed_path: str = "asset/Embed.safetensors" @dataclass(repr=False, eq=False) @@ -62,6 +63,14 @@ class GPT: num_vq: int = 4 +@dataclass(repr=False, eq=False) +class Embed: + hidden_size: int = 768 + num_audio_tokens: int = 626 + num_text_tokens: int = 21178 + num_vq: int = 4 + + @dataclass(repr=False, eq=False) class FeatureExtractorInitArgs: sample_rate: int = 24000 @@ -118,6 +127,7 @@ class Config: decoder: Decoder = Decoder() dvae: DVAE = DVAE() gpt: GPT = GPT() + embed: Embed = Embed() vocos: Vocos = Vocos() spk_stat: str = ( "愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉癅乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆" diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 9f730bdf4..15099dded 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -13,7 +13,7 @@ from huggingface_hub import snapshot_download from .config import Config -from .model import DVAE, GPT, gen_logits, Tokenizer, Speaker +from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker from .utils import ( check_all_assets, download_all_assets, @@ -46,7 +46,7 @@ def __init__(self, logger=logging.getLogger(__name__)): def has_loaded(self, use_decoder=False): not_finish = False - check_list = ["vocos", "gpt", "tokenizer"] + check_list = ["vocos", "gpt", "tokenizer", "embed"] if use_decoder: check_list.append("decoder") @@ -97,7 +97,7 @@ def download_models( try: download_path = snapshot_download( repo_id="2Noise/ChatTTS", - allow_patterns=["*.pt", "*.yaml", "*.json"], + allow_patterns=["*.pt", "*.yaml", "*.json", "*.safetensors"], ) except: download_path = None @@ -150,7 +150,7 @@ def unload(self): self.normalizer.destroy() del self.normalizer del self.sha256_map - del_list = ["vocos", "gpt", "decoder", "dvae", "tokenizer"] + del_list = ["vocos", "gpt", "decoder", "dvae", "tokenizer", "embed"] for module in del_list: if hasattr(self, module): delattr(self, module) @@ -228,6 +228,7 @@ def _load( vocos_ckpt_path: str = None, dvae_ckpt_path: str = None, gpt_ckpt_path: str = None, + embed_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: Optional[torch.device] = None, @@ -281,8 +282,19 @@ def _load( self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") + embed = Embed( + self.config.embed.hidden_size, + self.config.embed.num_audio_tokens, + self.config.embed.num_text_tokens, + self.config.embed.num_vq, + ) + embed.from_pretrained(embed_path) + self.embed = embed + self.logger.log(logging.INFO, "embed loaded.") + gpt = GPT( gpt_config=asdict(self.config.gpt), + embed=self.embed, use_flash_attn=use_flash_attn, use_vllm=use_vllm, device=device, @@ -290,14 +302,15 @@ def _load( logger=self.logger, ).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.from_pretrained(gpt_ckpt_path, experimental=experimental) + gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt + self.logger.log(logging.INFO, "gpt loaded.") self.speaker = Speaker( self.config.gpt.hidden_size, self.config.spk_stat, device ) - self.logger.log(logging.INFO, "gpt loaded.") + self.logger.log(logging.INFO, "speaker loaded.") decoder = ( DVAE( @@ -528,7 +541,7 @@ def _infer_code( ), ] - emb = gpt(input_ids, text_mask) + emb = self.embed(input_ids, text_mask) del text_mask @@ -626,7 +639,7 @@ def _refine_text( attentions=[], ) - emb = gpt(input_ids, text_mask) + emb = self.embed(input_ids, text_mask) del text_mask diff --git a/ChatTTS/model/__init__.py b/ChatTTS/model/__init__.py index 4a8bcde8f..e14b41646 100644 --- a/ChatTTS/model/__init__.py +++ b/ChatTTS/model/__init__.py @@ -1,4 +1,5 @@ from .dvae import DVAE +from .embed import Embed from .gpt import GPT from .processors import gen_logits from .speaker import Speaker diff --git a/ChatTTS/model/embed.py b/ChatTTS/model/embed.py new file mode 100644 index 000000000..02aa5525e --- /dev/null +++ b/ChatTTS/model/embed.py @@ -0,0 +1,80 @@ +from safetensors.torch import safe_open +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import weight_norm + + +class Embed(nn.Module): + def __init__( + self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 + ): + super().__init__() + + self.num_vq = num_vq + self.num_audio_tokens = num_audio_tokens + + self.model_dim = hidden_size + self.emb_code = nn.ModuleList( + [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], + ) + self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) + + self.head_text = weight_norm( + nn.Linear(self.model_dim, num_text_tokens, bias=False), + name="weight", + ) + self.head_code = nn.ModuleList( + [ + weight_norm( + nn.Linear(self.model_dim, num_audio_tokens, bias=False), + name="weight", + ) + for _ in range(self.num_vq) + ], + ) + + @torch.inference_mode() + def from_pretrained(self, filename: str): + state_dict_tensors = {} + with safe_open(filename, framework="pt") as f: + for k in f.keys(): + state_dict_tensors[k] = f.get_tensor(k) + self.load_state_dict(state_dict_tensors) + + def __call__( + self, input_ids: torch.Tensor, text_mask: torch.Tensor + ) -> torch.Tensor: + """ + get_emb + """ + return super().__call__(input_ids, text_mask) + + @torch.inference_mode() + def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: + """ + get_emb + """ + device = next(self.parameters()).device + emb_text: torch.Tensor = self.emb_text( + input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) + ) + + text_mask_inv = text_mask.logical_not().to(device) + masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) + + emb_code = [ + self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) + ] + emb_code = torch.stack(emb_code, 2).sum(2) + + emb = torch.zeros( + (input_ids.shape[:-1]) + (emb_text.shape[-1],), + device=emb_text.device, + dtype=emb_text.dtype, + ) + emb[text_mask] = emb_text + emb[text_mask_inv] = emb_code.to(emb.dtype) + + del emb_text, emb_code, text_mask_inv + + return emb diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 5b918b621..576ecdfc4 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,15 +1,13 @@ -import os, platform +import platform from dataclasses import dataclass import logging from typing import Union, List, Optional, Tuple, Callable import gc -from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.parametrize as P -from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm from transformers import LlamaModel, LlamaConfig from transformers.cache_utils import Cache @@ -17,12 +15,14 @@ from transformers.utils import is_flash_attn_2_available from ..utils import del_all +from .embed import Embed class GPT(nn.Module): def __init__( self, gpt_config: dict, + embed: Embed, use_flash_attn=False, use_vllm=False, device=torch.device("cpu"), @@ -38,7 +38,6 @@ def __init__( self.generator = torch.Generator(device=device) - self.config = gpt_config self.num_vq = int(gpt_config["num_vq"]) self.num_audio_tokens = int(gpt_config["num_audio_tokens"]) self.num_text_tokens = int(gpt_config["num_text_tokens"]) @@ -50,88 +49,33 @@ def __init__( if self.is_vllm: return - self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) + self.llama_config = self._build_llama_config(gpt_config) - self.model_dim = int(self.gpt.config.hidden_size) - self.emb_code = nn.ModuleList( - [ - nn.Embedding( - self.num_audio_tokens, - self.model_dim, - device=self.device_gpt, - ) - for _ in range(self.num_vq) - ], - ) - self.emb_text = nn.Embedding( - self.num_text_tokens, self.model_dim, device=self.device_gpt - ) - - self.head_text = weight_norm( - nn.Linear( - self.model_dim, - self.num_text_tokens, - bias=False, - device=device, - ), - name="weight", - ) - self.head_code = nn.ModuleList( - [ - weight_norm( - nn.Linear( - self.model_dim, - self.num_audio_tokens, - bias=False, - device=device, - ), - name="weight", - ) - for _ in range(self.num_vq) - ], - ) + self.emb_code = [ec.__call__ for ec in embed.emb_code] + self.emb_text = embed.emb_text.__call__ + self.head_text = embed.head_text.__call__ + self.head_code = [hc.__call__ for hc in embed.head_code] - def from_pretrained(self, file_path: str, experimental=False): + def from_pretrained( + self, gpt_folder: str, embed_file_path: str, experimental=False + ): if self.is_vllm and platform.system().lower() == "linux": - from safetensors.torch import save_file - - from .velocity import LLM, PostModel - - vllm_folder = Path(os.getcwd()) / "asset" / "vllm" - if not os.path.exists(vllm_folder): - self.logger.info("initializing vLLM model to %s", str(vllm_folder)) - vllm_folder.mkdir(mode=0o755, parents=True, exist_ok=True) - gpt = GPT(gpt_config=self.config) - gpt.from_pretrained(file_path) - gpt.gpt.save_pretrained(vllm_folder / "gpt") - post_model = ( - PostModel( - int(gpt.gpt.config.hidden_size), - self.num_audio_tokens, - self.num_text_tokens, - ) - .to(self.device) - .eval() - ) - post_model.emb_code = gpt.emb_code - post_model.emb_text = gpt.emb_text - post_model.head_text = gpt.head_text - post_model.head_code = gpt.head_code - save_file( - post_model.state_dict(), - vllm_folder / "post_model.safetensors", - ) - del post_model, gpt + + from .velocity import LLM + self.llm = LLM( - model=str(vllm_folder / "gpt"), + model=gpt_folder, num_audio_tokens=self.num_audio_tokens, num_text_tokens=self.num_text_tokens, - post_model_path=vllm_folder / "post_model.safetensors", + post_model_path=embed_file_path, ) self.logger.info("vLLM model loaded") return - self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True)) + self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to( + self.device_gpt + ) + del self.gpt.embed_tokens if ( experimental @@ -166,10 +110,9 @@ def set(self, v: bool): def get(self) -> bool: return self._interrupt - def _build_llama( + def _build_llama_config( self, config: dict, - device: torch.device, ) -> Tuple[LlamaModel, LlamaConfig]: if self.use_flash_attn and is_flash_attn_2_available(): @@ -183,10 +126,7 @@ def _build_llama( else: llama_config = LlamaConfig(**config) - model = LlamaModel(llama_config) - del model.embed_tokens - - return model.to(device), llama_config + return llama_config def prepare(self, compile=False): if self.use_flash_attn and is_flash_attn_2_available(): @@ -198,43 +138,6 @@ def prepare(self, compile=False): except RuntimeError as e: self.logger.warning(f"compile failed: {e}. fallback to normal mode.") - def __call__( - self, input_ids: torch.Tensor, text_mask: torch.Tensor - ) -> torch.Tensor: - """ - get_emb - """ - return super().__call__(input_ids, text_mask) - - def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: - """ - get_emb - """ - - emb_text: torch.Tensor = self.emb_text( - input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt) - ) - - text_mask_inv = text_mask.logical_not().to(self.device_gpt) - masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt) - - emb_code = [ - self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) - ] - emb_code = torch.stack(emb_code, 2).sum(2) - - emb = torch.zeros( - (input_ids.shape[:-1]) + (emb_text.shape[-1],), - device=emb_text.device, - dtype=emb_text.dtype, - ) - emb[text_mask] = emb_text - emb[text_mask_inv] = emb_code.to(emb.dtype) - - del emb_text, emb_code, text_mask_inv - - return emb - @dataclass(repr=False, eq=False) class _GenerationInputs: position_ids: torch.Tensor diff --git a/ChatTTS/model/speaker.py b/ChatTTS/model/speaker.py index 07001db0b..5435922ab 100644 --- a/ChatTTS/model/speaker.py +++ b/ChatTTS/model/speaker.py @@ -18,7 +18,7 @@ def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None: def sample_random(self) -> str: return self._encode(self._sample_random()) - @torch.no_grad() + @torch.inference_mode() def apply( self, emb: torch.Tensor, diff --git a/ChatTTS/model/velocity/__init__.py b/ChatTTS/model/velocity/__init__.py index 866983506..c798a0439 100644 --- a/ChatTTS/model/velocity/__init__.py +++ b/ChatTTS/model/velocity/__init__.py @@ -1,3 +1,2 @@ from .llm import LLM -from .post_model import PostModel from .sampling_params import SamplingParams diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index b473b562c..a37f5cb34 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -2,12 +2,12 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.utils import Counter from .configs import EngineArgs from .llm_engine import LLMEngine from .output import RequestOutput from .sampling_params import SamplingParams -from vllm.utils import Counter class LLM: diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 86a3dfcb6..2f9d7e428 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -22,7 +22,8 @@ SequenceOutput, ) from vllm.utils import in_wsl -from .post_model import PostModel, Sampler +from ..embed import Embed +from .sampler import Sampler from safetensors.torch import safe_open logger = init_logger(__name__) @@ -78,7 +79,7 @@ def __init__( def load_model(self) -> None: self.model = get_model(self.model_config) - self.post_model = PostModel( + self.post_model = Embed( self.model_config.get_hidden_size(), self.model_config.num_audio_tokens, self.model_config.num_text_tokens, diff --git a/ChatTTS/model/velocity/post_model.py b/ChatTTS/model/velocity/sampler.py similarity index 64% rename from ChatTTS/model/velocity/post_model.py rename to ChatTTS/model/velocity/sampler.py index 79b8900a4..ec8706975 100644 --- a/ChatTTS/model/velocity/post_model.py +++ b/ChatTTS/model/velocity/sampler.py @@ -1,78 +1,12 @@ -import os - -os.environ["TOKENIZERS_PARALLELISM"] = "false" -""" -https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning -""" - import torch -import torch.nn as nn from torch.functional import F -from torch.nn.utils.parametrizations import weight_norm from typing import List, Callable - -class PostModel(nn.Module): - def __init__( - self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 - ): - super().__init__() - - self.num_vq = num_vq - self.num_audio_tokens = num_audio_tokens - - self.model_dim = hidden_size - self.emb_code = nn.ModuleList( - [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], - ) - self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) - - self.head_text = weight_norm( - nn.Linear(self.model_dim, num_text_tokens, bias=False), - name="weight", - ) - self.head_code = nn.ModuleList( - [ - weight_norm( - nn.Linear(self.model_dim, num_audio_tokens, bias=False), - name="weight", - ) - for _ in range(self.num_vq) - ], - ) - - def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: - """ - get_emb - """ - device = next(self.parameters()).device - emb_text: torch.Tensor = self.emb_text( - input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) - ) - - text_mask_inv = text_mask.logical_not().to(device) - masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) - - emb_code = [ - self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) - ] - emb_code = torch.stack(emb_code, 2).sum(2) - - emb = torch.zeros( - (input_ids.shape[:-1]) + (emb_text.shape[-1],), - device=emb_text.device, - dtype=emb_text.dtype, - ) - emb[text_mask] = emb_text - emb[text_mask_inv] = emb_code.to(emb.dtype) - - del emb_text, emb_code, text_mask_inv - - return emb +from ..embed import Embed class Sampler: - def __init__(self, post_model: PostModel, num_audio_tokens: int, num_vq: int): + def __init__(self, post_model: Embed, num_audio_tokens: int, num_vq: int): self.post_model = post_model self.device = next(self.post_model.parameters()).device self.num_audio_tokens = num_audio_tokens diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py index 90aca7f32..294c77d37 100644 --- a/ChatTTS/model/velocity/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -12,6 +12,7 @@ from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine + from .model_runner import ModelRunner diff --git a/ChatTTS/res/sha256_map.json b/ChatTTS/res/sha256_map.json index ef47afb3d..b58273eaa 100644 --- a/ChatTTS/res/sha256_map.json +++ b/ChatTTS/res/sha256_map.json @@ -1,8 +1,11 @@ { - "sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38", + "sha256_asset_Decoder_pt" : "9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38", "sha256_asset_DVAE_full_pt" : "553eb75763511e23f3e5f86303e2163c5ca775489d637fb635d979c8ae58bbe5", - "sha256_asset_GPT_pt" : "d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb", - "sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58", + "sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0", + "sha256_asset_Vocos_pt" : "09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58", + + "sha256_asset_gpt_config_json" : "0aaa1ecd96c49ad4f473459eb1982fa7ad79fa5de08cde2781bf6ad1f9a0c236", + "sha256_asset_gpt_model_safetensors" : "cd0806fd971f52f6a22c923ec64982b305e817bcc41ca83417fcf9141b984a0f", "sha256_asset_tokenizer_special_tokens_map_json": "bd0ac9d9bb1657996b5c5fbcaa7d80f8de530d01a283da97f89deae5b1b8d011", "sha256_asset_tokenizer_tokenizer_config_json" : "43e9d658b554fa5ee8d8e1d763349323bfef1ed7a89c0794220ab8861387d421", diff --git a/ChatTTS/utils/dl.py b/ChatTTS/utils/dl.py index da21daa0b..cd67f6532 100644 --- a/ChatTTS/utils/dl.py +++ b/ChatTTS/utils/dl.py @@ -3,7 +3,7 @@ import hashlib import requests from io import BytesIO -from typing import Dict +from typing import Dict, Tuple, Optional from mmap import mmap, ACCESS_READ from .log import logger @@ -43,45 +43,81 @@ def check_model( return True -def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool: - logger.get_logger().info("checking assets...") +def check_folder( + base_dir: Path, + *innder_dirs: str, + names: Tuple[str], + sha256_map: Dict[str, str], + update=False, +) -> bool: + key = "sha256_" + current_dir = base_dir + for d in innder_dirs: + current_dir /= d + key += f"{d}_" - current_dir = base_dir / "asset" - names = [ - "Decoder.pt", - "DVAE_full.pt", - "GPT.pt", - "Vocos.pt", - ] for model in names: menv = model.replace(".", "_") - if not check_model( - current_dir, model, sha256_map[f"sha256_asset_{menv}"], update - ): + if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update): return False + return True - current_dir = base_dir / "asset" / "tokenizer" - names = [ - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - ] - for model in names: - menv = model.replace(".", "_") - if not check_model( - current_dir, model, sha256_map[f"sha256_asset_tokenizer_{menv}"], update - ): - return False + +def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool: + logger.get_logger().info("checking assets...") + + if not check_folder( + base_dir, + "asset", + names=( + "Decoder.pt", + "DVAE_full.pt", + "Embed.safetensors", + "Vocos.pt", + ), + sha256_map=sha256_map, + update=update, + ): + return False + + if not check_folder( + base_dir, + "asset", + "gpt", + names=( + "config.json", + "model.safetensors", + ), + sha256_map=sha256_map, + update=update, + ): + return False + + if not check_folder( + base_dir, + "asset", + "tokenizer", + names=( + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + ), + sha256_map=sha256_map, + update=update, + ): + return False logger.get_logger().info("all assets are already latest.") return True -def download_and_extract_tar_gz(url: str, folder: str): +def download_and_extract_tar_gz( + url: str, folder: str, headers: Optional[Dict[str, str]] = None +): import tarfile logger.get_logger().info(f"downloading {url}") - response = requests.get(url, stream=True, timeout=(5, 10)) + response = requests.get(url, headers=headers, stream=True, timeout=(10, 3)) with BytesIO() as out_file: out_file.write(response.content) out_file.seek(0) @@ -91,11 +127,13 @@ def download_and_extract_tar_gz(url: str, folder: str): logger.get_logger().info(f"extracted into {folder}") -def download_and_extract_zip(url: str, folder: str): +def download_and_extract_zip( + url: str, folder: str, headers: Optional[Dict[str, str]] = None +): import zipfile logger.get_logger().info(f"downloading {url}") - response = requests.get(url, stream=True, timeout=(5, 10)) + response = requests.get(url, headers=headers, stream=True, timeout=(10, 3)) with BytesIO() as out_file: out_file.write(response.content) out_file.seek(0) @@ -105,15 +143,15 @@ def download_and_extract_zip(url: str, folder: str): logger.get_logger().info(f"extracted into {folder}") -def download_dns_yaml(url: str, folder: str): +def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]): logger.get_logger().info(f"downloading {url}") - response = requests.get(url, stream=True, timeout=(5, 10)) + response = requests.get(url, headers=headers, stream=True, timeout=(100, 3)) with open(os.path.join(folder, "dns.yaml"), "wb") as out_file: out_file.write(response.content) logger.get_logger().info(f"downloaded into {folder}") -def download_all_assets(tmpdir: str, version="0.2.7"): +def download_all_assets(tmpdir: str, version="0.2.8"): import subprocess import platform @@ -150,20 +188,17 @@ def download_all_assets(tmpdir: str, version="0.2.7"): os.chmod(cmdfile, 0o755) subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"]) except Exception: - BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/" - suffix = { - "darwin_amd64": "987", - "darwin_arm64": "988", - "linux_386": "989", - "linux_amd64": "990", - "linux_arm64": "991", - "windows_386": "992", - "windows_amd64": "993", - }[f"{system_type}_{architecture}"] - RVCMD_URL = BASE_URL + suffix + BASE_URL = ( + "https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/" + ) + suffix = "zip" if is_win else "tar.gz" + RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" download_dns_yaml( - "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml", + "https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml", tmpdir, + headers={ + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0" + }, ) if is_win: download_and_extract_zip(RVCMD_URL, tmpdir) diff --git a/examples/cmd/run.py b/examples/cmd/run.py index f11a62f5f..389890e80 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -1,23 +1,23 @@ +import os, sys + +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +now_dir = os.getcwd() +sys.path.append(now_dir) + from typing import Optional, List import argparse -import os -import sys import numpy as np -import torch import ChatTTS + from tools.logger import get_logger from tools.audio import pcm_arr_to_mp3_view from tools.normalizer.en import normalizer_en_nemo_text from tools.normalizer.zh import normalizer_zh_tn -if sys.platform == "darwin": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - -now_dir = os.getcwd() -sys.path.append(now_dir) - logger = get_logger("Command") @@ -66,9 +66,9 @@ def main( is_load = False if os.path.isdir(custom_path) and source == "custom": - is_load = chat.load(compile=True, source="custom", custom_path=custom_path) + is_load = chat.load(source="custom", custom_path=custom_path) else: - is_load = chat.load(compile=True, source=source) + is_load = chat.load(source=source) if is_load: logger.info("Models loaded successfully.") diff --git a/examples/onnx/exporter.py b/examples/onnx/exporter.py index 2276a7c98..351fed1de 100644 --- a/examples/onnx/exporter.py +++ b/examples/onnx/exporter.py @@ -1,6 +1,13 @@ +import os, sys + +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +now_dir = os.getcwd() +sys.path.append(now_dir) + from dataclasses import asdict import argparse -import os import torch from tqdm import tqdm from ChatTTS.model.dvae import DVAE @@ -8,7 +15,8 @@ from vocos import Vocos from vocos.pretrained import instantiate_class import torch.jit as jit -from examples.onnx.gpt import GPT + +from gpt import GPT # disable cuda torch.cuda.is_available = lambda: False diff --git a/examples/onnx/gpt.py b/examples/onnx/gpt.py index c026d1a79..42b6a83da 100644 --- a/examples/onnx/gpt.py +++ b/examples/onnx/gpt.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn from torch.nn.utils.parametrizations import weight_norm -from .modeling_llama import LlamaModel, LlamaConfig + +from modeling_llama import LlamaModel, LlamaConfig class GPT(nn.Module): diff --git a/examples/web/funcs.py b/examples/web/funcs.py index 4885e5331..b8fcd4f4a 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -1,4 +1,3 @@ -import sys import random from typing import Optional from time import sleep @@ -61,12 +60,10 @@ def on_audio_seed_change(audio_seed_input): def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool: if cust_path == None: - ret = chat.load(coef=coef, compile=sys.platform != "win32") + ret = chat.load(coef=coef) else: logger.info("local model path: %s", cust_path) - ret = chat.load( - "custom", custom_path=cust_path, coef=coef, compile=sys.platform != "win32" - ) + ret = chat.load("custom", custom_path=cust_path, coef=coef) global custom_path custom_path = cust_path if ret: diff --git a/examples/web/webui.py b/examples/web/webui.py index 267fa6c33..5a556b050 100644 --- a/examples/web/webui.py +++ b/examples/web/webui.py @@ -10,8 +10,8 @@ import gradio as gr -from examples.web.funcs import * -from examples.web.ex import ex +from funcs import * +from ex import ex def main(): diff --git a/tools/checksum/tmpl.go b/tools/checksum/tmpl.go index cfee203ed..984c7706c 100644 --- a/tools/checksum/tmpl.go +++ b/tools/checksum/tmpl.go @@ -3,19 +3,25 @@ package main var files = [...]string{ "asset/Decoder.pt", "asset/DVAE_full.pt", - "asset/GPT.pt", + "asset/Embed.safetensors", "asset/Vocos.pt", + "asset/gpt/config.json", + "asset/gpt/model.safetensors", + "asset/tokenizer/special_tokens_map.json", "asset/tokenizer/tokenizer_config.json", "asset/tokenizer/tokenizer.json", } const jsontmpl = `{ - "sha256_asset_Decoder_pt" : "%s", + "sha256_asset_Decoder_pt" : "%s", "sha256_asset_DVAE_full_pt" : "%s", - "sha256_asset_GPT_pt" : "%s", - "sha256_asset_Vocos_pt" : "%s", + "sha256_asset_Embed_safetensors" : "%s", + "sha256_asset_Vocos_pt" : "%s", + + "sha256_asset_gpt_config_json" : "%s", + "sha256_asset_gpt_model_safetensors" : "%s", "sha256_asset_tokenizer_special_tokens_map_json": "%s", "sha256_asset_tokenizer_tokenizer_config_json" : "%s",