From 8e6184e1c2b82b556b6d3c8e19447cf876bc25dd Mon Sep 17 00:00:00 2001 From: ylzz1997 Date: Sun, 21 Jul 2024 03:36:40 +0800 Subject: [PATCH 01/27] add feat: Add ChatTTS vLLM Wrapper --- ChatTTS/core.py | 297 +++++---- ChatTTS/vllm_engine/__init__.py | 0 ChatTTS/vllm_engine/block_manager.py | 293 +++++++++ ChatTTS/vllm_engine/configs.py | 787 ++++++++++++++++++++++++ ChatTTS/vllm_engine/llama.py | 382 ++++++++++++ ChatTTS/vllm_engine/llm.py | 212 +++++++ ChatTTS/vllm_engine/llm_engine.py | 811 +++++++++++++++++++++++++ ChatTTS/vllm_engine/model_loader.py | 67 ++ ChatTTS/vllm_engine/model_runner.py | 769 +++++++++++++++++++++++ ChatTTS/vllm_engine/output.py | 127 ++++ ChatTTS/vllm_engine/post_model.py | 201 ++++++ ChatTTS/vllm_engine/sampling_params.py | 273 +++++++++ ChatTTS/vllm_engine/scheduler.py | 413 +++++++++++++ ChatTTS/vllm_engine/sequence.py | 436 +++++++++++++ ChatTTS/vllm_engine/worker.py | 237 ++++++++ test.py | 27 + 16 files changed, 5208 insertions(+), 124 deletions(-) create mode 100644 ChatTTS/vllm_engine/__init__.py create mode 100644 ChatTTS/vllm_engine/block_manager.py create mode 100644 ChatTTS/vllm_engine/configs.py create mode 100644 ChatTTS/vllm_engine/llama.py create mode 100644 ChatTTS/vllm_engine/llm.py create mode 100644 ChatTTS/vllm_engine/llm_engine.py create mode 100644 ChatTTS/vllm_engine/model_loader.py create mode 100644 ChatTTS/vllm_engine/model_runner.py create mode 100644 ChatTTS/vllm_engine/output.py create mode 100644 ChatTTS/vllm_engine/post_model.py create mode 100644 ChatTTS/vllm_engine/sampling_params.py create mode 100644 ChatTTS/vllm_engine/scheduler.py create mode 100644 ChatTTS/vllm_engine/sequence.py create mode 100644 ChatTTS/vllm_engine/worker.py create mode 100644 test.py diff --git a/ChatTTS/core.py b/ChatTTS/core.py index fb00ba00d..ca16f0af1 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -6,15 +6,18 @@ from json import load from pathlib import Path import lzma - +import pathlib +from ChatTTS.vllm_engine.post_model import Post_model +from safetensors.torch import save_file, safe_open import numpy as np import torch from vocos import Vocos from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download import pybase16384 as b14 - -from .config import Config +from ChatTTS.vllm_engine.llm import LLM +from ChatTTS.vllm_engine.sampling_params import SamplingParams +import yaml from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( check_all_assets, @@ -167,7 +170,7 @@ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: @torch.no_grad() def _sample_random_speaker(self) -> torch.Tensor: - dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features + dim: int = self.hidden_size spk = ( torch.randn(dim, device=self.std.device, dtype=self.std.dtype) .mul_(self.std) @@ -266,56 +269,64 @@ def _load( if "mps" in str(device) else device ) - .eval() - ) - assert vocos_ckpt_path, "vocos_ckpt_path should not be None" - vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True)) - self.vocos = vocos - self.logger.log(logging.INFO, "vocos loaded.") - - dvae = ( - DVAE( - decoder_config=asdict(self.config.dvae.decoder), - encoder_config=asdict(self.config.dvae.encoder), - vq_config=asdict(self.config.dvae.vq), - dim=self.config.dvae.decoder.idim, - coef=coef, + self.dvae = dvae + self.logger.log(logging.INFO, "dvae loaded.") + + if gpt_config_path: + cfg = OmegaConf.load(gpt_config_path) + self.num_vq = 4 + if not os.path.exists("asset/vllm_model"): + gpt = GPT( + **cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger + ).eval() + assert gpt_ckpt_path, "gpt_ckpt_path should not be None" + gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) + gpt.prepare(compile=compile and "cuda" in str(device)) + self.gpt = gpt + pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) + self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") + self.post_model = Post_model( + cfg.gpt_config.hidden_size, + cfg.num_audio_tokens, + cfg.num_text_tokens, + device = device + ).to(device).eval() + + self.post_model.emb_code = self.gpt.emb_code + self.post_model.emb_text = self.gpt.emb_text + self.post_model.head_text = self.gpt.head_text + self.post_model.head_code = self.gpt.head_code + save_file(self.post_model.state_dict(), "asset/vllm_model/post_model.safetensors") + + self.num_audio_tokens = cfg.num_audio_tokens + spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") + assert os.path.exists( + spk_stat_path + ), f"Missing spk_stat.pt: {spk_stat_path}" + spk_stat: torch.Tensor = torch.load( + spk_stat_path, + weights_only=True, + mmap=True, + map_location=device, ) - .to(device) - .eval() - ) - coef = str(dvae) - assert dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True)) - self.dvae = dvae - self.logger.log(logging.INFO, "dvae loaded.") - - gpt = GPT( - gpt_config=asdict(self.config.gpt), - use_flash_attn=use_flash_attn, - device=device, - logger=self.logger, - ).eval() - assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.from_pretrained(gpt_ckpt_path) - gpt.prepare(compile=compile and "cuda" in str(device)) - self.gpt = gpt - spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") - assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" - spk_stat: torch.Tensor = torch.load( - spk_stat_path, - weights_only=True, - mmap=True, - map_location=device, - ) - self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) - self.logger.log(logging.INFO, "gpt loaded.") - - decoder = ( - DVAE( - decoder_config=asdict(self.config.decoder), - dim=self.config.decoder.idim, - coef=coef, + self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) + self.logger.log(logging.INFO, "gpt loaded.") + + self.hidden_size = cfg.gpt_config.hidden_size + self.gpt = LLM( + model="asset/vllm_model/gpt", + num_audio_tokens = cfg.num_audio_tokens, + num_text_tokens = cfg.num_text_tokens, + post_model_path="asset/vllm_model/post_model.safetensors", + ) + + if decoder_config_path: + cfg = OmegaConf.load(decoder_config_path) + decoder = DVAE(**cfg, coef=coef).to(device).eval() + coef = str(decoder) + assert decoder_ckpt_path, "decoder_ckpt_path should not be None" + decoder.load_state_dict( + torch.load(decoder_ckpt_path, weights_only=True, mmap=True) ) .to(device) .eval() @@ -335,7 +346,7 @@ def _load( self.coef = coef return self.has_loaded() - + def _infer( self, text, @@ -451,6 +462,55 @@ def _decode_to_wavs( del mel_specs return wavs + @staticmethod + def _decode_spk_emb(spk_emb: str) -> np.ndarray: + return np.frombuffer( + lzma.decompress( + b14.decode_from_string(spk_emb), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + dtype=np.float16, + ).copy() + + @torch.no_grad() + def _apply_spk_emb( + self, + emb: torch.Tensor, + spk_emb: str, + input_ids: torch.Tensor, + ): + n = ( + F.normalize( + torch.from_numpy( + self._decode_spk_emb(spk_emb), + ), + p=2.0, + dim=0, + eps=1e-12, + ) + .to(self.gpt.device_gpt) + .unsqueeze_(0) + .expand(emb.size(0), -1) + .unsqueeze_(1) + .expand(emb.shape) + ) + cond = ( + input_ids.narrow(-1, 0, 1).eq(self.tokenizer.spk_emb_ids).expand(emb.shape) + ) + torch.where(cond, n, emb, out=emb) + del cond, n + @dataclass(repr=False, eq=False) + class GenerationOutputs: + ids: List[torch.Tensor] + # attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] + hiddens: List[torch.Tensor] + + def destroy(self): + del_all(self.ids) + # del_all(self.attentions) + # del_all(self.hiddens) + @torch.no_grad() def _infer_code( self, @@ -461,7 +521,7 @@ def _infer_code( params: InferCodeParams, ): - gpt = self.gpt + gpt: LLM = self.gpt if not isinstance(text, list): text = [text] @@ -469,7 +529,7 @@ def _infer_code( assert len(text), "text should not be empty" if not isinstance(params.temperature, list): - temperature = [params.temperature] * gpt.num_vq + temperature = [params.temperature] * self.num_vq else: temperature = params.temperature @@ -494,22 +554,11 @@ def _infer_code( text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, - self.gpt.num_vq, - prompt_str=params.spk_smp, - device=gpt.device_gpt, + text, self.num_vq, self.device ) - - emb = gpt(input_ids, text_mask) - - del text_mask - - if params.spk_emb is not None: - self.tokenizer.apply_spk_emb( - emb, params.spk_emb, input_ids, self.gpt.device_gpt - ) - - num_code = int(gpt.emb_code[0].num_embeddings - 1) + start_idx = input_ids.shape[-2] + + num_code = self.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( num_code=num_code, @@ -517,31 +566,34 @@ def _infer_code( top_K=params.top_K, repetition_penalty=params.repetition_penalty, ) - - result = gpt.generate( - emb, - input_ids, - temperature=torch.tensor(temperature, device=device), - eos_token=num_code, - attention_mask=attention_mask, + + sample_params = SamplingParams( + temperature=temperature, max_new_token=params.max_new_token, + max_tokens = 8192, min_new_token=params.min_new_token, - logits_warpers=logits_warpers, - logits_processors=logits_processors, + logits_processors=(logits_warpers, logits_processors), + eos_token = num_code, infer_text=False, - return_hidden=return_hidden, - stream=stream, - show_tqdm=params.show_tqdm, - ensure_non_empty=params.ensure_non_empty, - stream_batch=params.stream_batch, - context=self.context, + start_idx=start_idx ) - - del emb, input_ids - del_all(logits_warpers) - del_all(logits_processors) - - return result + input_ids = [i.tolist() for i in input_ids] + + result = gpt.generate( + None, + sample_params, + input_ids, + ) + + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append(i.outputs[0].hidden_states.to(torch.float32).to(self.device)) + return [self.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states + ),] @torch.no_grad() def _refine_text( @@ -551,7 +603,7 @@ def _refine_text( params: RefineTextParams, ): - gpt = self.gpt + gpt:LLM = self.gpt if not isinstance(text, list): text = [text] @@ -559,11 +611,10 @@ def _refine_text( text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, - self.gpt.num_vq, - device=gpt.device_gpt, + text, self.num_vq, self.device ) - + start_idx = input_ids.shape[-2] + # print(start_idx) logits_warpers, logits_processors = gen_logits( num_code=self.tokenizer.len, top_P=params.top_P, @@ -571,31 +622,29 @@ def _refine_text( repetition_penalty=params.repetition_penalty, ) - emb = gpt(input_ids, text_mask) - - del text_mask - - result = next( - gpt.generate( - emb, - input_ids, - temperature=torch.tensor([params.temperature], device=device), - eos_token=self.tokenizer.eos_token, - attention_mask=attention_mask, - max_new_token=params.max_new_token, - min_new_token=params.min_new_token, - logits_warpers=logits_warpers, - logits_processors=logits_processors, - infer_text=True, - stream=False, - show_tqdm=params.show_tqdm, - ensure_non_empty=params.ensure_non_empty, - context=self.context, - ) + sample_params = SamplingParams( + temperature=params.temperature, + max_new_token=params.max_new_token, + max_tokens = 8192, + min_new_token=params.min_new_token, + logits_processors=(logits_warpers, logits_processors), + eos_token = self.tokenizer.eos_token, + infer_text=True, + start_idx=start_idx + ) + input_ids = [i.tolist() for i in input_ids] + + result = gpt.generate( + None, + sample_params, + input_ids + ) + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append(i.outputs[0].hidden_states) + return self.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states ) - - del emb, input_ids - del_all(logits_warpers) - del_all(logits_processors) - - return result diff --git a/ChatTTS/vllm_engine/__init__.py b/ChatTTS/vllm_engine/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ChatTTS/vllm_engine/block_manager.py b/ChatTTS/vllm_engine/block_manager.py new file mode 100644 index 000000000..b95cdc840 --- /dev/null +++ b/ChatTTS/vllm_engine/block_manager.py @@ -0,0 +1,293 @@ +"""A block manager that manages token blocks.""" +import enum +from typing import Dict, List, Optional, Set, Tuple + +from vllm.block import PhysicalTokenBlock +from ChatTTS.vllm_engine.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] + + +class BlockAllocator: + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: BlockTable = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(device=device, + block_number=i, + block_size=block_size) + self.free_blocks.append(block) + + def allocate(self) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager: + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.block_sliding_window = None + if sliding_window is not None: + assert sliding_window % block_size == 0, (sliding_window, + block_size) + self.block_sliding_window = sliding_window // block_size + + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, + num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, + num_cpu_blocks) + # Mapping: seq_id -> BlockTable. + self.block_tables: Dict[int, BlockTable] = {} + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = len(seq.logical_token_blocks) + if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.block_sliding_window) + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def allocate(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + # Allocate new physical token blocks that will store the prompt tokens. + block_table: BlockTable = [] + for logical_idx in range(len(seq.logical_token_blocks)): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + # Assign the block table for each sequence. + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() + + def can_append_slot(self, seq_group: SequenceGroup) -> bool: + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + + def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + logical_blocks = seq.logical_token_blocks + block_table = self.block_tables[seq.seq_id] + + if len(block_table) < len(logical_blocks): + if (self.block_sliding_window + and len(block_table) >= self.block_sliding_window): + # re-use a block + block_table.append(block_table[len(block_table) % + self.block_sliding_window]) + else: + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + return None + + # We want to append the token to the last physical block. + last_block = block_table[-1] + assert last_block.device == Device.GPU + if last_block.ref_count == 1: + # Not shared with other sequences. Appendable. + return None + else: + # The last block is shared with other sequences. + # Copy on Write: Allocate a new block and copy the tokens. + new_block = self.gpu_allocator.allocate() + block_table[-1] = new_block + self.gpu_allocator.free(last_block) + return last_block.block_number, new_block.block_number + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + # NOTE: fork does not allocate a new physical block. + # Thus, it is always safe from OOM. + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.copy() + for block in src_block_table: + block.ref_count += 1 + + def _get_physical_blocks( + self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by + # the sequences in the same group. + blocks: Set[PhysicalTokenBlock] = set() + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + blocks.update(self.block_tables[seq.seq_id]) + return list(blocks) + + def can_swap_in(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + num_free_blocks = self.gpu_allocator.get_num_free_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate + # at least one free block right after the swap-in. + # NOTE: This should match the logic in can_append_slot(). + num_required_blocks = len(blocks) + num_swapped_seqs + return num_free_blocks - num_required_blocks >= self.watermark_blocks + + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + # CPU block -> GPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate() + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + cpu_block.block_number: gpu_block.block_number + for cpu_block, gpu_block in mapping.items() + } + return block_number_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + # GPU block -> CPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate() + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + gpu_block.block_number: cpu_block.block_number + for gpu_block, cpu_block in mapping.items() + } + return block_number_mapping + + def _free_block_table(self, block_table: BlockTable) -> None: + for block in set(block_table): + if block.device == Device.GPU: + self.gpu_allocator.free(block) + else: + self.cpu_allocator.free(block) + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + block_table = self.block_tables[seq.seq_id] + self._free_block_table(block_table) + del self.block_tables[seq.seq_id] + + def reset(self) -> None: + for block_table in self.block_tables.values(): + self._free_block_table(block_table) + self.block_tables.clear() + + def get_block_table(self, seq: Sequence) -> List[int]: + block_table = self.block_tables[seq.seq_id] + return [block.block_number for block in block_table] + + def get_num_free_gpu_blocks(self) -> int: + return self.gpu_allocator.get_num_free_blocks() + + def get_num_free_cpu_blocks(self) -> int: + return self.cpu_allocator.get_num_free_blocks() diff --git a/ChatTTS/vllm_engine/configs.py b/ChatTTS/vllm_engine/configs.py new file mode 100644 index 000000000..30d6c9afa --- /dev/null +++ b/ChatTTS/vllm_engine/configs.py @@ -0,0 +1,787 @@ +from typing import Optional, Union, Tuple +import os + +import torch +from transformers import PretrainedConfig + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config +from vllm.utils import get_cpu_memory, is_hip + +import argparse +import dataclasses +from dataclasses import dataclass + + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + download_dir: Optional[str], + load_format: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + num_audio_tokens: int = 1024, + num_text_tokens: int = 80 + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.download_dir = download_dir + self.load_format = load_format + self.seed = seed + self.revision = revision + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.enforce_eager = enforce_eager + self.max_context_len_to_capture = max_context_len_to_capture + self.num_audio_tokens = num_audio_tokens + self.num_text_tokens = num_text_tokens + + if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C + model_path = snapshot_download(model_id=model, + cache_dir=download_dir, + revision=revision) + self.model = model_path + self.download_dir = model_path + self.tokenizer = model_path + + self.hf_config = get_config(self.model, trust_remote_code, revision) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.max_model_len = _get_and_verify_max_len(self.hf_config, + max_model_len) + self._verify_load_format() + self._verify_tokenizer_mode() + self._verify_quantization() + self._verify_cuda_graph() + + def _verify_load_format(self) -> None: + load_format = self.load_format.lower() + supported_load_format = [ + "auto", "pt", "safetensors", "npcache", "dummy" + ] + rocm_not_supported_load_format = [] + if load_format not in supported_load_format: + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") + + # TODO: Remove this check once HF updates the pt weights of Mixtral. + architectures = getattr(self.hf_config, "architectures", []) + if "MixtralForCausalLM" in architectures and load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. ") + self.load_format = load_format + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode + + def _verify_quantization(self) -> None: + supported_quantization = ["awq", "gptq", "squeezellm"] + rocm_not_supported_quantization = ["awq"] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + hf_quant_config = getattr(self.hf_config, "quantization_config", None) + if hf_quant_config is not None: + hf_quant_method = str(hf_quant_config["quant_method"]).lower() + if self.quantization is None: + self.quantization = hf_quant_method + elif self.quantization != hf_quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({hf_quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization in rocm_not_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not supported " + f"in ROCm.") + logger.warning(f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.") + + def _verify_cuda_graph(self) -> None: + if self.max_context_len_to_capture is None: + self.max_context_len_to_capture = self.max_model_len + self.max_context_len_to_capture = min(self.max_context_len_to_capture, + self.max_model_len) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_num_attention_heads = self.hf_config.num_attention_heads + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + total_num_hidden_layers = self.hf_config.num_hidden_layers + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if total_num_hidden_layers % pipeline_parallel_size != 0: + raise ValueError( + f"Total number of hidden layers ({total_num_hidden_layers}) " + "must be divisible by pipeline parallel size " + f"({pipeline_parallel_size}).") + + def get_sliding_window(self) -> Optional[int]: + return getattr(self.hf_config, "sliding_window", None) + + def get_vocab_size(self) -> int: + return self.hf_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_config.hidden_size + + def get_head_size(self) -> int: + # FIXME(woosuk): This may not be true for all models. + return self.hf_config.hidden_size // self.hf_config.num_attention_heads + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + total_num_hidden_layers = self.hf_config.num_hidden_layers + return total_num_hidden_layers // parallel_config.pipeline_parallel_size + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.sliding_window = sliding_window + self._verify_args() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. " + msg) + + +class ParallelConfig: + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + worker_use_ray: Whether to use Ray for model workers. Will be set to + True if either pipeline_parallel_size or tensor_parallel_size is + greater than 1. + """ + + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, + ) -> None: + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + self.worker_use_ray = worker_use_ray + self.max_parallel_loading_workers = max_parallel_loading_workers + + self.world_size = pipeline_parallel_size * tensor_parallel_size + if self.world_size > 1: + self.worker_use_ray = True + self._verify_args() + + def _verify_args(self) -> None: + if self.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is not supported yet.") + + +class SchedulerConfig: + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. + max_model_len: Maximum length of a sequence (including prompt + and generated text). + max_paddings: Maximum number of paddings to be added to a batch. + """ + + def __init__( + self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + max_paddings: int, + ) -> None: + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + # If max_model_len is too short, use 2048 as the default value for + # higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_seqs = max_num_seqs + self.max_model_len = max_model_len + self.max_paddings = max_paddings + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + if is_hip() and torch_dtype == torch.float32: + rocm_supported_dtypes = [ + k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + if (k not in _ROCM_NOT_SUPPORTED_DTYPE) + ] + raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + f"{possible_keys}. Assuming the model's maximum length is " + f"{default_max_len}.") + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if max_model_len is None: + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size.") + return int(max_model_len) + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + block_size: int = 16 + swap_space: int = 4 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_paddings: int = 256 + disable_log_stats: bool = False + revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: bool = False + max_context_len_to_capture: int = 8192 + num_audio_tokens: int = 1024 + num_text_tokens: int = 80 + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # NOTE: If you update any of the arguments below, please also + # make sure to update docs/source/models/engine_args.rst + + # Model arguments + parser.add_argument( + '--model', + type=str, + default='facebook/opt-125m', + help='name or path of the huggingface model to use') + parser.add_argument( + '--tokenizer', + type=str, + default=EngineArgs.tokenizer, + help='name or path of the huggingface tokenizer to use') + parser.add_argument( + '--revision', + type=str, + default=None, + help='the specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument( + '--tokenizer-revision', + type=str, + default=None, + help='the specific tokenizer version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument('--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow'], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + 'always use the slow tokenizer.') + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument('--download-dir', + type=str, + default=EngineArgs.download_dir, + help='directory to download and load the weights, ' + 'default to the default cache dir of ' + 'huggingface') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--max-model-len', + type=int, + default=None, + help='model context length. If unspecified, ' + 'will be automatically derived from the model.') + # Parallel arguments + parser.add_argument('--worker-use-ray', + action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU') + parser.add_argument('--pipeline-parallel-size', + '-pp', + type=int, + default=EngineArgs.pipeline_parallel_size, + help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', + '-tp', + type=int, + default=EngineArgs.tensor_parallel_size, + help='number of tensor parallel replicas') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + help='load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models') + # KV cache arguments + parser.add_argument('--block-size', + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32], + help='token block size') + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument('--seed', + type=int, + default=EngineArgs.seed, + help='random seed') + parser.add_argument('--swap-space', + type=int, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument('--max-num-batched-tokens', + type=int, + default=EngineArgs.max_num_batched_tokens, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--max-num-seqs', + type=int, + default=EngineArgs.max_num_seqs, + help='maximum number of sequences per iteration') + parser.add_argument('--max-paddings', + type=int, + default=EngineArgs.max_paddings, + help='maximum number of paddings in a batch') + parser.add_argument('--disable-log-stats', + action='store_true', + help='disable logging statistics') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'gptq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-context-len-to-capture', + type=int, + default=EngineArgs.max_context_len_to_capture, + help='maximum context length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode.') + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_engine_configs( + self, + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + model_config = ModelConfig(self.model, self.tokenizer, + self.tokenizer_mode, self.trust_remote_code, + self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, + self.tokenizer_revision, self.max_model_len, + self.quantization, self.enforce_eager, + self.max_context_len_to_capture, + self.num_audio_tokens, self.num_text_tokens, + ) + cache_config = CacheConfig(self.block_size, + self.gpu_memory_utilization, + self.swap_space, + model_config.get_sliding_window()) + parallel_config = ParallelConfig(self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers) + scheduler_config = SchedulerConfig(self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings) + return model_config, cache_config, parallel_config, scheduler_config + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + engine_use_ray: bool = False + disable_log_requests: bool = False + max_log_len: Optional[int] = None + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--engine-use-ray', + action='store_true', + help='use Ray to start the LLM engine in a ' + 'separate process as the server process.') + parser.add_argument('--disable-log-requests', + action='store_true', + help='disable logging requests') + parser.add_argument('--max-log-len', + type=int, + default=None, + help='max number of prompt characters or prompt ' + 'ID numbers being printed in log. ' + 'Default: unlimited.') + return parser diff --git a/ChatTTS/vllm_engine/llama.py b/ChatTTS/vllm_engine/llama.py new file mode 100644 index 000000000..415b09d86 --- /dev/null +++ b/ChatTTS/vllm_engine/llama.py @@ -0,0 +1,382 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = input_emb + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + +class LlamaForCausalLM(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = LlamaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/ChatTTS/vllm_engine/llm.py b/ChatTTS/vllm_engine/llm.py new file mode 100644 index 000000000..b2bce9f6b --- /dev/null +++ b/ChatTTS/vllm_engine/llm.py @@ -0,0 +1,212 @@ +from typing import List, Optional, Union + +from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from ChatTTS.vllm_engine.configs import EngineArgs +from ChatTTS.vllm_engine.llm_engine import LLMEngine +from ChatTTS.vllm_engine.output import RequestOutput +from ChatTTS.vllm_engine.sampling_params import SamplingParams +from vllm.utils import Counter + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq" and "squeezellm". If None, we first check + the `quantization_config` attribute in the model config file. If + that is None, we assume the model weights are not quantized and use + `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_context_len_to_capture: int = 8192, + post_model_path: str = None, + num_audio_tokens: int = 0, + num_text_tokens: int = 0, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = EngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + num_audio_tokens = num_audio_tokens, + num_text_tokens = num_text_tokens, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) + self.request_counter = Counter() + + def get_tokenizer( + self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + + Returns: + A list of `RequestOutput` objects containing the generated + completions in the same order as the input prompts. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + if (prompts is not None and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + # Add requests to the engine. + num_requests = len(prompts) if prompts is not None else len( + prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + token_ids = None if prompt_token_ids is None else prompt_token_ids[ + i] + self._add_request(prompt, sampling_params, token_ids) + + rtns = self._run_engine(use_tqdm) + for i, rtn in enumerate(rtns): + token_ids = rtn.outputs[0].token_ids + for j, token_id in enumerate(token_ids): + if len(token_id) == 1: + token_ids[j] = token_id[0] + else: + token_ids[j] = list(token_id) + + return rtns + + def _add_request( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]], + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, prompt, sampling_params, + prompt_token_ids) + + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm(total=num_requests, desc="Processed prompts") + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs diff --git a/ChatTTS/vllm_engine/llm_engine.py b/ChatTTS/vllm_engine/llm_engine.py new file mode 100644 index 000000000..a89bb87cd --- /dev/null +++ b/ChatTTS/vllm_engine/llm_engine.py @@ -0,0 +1,811 @@ +import copy +from collections import defaultdict +import os +import time +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, + Union) + +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from ChatTTS.vllm_engine.scheduler import Scheduler, SchedulerOutputs +from ChatTTS.vllm_engine.configs import EngineArgs +from vllm.engine.metrics import record_metrics +from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.logger import init_logger +from ChatTTS.vllm_engine.output import RequestOutput +from ChatTTS.vllm_engine.sampling_params import SamplingParams +from ChatTTS.vllm_engine.sequence import (SamplerOutput, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.tokenizer import (detokenize_incrementally, + get_tokenizer) +from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port +import numpy as np +if ray: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + +_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The `LLM` class wraps this class for offline batched inference and the + `AsyncLLMEngine` class wraps this class for online serving. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + placement_group: Ray placement group for distributed execution. + Required for distributed execution. + log_stats: Whether to log statistics. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + placement_group: Optional["PlacementGroup"], + post_model_path: str, + log_stats: bool, + ) -> None: + logger.info( + "Initializing an LLM engine with config: " + f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " + f"tokenizer_revision={model_config.tokenizer_revision}, " + f"trust_remote_code={model_config.trust_remote_code}, " + f"dtype={model_config.dtype}, " + f"max_seq_len={model_config.max_model_len}, " + f"download_dir={model_config.download_dir!r}, " + f"load_format={model_config.load_format}, " + f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"quantization={model_config.quantization}, " + f"enforce_eager={model_config.enforce_eager}, " + f"seed={model_config.seed}), " + f"post_model_path={post_model_path!r}" + ) + # TODO(woosuk): Print more configs in debug mode. + + self.model_config = model_config + self.cache_config = cache_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.log_stats = log_stats + self._verify_args() + self.post_model_path = post_model_path + self.seq_counter = Counter() + + # Create the parallel GPU workers. + if self.parallel_config.worker_use_ray: + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + self._init_workers_ray(placement_group) + else: + self._init_workers() + + # Profile the memory usage and initialize the cache. + self._init_cache() + + # Create the scheduler. + self.scheduler = Scheduler(scheduler_config, cache_config) + + # Logging. + self.last_logging_time = 0.0 + # List of (timestamp, num_tokens) + self.num_prompt_tokens: List[Tuple[float, int]] = [] + # List of (timestamp, num_tokens) + self.num_generation_tokens: List[Tuple[float, int]] = [] + + def _init_workers(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from ChatTTS.vllm_engine.worker import Worker + + assert self.parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + + self.workers: List[Worker] = [] + distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + post_model_path = self.post_model_path + ) + self._run_workers("init_model") + self._run_workers("load_model") + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if self.parallel_config.tensor_parallel_size == 1: + num_gpus = self.cache_config.gpu_memory_utilization + else: + num_gpus = 1 + + self.driver_dummy_worker: RayWorkerVllm = None + self.workers: List[RayWorkerVllm] = [] + + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + else: + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + driver_node_id, driver_gpu_ids = ray.get( + self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + worker_node_and_gpu_ids = ray.get( + [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + node_workers[driver_node_id].append(0) + node_gpus[driver_node_id].extend(driver_gpu_ids) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, + start=1): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # Set CUDA_VISIBLE_DEVICES for the driver. + set_cuda_visible_devices(node_gpus[driver_node_id]) + for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + + distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}" + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + # Initialize torch distributed process group for the workers. + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + + for rank, (worker, (node_id, + _)) in enumerate(zip(self.workers, + worker_node_and_gpu_ids), + start=1): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + local_rank, + rank, + distributed_init_method, + )) + + driver_rank = 0 + driver_local_rank = node_workers[driver_node_id].index(driver_rank) + self.driver_worker = Worker( + model_config, + parallel_config, + scheduler_config, + driver_local_rank, + driver_rank, + distributed_init_method, + is_driver_worker=True, + ) + + self._run_workers("init_model") + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache.""" + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers( + "profile_num_available_blocks", + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + # FIXME(woosuk): Change to debug log. + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = self.cache_config.block_size * num_gpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") + + @classmethod + def from_engine_args(cls, engine_args: EngineArgs, post_model_path=None) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + # Initialize the cluster. + placement_group = initialize_cluster(parallel_config) + # Create the LLM engine. + engine = cls(*engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats, + post_model_path = post_model_path + ) + return engine + + def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + """ + if arrival_time is None: + arrival_time = time.monotonic() + + assert prompt_token_ids is not None, "prompt_token_ids must be provided" + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, + arrival_time) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + """ + self.scheduler.abort_seq_group(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_seq_groups() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_seqs() + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = (current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + if early_stopping is False: + highest_attainable_score = (best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) + return current_worst_score >= highest_attainable_score + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None: + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs, + child_sample.hidden_states, + child_sample.finished + ) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs, + last_child_sample.hidden_states, + last_child_sample.finished + ) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + # self._decode_sequence(seq, seq_group.sampling_params) + self._check_stop(seq, seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _process_model_outputs( + self, output: SamplerOutput, + scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + # Update the scheduled sequence groups with the model outputs. + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + for seq_group, outputs in zip(scheduled_seq_groups, output): + self._process_sequence_group_outputs(seq_group, outputs) + + # Free the finished sequence groups. + self.scheduler.free_finished_seq_groups() + + # Create the outputs. + request_outputs: List[RequestOutput] = [] + for seq_group in (scheduled_seq_groups + + scheduler_outputs.ignored_seq_groups): + request_output = RequestOutput.from_seq_group(seq_group) + request_outputs.append(request_output) + + if self.log_stats: + # Log the system stats. + self._log_system_stats(scheduler_outputs.prompt_run, + scheduler_outputs.num_batched_tokens) + return request_outputs + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = self._run_workers( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] + + return self._process_model_outputs(output, scheduler_outputs) + + def _log_system_stats( + self, + prompt_run: bool, + num_batched_tokens: int, + ) -> None: + now = time.monotonic() + # Log the number of batched input tokens. + if prompt_run: + self.num_prompt_tokens.append((now, num_batched_tokens)) + else: + self.num_generation_tokens.append((now, num_batched_tokens)) + + should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC + if not should_log: + return + + # Discard the old stats. + self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens + if now - t < _LOGGING_INTERVAL_SEC] + self.num_generation_tokens = [(t, n) + for t, n in self.num_generation_tokens + if now - t < _LOGGING_INTERVAL_SEC] + + if len(self.num_prompt_tokens) > 1: + total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) + window = now - self.num_prompt_tokens[0][0] + avg_prompt_throughput = total_num_tokens / window + else: + avg_prompt_throughput = 0.0 + if len(self.num_generation_tokens) > 1: + total_num_tokens = sum(n + for _, n in self.num_generation_tokens[:-1]) + window = now - self.num_generation_tokens[0][0] + avg_generation_throughput = total_num_tokens / window + else: + avg_generation_throughput = 0.0 + + total_num_gpu_blocks = self.cache_config.num_gpu_blocks + num_free_gpu_blocks = ( + self.scheduler.block_manager.get_num_free_gpu_blocks()) + num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks + + total_num_cpu_blocks = self.cache_config.num_cpu_blocks + if total_num_cpu_blocks > 0: + num_free_cpu_blocks = ( + self.scheduler.block_manager.get_num_free_cpu_blocks()) + num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks + cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks + else: + cpu_cache_usage = 0.0 + + record_metrics( + avg_prompt_throughput=avg_prompt_throughput, + avg_generation_throughput=avg_generation_throughput, + scheduler_running=len(self.scheduler.running), + scheduler_swapped=len(self.scheduler.swapped), + scheduler_waiting=len(self.scheduler.waiting), + gpu_cache_usage=gpu_cache_usage, + cpu_cache_usage=cpu_cache_usage, + ) + + logger.info("Avg prompt throughput: " + f"{avg_prompt_throughput:.1f} tokens/s, " + "Avg generation throughput: " + f"{avg_generation_throughput:.1f} tokens/s, " + f"Running: {len(self.scheduler.running)} reqs, " + f"Swapped: {len(self.scheduler.swapped)} reqs, " + f"Pending: {len(self.scheduler.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") + self.last_logging_time = now + + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: + """Decodes the new token for a sequence.""" + (new_tokens, new_output_text, prefix_offset, + read_offset) = detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_output_text + + def _check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + if not sampling_params.include_stop_str_in_output: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] + seq.status = SequenceStatus.FINISHED_STOPPED + return + if seq.data.finished: + seq.status = SequenceStatus.FINISHED_STOPPED + return + + for token_id in seq.get_last_token_id(): + if token_id == sampling_params.eos_token: + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id()[0] == sampling_params.eos_token): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, + method)(*driver_args, **driver_kwargs) + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs diff --git a/ChatTTS/vllm_engine/model_loader.py b/ChatTTS/vllm_engine/model_loader.py new file mode 100644 index 000000000..cb189482c --- /dev/null +++ b/ChatTTS/vllm_engine/model_loader.py @@ -0,0 +1,67 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Type + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.weight_utils import (get_quant_config, + initialize_dummy_weights) +import importlib + +@contextlib.contextmanager +def _set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + model_cls = getattr(importlib.import_module("ChatTTS.vllm_engine.llama"), "LlamaModel", None) + return model_cls + +def get_model(model_config: ModelConfig) -> nn.Module: + model_class = _get_model_architecture(model_config.hf_config) + + # Get the (maybe quantized) linear method. + linear_method = None + if model_config.quantization is not None: + quant_config = get_quant_config(model_config.quantization, + model_config.model, + model_config.hf_config, + model_config.download_dir) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() + + with _set_default_torch_dtype(model_config.dtype): + # Create a model instance. + # The weights will be initialized as empty tensors. + with torch.device("cuda"): + model = model_class(model_config.hf_config, linear_method) + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, model_config.download_dir, + model_config.load_format, model_config.revision) + return model.eval() diff --git a/ChatTTS/vllm_engine/model_runner.py b/ChatTTS/vllm_engine/model_runner.py new file mode 100644 index 000000000..c3db34cc5 --- /dev/null +++ b/ChatTTS/vllm_engine/model_runner.py @@ -0,0 +1,769 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ChatTTS.vllm_engine.configs import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.logger import init_logger +from ChatTTS.vllm_engine.model_loader import get_model +from vllm.model_executor import InputMetadata, SamplingMetadata +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast, broadcast_object_list) +from ChatTTS.vllm_engine.sampling_params import SamplingParams, SamplingType +from ChatTTS.vllm_engine.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput +from vllm.utils import in_wsl +from ChatTTS.vllm_engine.post_model import Post_model, Sampler +from safetensors.torch import safe_open + +logger = init_logger(__name__) + +KVCache = Tuple[torch.Tensor, torch.Tensor] +_PAD_SLOT_ID = -1 +# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + is_driver_worker: bool = False, + post_model_path: str = None + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.is_driver_worker = is_driver_worker + self.post_model_path = post_model_path + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.model = None + self.block_size = None # Set after initial profiling. + + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None else 0) + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() + + def load_model(self) -> None: + self.model = get_model(self.model_config) + self.post_model = Post_model( + self.model_config.get_hidden_size(), + self.model_config.num_audio_tokens, + self.model_config.num_text_tokens + ) + state_dict_tensors = {} + with safe_open(self.post_model_path, framework="pt", device=0) as f: + for k in f.keys(): + state_dict_tensors[k] = f.get_tensor(k) + self.post_model.load_state_dict(state_dict_tensors) + self.post_model.to(next(self.model.parameters())).eval() + self.sampler = Sampler( + self.post_model, + self.model_config.num_audio_tokens, + 4 + ) + def set_block_size(self, block_size: int) -> None: + self.block_size = block_size + + max_num_blocks = (self.max_context_len_to_capture + block_size - + 1) // block_size + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + + prompt_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(prompt_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + for i in range(prompt_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(prompt_lens) + input_tokens = _make_tensor_with_pad(input_tokens, + max_prompt_len, + pad=0, + dtype=torch.long) + input_positions = _make_tensor_with_pad(input_positions, + max_prompt_len, + pad=0, + dtype=torch.long) + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long) + + input_metadata = InputMetadata( + is_prompt=True, + slot_mapping=slot_mapping, + max_context_len=None, + context_lens=None, + block_tables=None, + use_cuda_graph=False, + ) + return input_tokens, input_positions, input_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + batch_size = len(input_tokens) + max_context_len = max(context_lens) + use_captured_graph = ( + not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_context_len <= self.max_context_len_to_capture) + if use_captured_graph: + # Pad the input tokens, positions, and slot mapping to match the + # batch size of the captured graph. + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append([]) + input_positions.append([]) + slot_mapping.append([]) + context_lens.append(1) + block_tables.append([]) + batch_size = graph_batch_size + + input_tokens = _make_tensor_with_pad(input_tokens, + max_len=1, + pad=0, + dtype=torch.long, + device="cuda") + input_positions = _make_tensor_with_pad(input_positions, + max_len=1, + pad=0, + dtype=torch.long, + device="cuda") + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_len=1, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device="cuda") + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device="cuda") + else: + block_tables = _make_tensor_with_pad( + block_tables, + max_len=max_context_len, + pad=0, + dtype=torch.int, + device="cuda", + ) + + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + return input_tokens, input_positions, input_metadata + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + + max_prompt_len = max(prompt_lens) if prompt_lens else 1 + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + prompt_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_prompt_len + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + + selected_token_indices = _async_h2d(selected_token_indices, + dtype=torch.long, + pin_memory=not self.in_wsl) + categorized_sample_indices = { + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, input_metadata + ) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + + def get_size_or_none(x: Optional[torch.Tensor]): + return x.size() if x is not None else None + + # Broadcast the input data. For input tensors, we first broadcast + # its shape and then broadcast the tensor to avoid high + # serialization cost. + py_data = { + "input_tokens_size": + input_tokens.size(), + "input_positions_size": + input_positions.size(), + "is_prompt": + input_metadata.is_prompt, + "slot_mapping_size": + get_size_or_none(input_metadata.slot_mapping), + "max_context_len": + input_metadata.max_context_len, + "context_lens_size": + get_size_or_none(input_metadata.context_lens), + "block_tables_size": + get_size_or_none(input_metadata.block_tables), + "use_cuda_graph": + input_metadata.use_cuda_graph, + "selected_token_indices_size": + sampling_metadata.selected_token_indices.size(), + } + broadcast_object_list([py_data], src=0) + # TODO(zhuohan): Combine the broadcasts or set async_op=True. + broadcast(input_tokens, src=0) + broadcast(input_positions, src=0) + if input_metadata.slot_mapping is not None: + broadcast(input_metadata.slot_mapping, src=0) + if input_metadata.context_lens is not None: + broadcast(input_metadata.context_lens, src=0) + if input_metadata.block_tables is not None: + broadcast(input_metadata.block_tables, src=0) + broadcast(sampling_metadata.selected_token_indices, src=0) + else: + receving_list = [None] + broadcast_object_list(receving_list, src=0) + py_data = receving_list[0] + input_tokens = torch.empty(*py_data["input_tokens_size"], + dtype=torch.long, + device="cuda") + broadcast(input_tokens, src=0) + input_positions = torch.empty(*py_data["input_positions_size"], + dtype=torch.long, + device="cuda") + broadcast(input_positions, src=0) + if py_data["slot_mapping_size"] is not None: + slot_mapping = torch.empty(*py_data["slot_mapping_size"], + dtype=torch.long, + device="cuda") + broadcast(slot_mapping, src=0) + else: + slot_mapping = None + if py_data["context_lens_size"] is not None: + context_lens = torch.empty(*py_data["context_lens_size"], + dtype=torch.int, + device="cuda") + broadcast(context_lens, src=0) + else: + context_lens = None + if py_data["block_tables_size"] is not None: + block_tables = torch.empty(*py_data["block_tables_size"], + dtype=torch.int, + device="cuda") + broadcast(block_tables, src=0) + else: + block_tables = None + selected_token_indices = torch.empty( + *py_data["selected_token_indices_size"], + dtype=torch.long, + device="cuda") + broadcast(selected_token_indices, src=0) + input_metadata = InputMetadata( + is_prompt=py_data["is_prompt"], + slot_mapping=slot_mapping, + max_context_len=py_data["max_context_len"], + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=py_data["use_cuda_graph"], + ) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + perform_sampling=False, + ) + + return input_tokens, input_positions, input_metadata, sampling_metadata + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Optional[SamplerOutput]: + input_tokens, input_positions, input_metadata, sampling_metadata = ( + self.prepare_input_tensors(seq_group_metadata_list)) + # print(sampling_metadata.seq_data) + seq_groups = [] + input_tokens_history = [] + for i, rtn in enumerate(sampling_metadata.seq_groups): + seq_groups.append(rtn[0][0]) + tokens_history = sampling_metadata.seq_data[rtn[0][0]].output_token_ids + if len(tokens_history) >= 1: + if len(tokens_history[0]) == 1: + tokens_history = [token[0] for token in tokens_history] + else: + tokens_history = [list(token) for token in tokens_history] + input_tokens_history.append(tokens_history) + input_tokens_history = torch.tensor(input_tokens_history).to(input_tokens.device) + # token_ids = rtn.outputs[0].token_ids + # for j, token_id in enumerate(token_ids): + # if len(token_id) == 1: + # token_ids[j] = token_id[0] + # else: + # token_ids[j] = list(token_id) + + # Execute the model. + # print("it1",input_tokens) + if len(input_tokens.shape) == 2: + input_tokens = input_tokens.unsqueeze(2).repeat(1, 1, 4) + if len(input_tokens_history.shape) == 2: + input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4) + # print(input_tokens_history.shape) + # print("it2",input_tokens.shape) + text_mask = input_tokens != 0 + text_mask = text_mask[:, :, 0] + + if input_metadata.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + infer_text = sampling_metadata.seq_groups[0][1].infer_text + temperture = sampling_metadata.seq_groups[0][1].temperature + if not infer_text: + temperture = torch.tensor(temperture).to(input_tokens.device) + logits_processors, logits_warpers = sampling_metadata.seq_groups[0][1].logits_processors + # print(logits_processors, logits_warpers) + min_new_token = sampling_metadata.seq_groups[0][1].min_new_token + eos_token = sampling_metadata.seq_groups[0][1].eos_token + start_idx = sampling_metadata.seq_groups[0][1].start_idx + if input_tokens.shape[-2] == 1: + if infer_text: + input_emb: torch.Tensor = self.post_model.emb_text(input_tokens[:, :, 0]) + else: + code_emb = [ + self.post_model.emb_code[i](input_tokens[:, :, i]) + for i in range(self.post_model.num_vq) + ] + input_emb = torch.stack(code_emb, 3).sum(3) + else: + input_emb = self.post_model(input_tokens, text_mask) + # print(input_emb.shape) + hidden_states = model_executable( + input_emb=input_emb, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + # print(hidden_states.shape) + # print(input_tokens) + idx_next, logprob, finish = self.sampler.sample( + inputs_ids=input_tokens if input_tokens_history.shape[-2] == 0 else input_tokens_history, + hidden_states=hidden_states, + infer_text=infer_text, + temperature=temperture, + logits_processors=logits_processors, + logits_warpers=logits_warpers, + min_new_token=min_new_token, + now_length=1, + eos_token=eos_token, + start_idx=start_idx + ) + # print(logprob.shape, idx_next.shape) + if len(logprob.shape) == 2: + logprob = logprob[:,None,:] + logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0] + # print("测试",idx_next.shape, logprob.shape) + # Sample the next token. + # output = self.model.sample( + # hidden_states=hidden_states, + # sampling_metadata=sampling_metadata, + # ) + results = [] + for i in range(idx_next.shape[0]): + idx_next_i = idx_next[i, 0, :].cpu().tolist() + logprob_i = logprob[i].cpu().tolist() + result = SequenceGroupOutput( + samples = [SequenceOutput( + parent_seq_id=seq_groups[i], + logprobs={tuple(idx_next_i):logprob_i}, + output_token=tuple(idx_next_i), + hidden_states=hidden_states[i].cpu(), + finished=finish[i].item(), + ),], + prompt_logprobs = None + ) + results.append(result) + # print(results) + # print(idx_next, idx_next.shape, logprob.shape) + return results + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model_config.get_vocab_size() + sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1, infer_text=True) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [(None, None)] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + @torch.inference_mode() + def capture_model(self, kv_caches: List[KVCache]) -> None: + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_emb = torch.zeros(max_batch_size, 1, self.model_config.get_hidden_size(), dtype=next(self.model.parameters()).dtype).cuda() + input_positions = torch.zeros(max_batch_size, 1, + dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + # Create dummy input_metadata. + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping[:batch_size], + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_emb[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") + + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + def capture( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + memory_pool, + ) -> None: + assert self.graph is None + # Run the model once without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + self.model( + input_emb, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Capture the graph. + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states = self.model( + input_emb, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_emb": input_emb, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": input_metadata.slot_mapping, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return + + def forward( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + input_metadata: InputMetadata, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + non_blocking=True) + self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + non_blocking=True) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return x + [pad] * (max_len - len(x)) + + +def _make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, +) -> torch.Tensor: + padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] + return torch.tensor(padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu") + + +def _get_graph_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return (batch_size + 7) // 8 * 8 + + +def _async_h2d(data: list, dtype, pin_memory): + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) + return t.to(device="cuda", non_blocking=True) diff --git a/ChatTTS/vllm_engine/output.py b/ChatTTS/vllm_engine/output.py new file mode 100644 index 000000000..c08edde70 --- /dev/null +++ b/ChatTTS/vllm_engine/output.py @@ -0,0 +1,127 @@ +from typing import List, Optional +import torch + +from ChatTTS.vllm_engine.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, + SequenceStatus) + + +class CompletionOutput: + """The output data of one completion output of a request. + + Args: + index: The index of the output in the request. + text: The generated output text. + token_ids: The token IDs of the generated output text. + cumulative_logprob: The cumulative log probability of the generated + output text. + logprobs: The log probabilities of the top probability words at each + position if the logprobs are requested. + finish_reason: The reason why the sequence is finished. + """ + + def __init__( + self, + index: int, + text: str, + token_ids: List[int], + cumulative_logprob: float, + logprobs: Optional[SampleLogprobs], + finish_reason: Optional[str] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> None: + self.index = index + self.text = text + self.token_ids = token_ids + self.cumulative_logprob = cumulative_logprob + self.logprobs = logprobs + self.finish_reason = finish_reason + self.hidden_states = hidden_states + + def finished(self) -> bool: + return self.finish_reason is not None + + def __repr__(self) -> str: + return (f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})") + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + prompt_logprobs: The log probabilities to return per prompt token. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + prompt_logprobs: Optional[PromptLogprobs], + outputs: List[CompletionOutput], + finished: bool, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs + self.outputs = outputs + self.finished = finished + + @classmethod + def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + # Get the top-n sequences. + n = seq_group.sampling_params.n + seqs = seq_group.get_seqs() + if seq_group.sampling_params.use_beam_search: + sorting_key = lambda seq: seq.get_beam_search_score( + seq_group.sampling_params.length_penalty) + else: + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + + # Create the outputs. + outputs: List[CompletionOutput] = [] + for seq in top_n_seqs: + logprobs = seq.output_logprobs + if seq_group.sampling_params.logprobs is None: + # NOTE: We need to take care of this case because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + logprobs = None + finshed_reason = SequenceStatus.get_finished_reason(seq.status) + output = CompletionOutput(seqs.index(seq), seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), logprobs, + finshed_reason, + seq.data.hidden_states + ) + outputs.append(output) + + # Every sequence in the sequence group should have the same prompt. + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + finished = seq_group.is_finished() + return cls(seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished) + + def __repr__(self) -> str: + return (f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished})" + ) diff --git a/ChatTTS/vllm_engine/post_model.py b/ChatTTS/vllm_engine/post_model.py new file mode 100644 index 000000000..c38853b3a --- /dev/null +++ b/ChatTTS/vllm_engine/post_model.py @@ -0,0 +1,201 @@ +import os, platform + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +""" +https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning +""" + +import logging + +import torch +import torch.nn as nn +from torch.functional import F +from torch.nn.utils.parametrizations import weight_norm +from typing import List, Callable +class Post_model(nn.Module): + def __init__( + self, + hidden_size: int, + num_audio_tokens: int, + num_text_tokens: int, + num_vq=4 + ): + super().__init__() + + self.num_vq = num_vq + self.num_audio_tokens = num_audio_tokens + + self.model_dim = hidden_size + self.emb_code = nn.ModuleList( + [ + nn.Embedding( + num_audio_tokens, + self.model_dim + ) + for _ in range(num_vq) + ], + ) + self.emb_text = nn.Embedding( + num_text_tokens, self.model_dim + ) + + self.head_text = weight_norm( + nn.Linear( + self.model_dim, + num_text_tokens, + bias=False + ), + name="weight", + ) + self.head_code = nn.ModuleList( + [ + weight_norm( + nn.Linear( + self.model_dim, + num_audio_tokens, + bias=False + ), + name="weight", + ) + for _ in range(self.num_vq) + ], + ) + + + def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: + """ + get_emb + """ + device = next(self.parameters()).device + emb_text: torch.Tensor = self.emb_text( + input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) + ) + + text_mask_inv = text_mask.logical_not().to(device) + masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) + + emb_code = [ + self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) + ] + emb_code = torch.stack(emb_code, 2).sum(2) + + emb = torch.zeros( + (input_ids.shape[:-1]) + (emb_text.shape[-1],), + device=emb_text.device, + dtype=emb_text.dtype, + ) + emb[text_mask] = emb_text + emb[text_mask_inv] = emb_code.to(emb.dtype) + + del emb_text, emb_code, text_mask_inv + + return emb + +class Sampler: + def __init__(self, + post_model: Post_model, + num_audio_tokens: int, + num_vq: int + ): + self.post_model = post_model + self.device = next(self.post_model.parameters()).device + self.num_audio_tokens = num_audio_tokens + self.num_vq = num_vq + + def sample(self, + inputs_ids: torch.Tensor, + hidden_states: torch.Tensor, + infer_text: bool = False, + temperature: torch.Tensor = 1.0, + logits_processors: List[Callable] = [lambda logits_token, logits: logits,], + logits_warpers: List[Callable] = [lambda logits_token, logits: logits,], + min_new_token: int = 0, + now_length: int = 0, + eos_token: int = 0, + start_idx: int = 0, + ): + # print(inputs_ids.shape) + B = hidden_states.shape[0] + + end_idx = torch.zeros( + inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + ) + finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() + if not infer_text: + temperature = ( + temperature.unsqueeze(0) + .expand(inputs_ids.shape[0], -1) + .contiguous() + .view(-1, 1) + ) + + if infer_text: + logits: torch.Tensor = self.post_model.head_text(hidden_states) + else: + # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) + logits = torch.empty( + hidden_states.size(0), + hidden_states.size(1), + self.num_audio_tokens, + self.num_vq, + dtype=torch.float, + device=self.device, + ) + for num_vq_iter in range(self.num_vq): + x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) + logits[..., num_vq_iter] = x + del x + + del hidden_states + + # logits = logits[:, -1].float() + logits = logits.narrow(1, -1, 1).squeeze_(1).float() + + if not infer_text: + # logits = rearrange(logits, "b c n -> (b n) c") + logits = logits.permute(0, 2, 1) + logits = logits.reshape(-1, logits.size(2)) + # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") + inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) + logits_token = inputs_ids_sliced.reshape( + inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), + -1, + ).to(self.device) + else: + logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + + logits /= temperature + + for logitsProcessors in logits_processors: + logits = logitsProcessors(logits_token, logits) + + for logitsWarpers in logits_warpers: + logits = logitsWarpers(logits_token, logits) + + del logits_token + + if now_length < min_new_token: + logits[:, eos_token] = -torch.inf + + scores = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) + if not infer_text: + scores = scores.reshape(B, -1, scores.shape[-1]) + if not infer_text: + # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) + idx_next = idx_next.view(-1, self.num_vq) + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + else: + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + + del inputs_ids + + not_finished = finish.logical_not().to(end_idx.device) + + end_idx.add_(not_finished.int()) + idx_next = idx_next[:, None, :] + return idx_next, torch.log(scores), finish, \ No newline at end of file diff --git a/ChatTTS/vllm_engine/sampling_params.py b/ChatTTS/vllm_engine/sampling_params.py new file mode 100644 index 000000000..be3f9bf7f --- /dev/null +++ b/ChatTTS/vllm_engine/sampling_params.py @@ -0,0 +1,273 @@ +"""Sampling parameters for text generation.""" +from enum import IntEnum +from functools import cached_property +from typing import Callable, List, Optional, Union + +import torch + +_SAMPLING_EPS = 1e-5 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + BEAM = 2 + + +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + +class SamplingParams: + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. + use_beam_search: Whether to use beam search instead of sampling. + length_penalty: Float that penalizes sequences based on their length. + Used in beam search. + early_stopping: Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm). + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + stop_token_ids: List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in output + text. Defaults to False. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + logprobs: Number of log probabilities to return per output token. + Note that the implementation follows the OpenAI API: The return + result includes the log probabilities on the `logprobs` most likely + tokens, as well the chosen tokens. The API will always return the + log probability of the sampled token, so there may be up to + `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. + skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens. + """ + + def __init__( + self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: int = 16, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = ([lambda logits_token, logits: logits,],[lambda logits_token, logits: logits,]), + min_new_token: int = 0, + max_new_token: int = 8192, + infer_text: bool = False, + eos_token: int = 0, + spk_emb:str = None, + start_idx:int = 0, + ) -> None: + self.n = n + self.best_of = best_of if best_of is not None else n + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.repetition_penalty = repetition_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.use_beam_search = use_beam_search + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.min_new_token = min_new_token + self.max_new_token = max_new_token + self.infer_text = infer_text + self.eos_token = eos_token + self.spk_emb = spk_emb + self.start_idx = start_idx + if stop is None: + self.stop = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = list(stop) + if stop_token_ids is None: + self.stop_token_ids = [] + else: + self.stop_token_ids = list(stop_token_ids) + self.ignore_eos = ignore_eos + self.max_tokens = max_tokens + self.logprobs = logprobs + self.prompt_logprobs = prompt_logprobs + self.skip_special_tokens = skip_special_tokens + self.spaces_between_special_tokens = spaces_between_special_tokens + self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output + self._verify_args() + if self.use_beam_search: + self._verify_beam_search() + else: + self._verify_non_beam_search() + # if self.temperature < _SAMPLING_EPS: + # # Zero temperature means greedy sampling. + # self.top_p = 1.0 + # self.top_k = -1 + # self.min_p = 0.0 + # self._verify_greedy_sampling() + + def _verify_args(self) -> None: + if self.n < 1: + raise ValueError(f"n must be at least 1, got {self.n}.") + if self.best_of < self.n: + raise ValueError(f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}.") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("presence_penalty must be in [-2, 2], got " + f"{self.presence_penalty}.") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}.") + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError("repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}.") + # if self.temperature < 0.0: + # raise ValueError( + # f"temperature must be non-negative, got {self.temperature}.") + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, " + f"got {self.top_k}.") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " + f"{self.min_p}.") + if self.max_tokens < 1: + raise ValueError( + f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.logprobs is not None and self.logprobs < 0: + raise ValueError( + f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError(f"prompt_logprobs must be non-negative, got " + f"{self.prompt_logprobs}.") + + def _verify_beam_search(self) -> None: + if self.best_of == 1: + raise ValueError("best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}.") + if self.temperature > _SAMPLING_EPS: + raise ValueError("temperature must be 0 when using beam search.") + if self.top_p < 1.0 - _SAMPLING_EPS: + raise ValueError("top_p must be 1 when using beam search.") + if self.top_k != -1: + raise ValueError("top_k must be -1 when using beam search.") + if self.early_stopping not in [True, False, "never"]: + raise ValueError( + f"early_stopping must be True, False, or 'never', " + f"got {self.early_stopping}.") + + def _verify_non_beam_search(self) -> None: + if self.early_stopping is not False: + raise ValueError("early_stopping is not effective and must be " + "False when not using beam search.") + if (self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS): + raise ValueError( + "length_penalty is not effective and must be the " + "default value of 1.0 when not using beam search.") + + def _verify_greedy_sampling(self) -> None: + if self.best_of > 1: + raise ValueError("best_of must be 1 when using greedy sampling." + f"Got {self.best_of}.") + + @cached_property + def sampling_type(self) -> SamplingType: + if self.use_beam_search: + return SamplingType.BEAM + # if self.temperature < _SAMPLING_EPS: + # return SamplingType.GREEDY + return SamplingType.RANDOM + + def __repr__(self) -> str: + return ( + f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens}), " + f"max_new_token={self.max_new_token}), " + f"min_new_token={self.min_new_token}), " + f"infer_text={self.infer_text})" + ) diff --git a/ChatTTS/vllm_engine/scheduler.py b/ChatTTS/vllm_engine/scheduler.py new file mode 100644 index 000000000..27f5752a7 --- /dev/null +++ b/ChatTTS/vllm_engine/scheduler.py @@ -0,0 +1,413 @@ +import enum +import time +from typing import Dict, Iterable, List, Optional, Tuple, Union + +from vllm.config import CacheConfig, SchedulerConfig +from ChatTTS.vllm_engine.block_manager import AllocStatus, BlockSpaceManager +from vllm.core.policy import PolicyFactory +from vllm.logger import init_logger +from ChatTTS.vllm_engine.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) + +logger = init_logger(__name__) + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +class SchedulerOutputs: + + def __init__( + self, + scheduled_seq_groups: List[SequenceGroup], + prompt_run: bool, + num_batched_tokens: int, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ignored_seq_groups: List[SequenceGroup], + ) -> None: + self.scheduled_seq_groups = scheduled_seq_groups + self.prompt_run = prompt_run + self.num_batched_tokens = num_batched_tokens + self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_copy = blocks_to_copy + # Swap in and swap out should never happen at the same time. + assert not (blocks_to_swap_in and blocks_to_swap_out) + self.ignored_seq_groups = ignored_seq_groups + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + + self.prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy(policy_name="fcfs") + # Create the block space manager. + self.block_manager = BlockSpaceManager( + block_size=self.cache_config.block_size, + num_gpu_blocks=self.cache_config.num_gpu_blocks, + num_cpu_blocks=self.cache_config.num_cpu_blocks, + sliding_window=self.cache_config.sliding_window) + + # TODO(zhuohan): Use deque instead of list for better performance. + # Sequence groups in the WAITING state. + self.waiting: List[SequenceGroup] = [] + # Sequence groups in the RUNNING state. + self.running: List[SequenceGroup] = [] + # Sequence groups in the SWAPPED state. + self.swapped: List[SequenceGroup] = [] + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + for state_queue in [self.waiting, self.running, self.swapped]: + # We need to reverse the list as we are removing elements + # from it as we iterate over it. If we don't do it, + # indices will get messed up and we will skip over elements. + for seq_group in reversed(state_queue): + if seq_group.request_id in request_ids: + # Remove the sequence group from the state queue. + state_queue.remove(seq_group) + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + request_ids.remove(seq_group.request_id) + if not request_ids: + return + + def has_unfinished_seqs(self) -> bool: + return self.waiting or self.running or self.swapped + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def _schedule(self) -> SchedulerOutputs: + # Blocks that need to be swaped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_out: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} + + # Fix the current time. + now = time.monotonic() + + # Join waiting sequences if possible. + if not self.swapped: + ignored_seq_groups: List[SequenceGroup] = [] + scheduled: List[SequenceGroup] = [] + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) + seq_lens: List[int] = [] + + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + while self.waiting: + seq_group = self.waiting[0] + + waiting_seqs = seq_group.get_seqs( + status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + num_prompt_tokens = waiting_seqs[0].get_len() + if num_prompt_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + continue + + # If the number of batched tokens exceeds the limit, stop. + new_seq_lens = seq_lens + [num_prompt_tokens] + num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + if (num_batched_tokens > + self.scheduler_config.max_num_batched_tokens): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + num_paddings = num_batched_tokens - sum(new_seq_lens) + if num_paddings > self.scheduler_config.max_paddings: + break + seq_lens = new_seq_lens + + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_curr_seqs += num_new_seqs + scheduled.append(seq_group) + + if scheduled or ignored_seq_groups: + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=scheduled, + prompt_run=True, + num_batched_tokens=len(seq_lens) * + max(seq_lens) if seq_lens else 0, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + ) + return scheduler_outputs + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.pop(0) + while not self.block_manager.can_append_slot(seq_group): + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq_group = self.running.pop(-1) + self._preempt(victim_seq_group, blocks_to_swap_out) + preempted.append(victim_seq_group) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq_group, blocks_to_swap_out) + preempted.append(seq_group) + break + else: + # Append new slots to the sequence group. + self._append_slot(seq_group, blocks_to_copy) + running.append(seq_group) + self.running = running + + # Swap in the sequence groups in the SWAPPED state if possible. + self.swapped = self.policy.sort_by_priority(now, self.swapped) + if not preempted: + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) + + while self.swapped: + seq_group = self.swapped[0] + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + seq_group = self.swapped.pop(0) + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slot(seq_group, blocks_to_copy) + num_curr_seqs += num_new_seqs + self.running.append(seq_group) + + # Each sequence in the generation phase only takes one token slot. + # Therefore, the number of batched tokens is equal to the number of + # sequences in the RUNNING state. + num_batched_tokens = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) + + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=self.running, + prompt_run=False, + num_batched_tokens=num_batched_tokens, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=[], + ) + return scheduler_outputs + + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_outputs = self._schedule() + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for seq_group in scheduler_outputs.scheduled_seq_groups: + seq_data: Dict[int, SequenceData] = {} + block_tables: Dict[int, List[int]] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=scheduler_outputs.prompt_run, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + ) + seq_group_metadata_list.append(seq_group_metadata) + return seq_group_metadata_list, scheduler_outputs + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + self.block_manager.free(seq) + + def free_finished_seq_groups(self) -> None: + self.running = [ + seq_group for seq_group in self.running + if not seq_group.is_finished() + ] + + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slot( + self, + seq_group: SequenceGroup, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + ret = self.block_manager.append_slot(seq) + if ret is not None: + src_block, dst_block = ret + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + preemption_mode: Optional[PreemptionMode] = None, + ) -> None: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + # NOTE: For FCFS, we insert the preempted sequence group to the front + # of the waiting queue. + self.waiting.insert(0, seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/ChatTTS/vllm_engine/sequence.py b/ChatTTS/vllm_engine/sequence.py new file mode 100644 index 000000000..a417fe7f5 --- /dev/null +++ b/ChatTTS/vllm_engine/sequence.py @@ -0,0 +1,436 @@ +"""Sequence and its related classes.""" +import copy +import enum +from typing import Dict, List, Optional, Union +import torch +from vllm.block import LogicalTokenBlock +from ChatTTS.vllm_engine.sampling_params import SamplingParams + +PromptLogprobs = List[Optional[Dict[int, float]]] +SampleLogprobs = List[Dict[int, float]] + + +class SequenceStatus(enum.Enum): + """Status of a sequence.""" + WAITING = enum.auto() + RUNNING = enum.auto() + SWAPPED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + SequenceStatus.FINISHED_ABORTED, + SequenceStatus.FINISHED_IGNORED, + ] + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +class SequenceData: + """Data associated with a sequence. + + + Args: + prompt_token_ids: The token IDs of the prompt. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + + def __init__( + self, + prompt_token_ids: List[int], + ) -> None: + self.prompt_token_ids = prompt_token_ids + self.output_token_ids: List[int] = [] + self.cumulative_logprob = 0.0 + self.hidden_states: Optional[torch.Tensor] = None + self.finished = False + + def append_token_id(self, token_id: int, logprob: float) -> None: + if isinstance(self.cumulative_logprob, float): + self.cumulative_logprob = [0.0, ] * len(logprob) + self.output_token_ids.append(token_id) + for i in range(len(self.cumulative_logprob)): + self.cumulative_logprob[i] += logprob[i] + + def append_hidden_states(self, hidden_states: torch.Tensor) -> None: + if self.hidden_states is None: + self.hidden_states = hidden_states + else: + self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) + + def get_len(self) -> int: + return len(self.output_token_ids) + len(self.prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self.prompt_token_ids) + + def get_output_len(self) -> int: + return len(self.output_token_ids) + + def get_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.output_token_ids + + def get_last_token_id(self) -> int: + if not self.output_token_ids: + return self.prompt_token_ids[-1] + return self.output_token_ids[-1] + + def __repr__(self) -> str: + return (f"SequenceData(" + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}), " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " + f"finished={self.finished})") + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + Args: + seq_id: The ID of the sequence. + prompt: The prompt of the sequence. + prompt_token_ids: The token IDs of the prompt. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + """ + + def __init__( + self, + seq_id: int, + prompt: str, + prompt_token_ids: List[int], + block_size: int, + ) -> None: + self.seq_id = seq_id + self.prompt = prompt + self.block_size = block_size + + self.data = SequenceData(prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" + + self.logical_token_blocks: List[LogicalTokenBlock] = [] + # Initialize the logical token blocks with the prompt token ids. + self._append_tokens_to_blocks(prompt_token_ids) + self.status = SequenceStatus.WAITING + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + + def _append_logical_block(self) -> None: + block = LogicalTokenBlock( + block_number=len(self.logical_token_blocks), + block_size=self.block_size, + ) + self.logical_token_blocks.append(block) + + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: + cursor = 0 + while cursor < len(token_ids): + if not self.logical_token_blocks: + self._append_logical_block() + + last_block = self.logical_token_blocks[-1] + if last_block.is_full(): + self._append_logical_block() + last_block = self.logical_token_blocks[-1] + + num_empty_slots = last_block.get_num_empty_slots() + last_block.append_tokens(token_ids[cursor:cursor + + num_empty_slots]) + cursor += num_empty_slots + + def append_token_id( + self, + token_id: int, + logprobs: Dict[int, float], + hidden_states: Optional[torch.Tensor] = None, + finished: bool = False + ) -> None: + assert token_id in logprobs + self._append_tokens_to_blocks([token_id]) + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id]) + self.data.append_hidden_states(hidden_states) + self.data.finished = finished + + def get_len(self) -> int: + return self.data.get_len() + + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + + def get_output_len(self) -> int: + return self.data.get_output_len() + + def get_token_ids(self) -> List[int]: + return self.data.get_token_ids() + + def get_last_token_id(self) -> int: + return self.data.get_last_token_id() + + def get_output_token_ids(self) -> List[int]: + return self.data.output_token_ids + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + + def get_beam_search_score(self, + length_penalty: float = 0.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + if seq_len is None: + seq_len = self.get_len() + # NOTE: HF implementation does not count the EOS token + # towards the length, we align with that here for testing. + if (eos_token_id is not None + and self.get_last_token_id() == eos_token_id): + seq_len -= 1 + return self.get_cumulative_logprob() / (seq_len**length_penalty) + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq + + def __repr__(self) -> str: + return (f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={len(self.logical_token_blocks)})") + + +class SequenceGroup: + """A group of sequences that are generated from the same prompt. + + Args: + request_id: The ID of the request. + seqs: The list of sequences. + sampling_params: The sampling parameters used to generate the outputs. + arrival_time: The arrival time of the request. + """ + + def __init__( + self, + request_id: str, + seqs: List[Sequence], + sampling_params: SamplingParams, + arrival_time: float, + ) -> None: + self.request_id = request_id + self.seqs_dict = {seq.seq_id: seq for seq in seqs} + self.sampling_params = sampling_params + self.arrival_time = arrival_time + self.prompt_logprobs: Optional[PromptLogprobs] = None + + @property + def prompt(self) -> str: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).prompt + + @property + def prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).data.prompt_token_ids + + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.sampling_params.use_beam_search: + # For beam search, maximally there will always be `best_of` beam + # candidates running in the future. + return self.sampling_params.best_of + else: + if self.sampling_params.best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences running. + return self.sampling_params.best_of + # At sampling stages, return the number of actual sequences + # that are not finished yet. + return self.num_unfinished_seqs() + + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> List[Sequence]: + if status is None: + return list(self.seqs_dict.values()) + else: + return [ + seq for seq in self.seqs_dict.values() if seq.status == status + ] + + def get_unfinished_seqs(self) -> List[Sequence]: + return [ + seq for seq in self.seqs_dict.values() if not seq.is_finished() + ] + + def get_finished_seqs(self) -> List[Sequence]: + return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + return len(self.get_seqs(status)) + + def num_unfinished_seqs(self) -> int: + return len(self.get_unfinished_seqs()) + + def num_finished_seqs(self) -> int: + return len(self.get_finished_seqs()) + + def find(self, seq_id: int) -> Sequence: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + return self.seqs_dict[seq_id] + + def add(self, seq: Sequence) -> None: + if seq.seq_id in self.seqs_dict: + raise ValueError(f"Sequence {seq.seq_id} already exists.") + self.seqs_dict[seq.seq_id] = seq + + def remove(self, seq_id: int) -> None: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + del self.seqs_dict[seq_id] + + def is_finished(self) -> bool: + return all(seq.is_finished() for seq in self.get_seqs()) + + def __repr__(self) -> str: + return (f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs_dict)})") + + +class SequenceGroupMetadata: + """Metadata for a sequence group. Used to create `InputMetadata`. + + + Args: + request_id: The ID of the request. + is_prompt: Whether the request is at prompt stage. + seq_data: The sequence data. (Seq id -> sequence data) + sampling_params: The sampling parameters used to generate the outputs. + block_tables: The block tables. (Seq id -> list of physical block + numbers) + """ + + def __init__( + self, + request_id: str, + is_prompt: bool, + seq_data: Dict[int, SequenceData], + sampling_params: SamplingParams, + block_tables: Dict[int, List[int]], + ) -> None: + self.request_id = request_id + self.is_prompt = is_prompt + self.seq_data = seq_data + self.sampling_params = sampling_params + self.block_tables = block_tables + + +class SequenceOutput: + """The model output associated with a sequence. + + Args: + parent_seq_id: The ID of the parent sequence (for forking in beam + search). + output_token: The output token ID. + logprobs: The logprobs of the output token. + (Token id -> logP(x_i+1 | x_0, ..., x_i)) + """ + + def __init__( + self, + parent_seq_id: int, + output_token: int, + logprobs: Dict[int, float], + hidden_states: Optional[torch.Tensor] = None, + finished: bool = False + ) -> None: + self.parent_seq_id = parent_seq_id + self.output_token = output_token + self.logprobs = logprobs + self.finished = finished + self.hidden_states = hidden_states + def __repr__(self) -> str: + return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})," + f"finished={self.finished})," + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceOutput): + raise NotImplementedError() + return (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + and self.logprobs == other.logprobs) + + +class SequenceGroupOutput: + """The model output associated with a sequence group.""" + + def __init__( + self, + samples: List[SequenceOutput], + prompt_logprobs: Optional[PromptLogprobs], + ) -> None: + self.samples = samples + self.prompt_logprobs = prompt_logprobs + + def __repr__(self) -> str: + return (f"SequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceGroupOutput): + raise NotImplementedError() + return (self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs) + + +# For each sequence group, we generate a list of SequenceOutput object, +# each of which contains one possible candidate for the next token. +SamplerOutput = List[SequenceGroupOutput] diff --git a/ChatTTS/vllm_engine/worker.py b/ChatTTS/vllm_engine/worker.py new file mode 100644 index 000000000..84e5c85d8 --- /dev/null +++ b/ChatTTS/vllm_engine/worker.py @@ -0,0 +1,237 @@ +"""A GPU worker class.""" +import os +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_object_list) +from vllm.model_executor.parallel_utils.parallel_state import ( + initialize_model_parallel) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine +from ChatTTS.vllm_engine.model_runner import ModelRunner + + +class Worker: + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + post_model_path:str, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + self.post_model_path = post_model_path + + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = ModelRunner(model_config, parallel_config, + scheduler_config, is_driver_worker, post_model_path) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + def init_model(self) -> None: + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) + + # Initialize the model. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int, + ) -> Tuple[int, int]: + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + cache_block_size = CacheEngine.get_cache_block_size( + block_size, self.model_config, self.parallel_config) + num_gpu_blocks = int( + (total_gpu_memory * gpu_memory_utilization - peak_memory) // + cache_block_size) + num_cpu_blocks = int(cpu_swap_space // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CacheEngine(self.cache_config, self.model_config, + self.parallel_config) + self.cache_events = self.cache_engine.events + self.gpu_cache = self.cache_engine.gpu_cache + self.model_runner.set_block_size(self.cache_engine.block_size) + + def warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def cache_swap( + self, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + # TODO(woosuk): Profile swapping overhead and optimize if needed. + if cache_events is not None: + for event in cache_events: + event.wait() + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + block_swapping_info = [ + blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy + ] + broadcast_object_list([num_seq_groups] + block_swapping_info, + src=0) + else: + # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, + # blocks_to_copy (4 elements) + recv_data = [None] * 4 + broadcast_object_list(recv_data, src=0) + num_seq_groups = recv_data[0] + block_swapping_info = recv_data[1:] + + self.cache_swap(*block_swapping_info) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + return output + + +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}.") diff --git a/test.py b/test.py new file mode 100644 index 000000000..07e9fae5d --- /dev/null +++ b/test.py @@ -0,0 +1,27 @@ +import ChatTTS as ChatTTS +import torch +import torchaudio +import soundfile as sf +chat = ChatTTS.Chat() +chat.load(compile=False) # Set to True for better performance +rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery + +params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb = rand_spk, # add sampled speaker + temperature = .3, # using custom temperature + top_P = 0.7, # top P decode + top_K = 20, # top K decode +) +params_refine_text = ChatTTS.Chat.RefineTextParams( + prompt='[oral_2][laugh_0][break_6]', +) +texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] + +wavs = chat.infer(texts, + params_refine_text=params_refine_text, + params_infer_code=params_infer_code, + ) + +# torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) +sf.write("output1.wav", wavs[1], 24000) \ No newline at end of file From 9de0a5385735883687b6c8044bb706708b14e89c Mon Sep 17 00:00:00 2001 From: ylzz1997 Date: Sun, 21 Jul 2024 04:06:52 +0800 Subject: [PATCH 02/27] sync --- ChatTTS/core.py | 112 +++++++++++++++++++------------------ ChatTTS/model/tokenizer.py | 2 +- 2 files changed, 58 insertions(+), 56 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index ca16f0af1..22865472d 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -9,6 +9,7 @@ import pathlib from ChatTTS.vllm_engine.post_model import Post_model from safetensors.torch import save_file, safe_open +from omegaconf import OmegaConf import numpy as np import torch from vocos import Vocos @@ -29,6 +30,7 @@ from .utils import logger as utils_logger from .norm import Normalizer +import pybase16384 as b14 class Chat: @@ -36,8 +38,6 @@ def __init__(self, logger=logging.getLogger(__name__)): self.logger = logger utils_logger.set_logger(logger) - self.config = Config() - self.normalizer = Normalizer( os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"), logger, @@ -144,7 +144,9 @@ def load( use_flash_attn=use_flash_attn, **{ k: os.path.join(download_path, v) - for k, v in asdict(self.config.path).items() + for k, v in OmegaConf.load( + os.path.join(download_path, "config", "path.yaml") + ).items() }, ) @@ -240,9 +242,13 @@ def interrupt(self): @torch.no_grad() def _load( self, + vocos_config_path: str = None, vocos_ckpt_path: str = None, + dvae_config_path: str = None, dvae_ckpt_path: str = None, + gpt_config_path: str = None, gpt_ckpt_path: str = None, + decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: Optional[torch.device] = None, @@ -256,22 +262,36 @@ def _load( self.device = device self.compile = compile - feature_extractor = instantiate_class( - args=(), init=asdict(self.config.vocos.feature_extractor) - ) - backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone)) - head = instantiate_class(args=(), init=asdict(self.config.vocos.head)) - vocos = ( - Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head) - .to( - # vocos on mps will crash, use cpu fallback - "cpu" - if "mps" in str(device) - else device + if vocos_config_path: + vocos = ( + Vocos.from_hparams(vocos_config_path) + .to( + # vocos on mps will crash, use cpu fallback + "cpu" + if "mps" in str(device) + else device + ) + .eval() + ) + assert vocos_ckpt_path, "vocos_ckpt_path should not be None" + vocos.load_state_dict( + torch.load(vocos_ckpt_path, weights_only=True, mmap=True) + ) + self.vocos = vocos + self.logger.log(logging.INFO, "vocos loaded.") + + if dvae_config_path: + cfg = OmegaConf.load(dvae_config_path) + dvae = DVAE(**cfg, coef=coef).to(device).eval() + coef = str(dvae) + assert dvae_ckpt_path, "dvae_ckpt_path should not be None" + dvae.load_state_dict( + torch.load(dvae_ckpt_path, weights_only=True, mmap=True) ) self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") + if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) self.num_vq = 4 @@ -320,6 +340,17 @@ def _load( post_model_path="asset/vllm_model/post_model.safetensors", ) + if dvae_config_path: + cfg = OmegaConf.load(dvae_config_path) + dvae = DVAE(**cfg, coef=coef).to(device).eval() + coef = str(dvae) + assert dvae_ckpt_path, "dvae_ckpt_path should not be None" + dvae.load_state_dict( + torch.load(dvae_ckpt_path, weights_only=True, mmap=True) + ) + self.dvae = dvae + self.logger.log(logging.INFO, "dvae loaded.") + if decoder_config_path: cfg = OmegaConf.load(decoder_config_path) decoder = DVAE(**cfg, coef=coef).to(device).eval() @@ -328,16 +359,8 @@ def _load( decoder.load_state_dict( torch.load(decoder_ckpt_path, weights_only=True, mmap=True) ) - .to(device) - .eval() - ) - coef = str(decoder) - assert decoder_ckpt_path, "decoder_ckpt_path should not be None" - decoder.load_state_dict( - torch.load(decoder_ckpt_path, weights_only=True, mmap=True) - ) - self.decoder = decoder - self.logger.log(logging.INFO, "decoder loaded.") + self.decoder = decoder + self.logger.log(logging.INFO, "decoder loaded.") if tokenizer_path: self.tokenizer = Tokenizer(tokenizer_path, device) @@ -473,33 +496,6 @@ def _decode_spk_emb(spk_emb: str) -> np.ndarray: dtype=np.float16, ).copy() - @torch.no_grad() - def _apply_spk_emb( - self, - emb: torch.Tensor, - spk_emb: str, - input_ids: torch.Tensor, - ): - n = ( - F.normalize( - torch.from_numpy( - self._decode_spk_emb(spk_emb), - ), - p=2.0, - dim=0, - eps=1e-12, - ) - .to(self.gpt.device_gpt) - .unsqueeze_(0) - .expand(emb.size(0), -1) - .unsqueeze_(1) - .expand(emb.shape) - ) - cond = ( - input_ids.narrow(-1, 0, 1).eq(self.tokenizer.spk_emb_ids).expand(emb.shape) - ) - torch.where(cond, n, emb, out=emb) - del cond, n @dataclass(repr=False, eq=False) class GenerationOutputs: ids: List[torch.Tensor] @@ -552,9 +548,12 @@ def _infer_code( text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] else: text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - + print(params.spk_smp, txt_smp, text) input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, self.num_vq, self.device + text, + self.num_vq, + prompt_str=params.spk_smp, + device=self.device, ) start_idx = input_ids.shape[-2] @@ -611,8 +610,11 @@ def _refine_text( text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, self.num_vq, self.device + text, + self.num_vq, + device=self.device, ) + start_idx = input_ids.shape[-2] # print(start_idx) logits_warpers, logits_processors = gen_logits( diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 0ee4ca706..4110a185b 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -211,4 +211,4 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str: ), ) del arr - return s + return s \ No newline at end of file From 8d6bc30dd77a59329c71aa7c3735eadc4356e8e7 Mon Sep 17 00:00:00 2001 From: ylzz1997 Date: Sun, 21 Jul 2024 04:07:16 +0800 Subject: [PATCH 03/27] remove debug print --- ChatTTS/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 22865472d..6c449c0b1 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -548,7 +548,7 @@ def _infer_code( text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] else: text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - print(params.spk_smp, txt_smp, text) + input_ids, attention_mask, text_mask = self.tokenizer.encode( text, self.num_vq, From c817ae9815ca48695d9bdadb46f40ac82157d778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:16:10 +0900 Subject: [PATCH 04/27] chore: move vllm_engine to model/velocity --- ChatTTS/core.py | 6 +++--- ChatTTS/model/tokenizer.py | 2 +- ChatTTS/{vllm_engine => model/velocity}/__init__.py | 0 .../{vllm_engine => model/velocity}/block_manager.py | 2 +- ChatTTS/{vllm_engine => model/velocity}/configs.py | 0 ChatTTS/{vllm_engine => model/velocity}/llama.py | 0 ChatTTS/{vllm_engine => model/velocity}/llm.py | 8 ++++---- .../{vllm_engine => model/velocity}/llm_engine.py | 12 ++++++------ .../{vllm_engine => model/velocity}/model_loader.py | 2 +- .../{vllm_engine => model/velocity}/model_runner.py | 10 +++++----- ChatTTS/{vllm_engine => model/velocity}/output.py | 2 +- .../{vllm_engine => model/velocity}/post_model.py | 0 .../velocity}/sampling_params.py | 0 ChatTTS/{vllm_engine => model/velocity}/scheduler.py | 4 ++-- ChatTTS/{vllm_engine => model/velocity}/sequence.py | 2 +- ChatTTS/{vllm_engine => model/velocity}/worker.py | 2 +- 16 files changed, 26 insertions(+), 26 deletions(-) rename ChatTTS/{vllm_engine => model/velocity}/__init__.py (100%) rename ChatTTS/{vllm_engine => model/velocity}/block_manager.py (99%) rename ChatTTS/{vllm_engine => model/velocity}/configs.py (100%) rename ChatTTS/{vllm_engine => model/velocity}/llama.py (100%) rename ChatTTS/{vllm_engine => model/velocity}/llm.py (97%) rename ChatTTS/{vllm_engine => model/velocity}/llm_engine.py (98%) rename ChatTTS/{vllm_engine => model/velocity}/model_loader.py (96%) rename ChatTTS/{vllm_engine => model/velocity}/model_runner.py (98%) rename ChatTTS/{vllm_engine => model/velocity}/output.py (98%) rename ChatTTS/{vllm_engine => model/velocity}/post_model.py (100%) rename ChatTTS/{vllm_engine => model/velocity}/sampling_params.py (100%) rename ChatTTS/{vllm_engine => model/velocity}/scheduler.py (99%) rename ChatTTS/{vllm_engine => model/velocity}/sequence.py (99%) rename ChatTTS/{vllm_engine => model/velocity}/worker.py (99%) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 6c449c0b1..94659b52a 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -7,7 +7,7 @@ from pathlib import Path import lzma import pathlib -from ChatTTS.vllm_engine.post_model import Post_model +from ChatTTS.model.velocity.post_model import Post_model from safetensors.torch import save_file, safe_open from omegaconf import OmegaConf import numpy as np @@ -16,8 +16,8 @@ from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download import pybase16384 as b14 -from ChatTTS.vllm_engine.llm import LLM -from ChatTTS.vllm_engine.sampling_params import SamplingParams +from ChatTTS.model.velocity.llm import LLM +from ChatTTS.model.velocity.sampling_params import SamplingParams import yaml from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 4110a185b..0ee4ca706 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -211,4 +211,4 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str: ), ) del arr - return s \ No newline at end of file + return s diff --git a/ChatTTS/vllm_engine/__init__.py b/ChatTTS/model/velocity/__init__.py similarity index 100% rename from ChatTTS/vllm_engine/__init__.py rename to ChatTTS/model/velocity/__init__.py diff --git a/ChatTTS/vllm_engine/block_manager.py b/ChatTTS/model/velocity/block_manager.py similarity index 99% rename from ChatTTS/vllm_engine/block_manager.py rename to ChatTTS/model/velocity/block_manager.py index b95cdc840..199a3a278 100644 --- a/ChatTTS/vllm_engine/block_manager.py +++ b/ChatTTS/model/velocity/block_manager.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Set, Tuple from vllm.block import PhysicalTokenBlock -from ChatTTS.vllm_engine.sequence import Sequence, SequenceGroup, SequenceStatus +from ChatTTS.model.velocity.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device # Mapping: logical block number -> physical block. diff --git a/ChatTTS/vllm_engine/configs.py b/ChatTTS/model/velocity/configs.py similarity index 100% rename from ChatTTS/vllm_engine/configs.py rename to ChatTTS/model/velocity/configs.py diff --git a/ChatTTS/vllm_engine/llama.py b/ChatTTS/model/velocity/llama.py similarity index 100% rename from ChatTTS/vllm_engine/llama.py rename to ChatTTS/model/velocity/llama.py diff --git a/ChatTTS/vllm_engine/llm.py b/ChatTTS/model/velocity/llm.py similarity index 97% rename from ChatTTS/vllm_engine/llm.py rename to ChatTTS/model/velocity/llm.py index b2bce9f6b..9668c87cf 100644 --- a/ChatTTS/vllm_engine/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -3,10 +3,10 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from ChatTTS.vllm_engine.configs import EngineArgs -from ChatTTS.vllm_engine.llm_engine import LLMEngine -from ChatTTS.vllm_engine.output import RequestOutput -from ChatTTS.vllm_engine.sampling_params import SamplingParams +from ChatTTS.model.velocity.configs import EngineArgs +from ChatTTS.model.velocity.llm_engine import LLMEngine +from ChatTTS.model.velocity.output import RequestOutput +from ChatTTS.model.velocity.sampling_params import SamplingParams from vllm.utils import Counter diff --git a/ChatTTS/vllm_engine/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py similarity index 98% rename from ChatTTS/vllm_engine/llm_engine.py rename to ChatTTS/model/velocity/llm_engine.py index a89bb87cd..4a72c0c3f 100644 --- a/ChatTTS/vllm_engine/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -7,14 +7,14 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from ChatTTS.vllm_engine.scheduler import Scheduler, SchedulerOutputs -from ChatTTS.vllm_engine.configs import EngineArgs +from ChatTTS.model.velocity.scheduler import Scheduler, SchedulerOutputs +from ChatTTS.model.velocity.configs import EngineArgs from vllm.engine.metrics import record_metrics from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.logger import init_logger -from ChatTTS.vllm_engine.output import RequestOutput -from ChatTTS.vllm_engine.sampling_params import SamplingParams -from ChatTTS.vllm_engine.sequence import (SamplerOutput, Sequence, SequenceGroup, +from ChatTTS.model.velocity.output import RequestOutput +from ChatTTS.model.velocity.sampling_params import SamplingParams +from ChatTTS.model.velocity.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) @@ -123,7 +123,7 @@ def __init__( def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from ChatTTS.vllm_engine.worker import Worker + from ChatTTS.model.velocity.worker import Worker assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") diff --git a/ChatTTS/vllm_engine/model_loader.py b/ChatTTS/model/velocity/model_loader.py similarity index 96% rename from ChatTTS/vllm_engine/model_loader.py rename to ChatTTS/model/velocity/model_loader.py index cb189482c..bb4605875 100644 --- a/ChatTTS/vllm_engine/model_loader.py +++ b/ChatTTS/model/velocity/model_loader.py @@ -22,7 +22,7 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - model_cls = getattr(importlib.import_module("ChatTTS.vllm_engine.llama"), "LlamaModel", None) + model_cls = getattr(importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None) return model_cls def get_model(model_config: ModelConfig) -> nn.Module: diff --git a/ChatTTS/vllm_engine/model_runner.py b/ChatTTS/model/velocity/model_runner.py similarity index 98% rename from ChatTTS/vllm_engine/model_runner.py rename to ChatTTS/model/velocity/model_runner.py index c3db34cc5..86ed4a730 100644 --- a/ChatTTS/vllm_engine/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -5,16 +5,16 @@ import torch import torch.nn as nn -from ChatTTS.vllm_engine.configs import ModelConfig, ParallelConfig, SchedulerConfig +from ChatTTS.model.velocity.configs import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger -from ChatTTS.vllm_engine.model_loader import get_model +from ChatTTS.model.velocity.model_loader import get_model from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast, broadcast_object_list) -from ChatTTS.vllm_engine.sampling_params import SamplingParams, SamplingType -from ChatTTS.vllm_engine.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput +from ChatTTS.model.velocity.sampling_params import SamplingParams, SamplingType +from ChatTTS.model.velocity.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput from vllm.utils import in_wsl -from ChatTTS.vllm_engine.post_model import Post_model, Sampler +from ChatTTS.model.velocity.post_model import Post_model, Sampler from safetensors.torch import safe_open logger = init_logger(__name__) diff --git a/ChatTTS/vllm_engine/output.py b/ChatTTS/model/velocity/output.py similarity index 98% rename from ChatTTS/vllm_engine/output.py rename to ChatTTS/model/velocity/output.py index c08edde70..ea3c81d80 100644 --- a/ChatTTS/vllm_engine/output.py +++ b/ChatTTS/model/velocity/output.py @@ -1,7 +1,7 @@ from typing import List, Optional import torch -from ChatTTS.vllm_engine.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, +from ChatTTS.model.velocity.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) diff --git a/ChatTTS/vllm_engine/post_model.py b/ChatTTS/model/velocity/post_model.py similarity index 100% rename from ChatTTS/vllm_engine/post_model.py rename to ChatTTS/model/velocity/post_model.py diff --git a/ChatTTS/vllm_engine/sampling_params.py b/ChatTTS/model/velocity/sampling_params.py similarity index 100% rename from ChatTTS/vllm_engine/sampling_params.py rename to ChatTTS/model/velocity/sampling_params.py diff --git a/ChatTTS/vllm_engine/scheduler.py b/ChatTTS/model/velocity/scheduler.py similarity index 99% rename from ChatTTS/vllm_engine/scheduler.py rename to ChatTTS/model/velocity/scheduler.py index 27f5752a7..e93ca7fd6 100644 --- a/ChatTTS/vllm_engine/scheduler.py +++ b/ChatTTS/model/velocity/scheduler.py @@ -3,10 +3,10 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig -from ChatTTS.vllm_engine.block_manager import AllocStatus, BlockSpaceManager +from ChatTTS.model.velocity.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger -from ChatTTS.vllm_engine.sequence import (Sequence, SequenceData, SequenceGroup, +from ChatTTS.model.velocity.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) logger = init_logger(__name__) diff --git a/ChatTTS/vllm_engine/sequence.py b/ChatTTS/model/velocity/sequence.py similarity index 99% rename from ChatTTS/vllm_engine/sequence.py rename to ChatTTS/model/velocity/sequence.py index a417fe7f5..f5c2a09ad 100644 --- a/ChatTTS/vllm_engine/sequence.py +++ b/ChatTTS/model/velocity/sequence.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union import torch from vllm.block import LogicalTokenBlock -from ChatTTS.vllm_engine.sampling_params import SamplingParams +from ChatTTS.model.velocity.sampling_params import SamplingParams PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] diff --git a/ChatTTS/vllm_engine/worker.py b/ChatTTS/model/velocity/worker.py similarity index 99% rename from ChatTTS/vllm_engine/worker.py rename to ChatTTS/model/velocity/worker.py index 84e5c85d8..0162302bf 100644 --- a/ChatTTS/vllm_engine/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -14,7 +14,7 @@ initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine -from ChatTTS.vllm_engine.model_runner import ModelRunner +from ChatTTS.model.velocity.model_runner import ModelRunner class Worker: From 776f2c44106d0b7fede218427c622e9916afc633 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:20:14 +0900 Subject: [PATCH 05/27] chore(format): run black on dev (#614) Co-authored-by: github-actions[bot] --- ChatTTS/core.py | 97 +++-- ChatTTS/model/velocity/block_manager.py | 43 +- ChatTTS/model/velocity/configs.py | 504 +++++++++++++--------- ChatTTS/model/velocity/llama.py | 131 +++--- ChatTTS/model/velocity/llm.py | 37 +- ChatTTS/model/velocity/llm_engine.py | 306 +++++++------ ChatTTS/model/velocity/model_loader.py | 34 +- ChatTTS/model/velocity/model_runner.py | 348 ++++++++------- ChatTTS/model/velocity/output.py | 67 +-- ChatTTS/model/velocity/post_model.py | 245 +++++------ ChatTTS/model/velocity/sampling_params.py | 87 ++-- ChatTTS/model/velocity/scheduler.py | 73 ++-- ChatTTS/model/velocity/sequence.py | 104 +++-- ChatTTS/model/velocity/worker.py | 67 +-- test.py | 22 +- 15 files changed, 1197 insertions(+), 968 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 94659b52a..bfc728437 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -291,33 +291,44 @@ def _load( self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") - if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) self.num_vq = 4 if not os.path.exists("asset/vllm_model"): gpt = GPT( - **cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger + **cfg, + use_flash_attn=use_flash_attn, + device=device, + logger=self.logger, ).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) + gpt.load_state_dict( + torch.load(gpt_ckpt_path, weights_only=True, mmap=True) + ) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") - self.post_model = Post_model( - cfg.gpt_config.hidden_size, - cfg.num_audio_tokens, - cfg.num_text_tokens, - device = device - ).to(device).eval() - + self.post_model = ( + Post_model( + cfg.gpt_config.hidden_size, + cfg.num_audio_tokens, + cfg.num_text_tokens, + device=device, + ) + .to(device) + .eval() + ) + self.post_model.emb_code = self.gpt.emb_code self.post_model.emb_text = self.gpt.emb_text self.post_model.head_text = self.gpt.head_text self.post_model.head_code = self.gpt.head_code - save_file(self.post_model.state_dict(), "asset/vllm_model/post_model.safetensors") - + save_file( + self.post_model.state_dict(), + "asset/vllm_model/post_model.safetensors", + ) + self.num_audio_tokens = cfg.num_audio_tokens spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") assert os.path.exists( @@ -331,15 +342,15 @@ def _load( ) self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) self.logger.log(logging.INFO, "gpt loaded.") - + self.hidden_size = cfg.gpt_config.hidden_size self.gpt = LLM( model="asset/vllm_model/gpt", - num_audio_tokens = cfg.num_audio_tokens, - num_text_tokens = cfg.num_text_tokens, + num_audio_tokens=cfg.num_audio_tokens, + num_text_tokens=cfg.num_text_tokens, post_model_path="asset/vllm_model/post_model.safetensors", ) - + if dvae_config_path: cfg = OmegaConf.load(dvae_config_path) dvae = DVAE(**cfg, coef=coef).to(device).eval() @@ -369,7 +380,7 @@ def _load( self.coef = coef return self.has_loaded() - + def _infer( self, text, @@ -506,7 +517,7 @@ def destroy(self): del_all(self.ids) # del_all(self.attentions) # del_all(self.hiddens) - + @torch.no_grad() def _infer_code( self, @@ -548,7 +559,7 @@ def _infer_code( text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] else: text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - + input_ids, attention_mask, text_mask = self.tokenizer.encode( text, self.num_vq, @@ -556,7 +567,7 @@ def _infer_code( device=self.device, ) start_idx = input_ids.shape[-2] - + num_code = self.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( @@ -565,34 +576,35 @@ def _infer_code( top_K=params.top_K, repetition_penalty=params.repetition_penalty, ) - + sample_params = SamplingParams( temperature=temperature, max_new_token=params.max_new_token, - max_tokens = 8192, + max_tokens=8192, min_new_token=params.min_new_token, logits_processors=(logits_warpers, logits_processors), - eos_token = num_code, + eos_token=num_code, infer_text=False, - start_idx=start_idx + start_idx=start_idx, ) input_ids = [i.tolist() for i in input_ids] - + result = gpt.generate( None, sample_params, input_ids, ) - + token_ids = [] hidden_states = [] for i in result: token_ids.append(torch.tensor(i.outputs[0].token_ids)) - hidden_states.append(i.outputs[0].hidden_states.to(torch.float32).to(self.device)) - return [self.GenerationOutputs( - ids=token_ids, - hiddens=hidden_states - ),] + hidden_states.append( + i.outputs[0].hidden_states.to(torch.float32).to(self.device) + ) + return [ + self.GenerationOutputs(ids=token_ids, hiddens=hidden_states), + ] @torch.no_grad() def _refine_text( @@ -602,7 +614,7 @@ def _refine_text( params: RefineTextParams, ): - gpt:LLM = self.gpt + gpt: LLM = self.gpt if not isinstance(text, list): text = [text] @@ -614,7 +626,7 @@ def _refine_text( self.num_vq, device=self.device, ) - + start_idx = input_ids.shape[-2] # print(start_idx) logits_warpers, logits_processors = gen_logits( @@ -627,26 +639,19 @@ def _refine_text( sample_params = SamplingParams( temperature=params.temperature, max_new_token=params.max_new_token, - max_tokens = 8192, + max_tokens=8192, min_new_token=params.min_new_token, logits_processors=(logits_warpers, logits_processors), - eos_token = self.tokenizer.eos_token, + eos_token=self.tokenizer.eos_token, infer_text=True, - start_idx=start_idx + start_idx=start_idx, ) input_ids = [i.tolist() for i in input_ids] - - result = gpt.generate( - None, - sample_params, - input_ids - ) + + result = gpt.generate(None, sample_params, input_ids) token_ids = [] hidden_states = [] for i in result: token_ids.append(torch.tensor(i.outputs[0].token_ids)) hidden_states.append(i.outputs[0].hidden_states) - return self.GenerationOutputs( - ids=token_ids, - hiddens=hidden_states - ) + return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states) diff --git a/ChatTTS/model/velocity/block_manager.py b/ChatTTS/model/velocity/block_manager.py index 199a3a278..ad69aa1b9 100644 --- a/ChatTTS/model/velocity/block_manager.py +++ b/ChatTTS/model/velocity/block_manager.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" + import enum from typing import Dict, List, Optional, Set, Tuple @@ -31,9 +32,9 @@ def __init__( # Initialize the free blocks. self.free_blocks: BlockTable = [] for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size) + block = PhysicalTokenBlock( + device=device, block_number=i, block_size=block_size + ) self.free_blocks.append(block) def allocate(self) -> PhysicalTokenBlock: @@ -63,6 +64,7 @@ class AllocStatus(enum.Enum): 3. Never: seq_group can never be allocated. The seq_group is too large to allocated in GPU. """ + OK = enum.auto() LATER = enum.auto() NEVER = enum.auto() @@ -85,18 +87,15 @@ def __init__( self.block_sliding_window = None if sliding_window is not None: - assert sliding_window % block_size == 0, (sliding_window, - block_size) + assert sliding_window % block_size == 0, (sliding_window, block_size) self.block_sliding_window = sliding_window // block_size self.watermark = watermark assert watermark >= 0.0 self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = BlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -106,13 +105,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) if self.block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): + if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: return AllocStatus.OK @@ -127,8 +124,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] for logical_idx in range(len(seq.logical_token_blocks)): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): + if ( + self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window + ): block = block_table[logical_idx % self.block_sliding_window] else: block = self.gpu_allocator.allocate() @@ -153,11 +152,14 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] if len(block_table) < len(logical_blocks): - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): + if ( + self.block_sliding_window + and len(block_table) >= self.block_sliding_window + ): # re-use a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) + block_table.append( + block_table[len(block_table) % self.block_sliding_window] + ) else: # The sequence has a new logical block. # Allocate a new physical block. @@ -188,7 +190,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: block.ref_count += 1 def _get_physical_blocks( - self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + self, seq_group: SequenceGroup + ) -> List[PhysicalTokenBlock]: # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() diff --git a/ChatTTS/model/velocity/configs.py b/ChatTTS/model/velocity/configs.py index 30d6c9afa..c578f468a 100644 --- a/ChatTTS/model/velocity/configs.py +++ b/ChatTTS/model/velocity/configs.py @@ -79,7 +79,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, num_audio_tokens: int = 1024, - num_text_tokens: int = 80 + num_text_tokens: int = 80, ) -> None: self.model = model self.tokenizer = tokenizer @@ -95,22 +95,24 @@ def __init__( self.max_context_len_to_capture = max_context_len_to_capture self.num_audio_tokens = num_audio_tokens self.num_text_tokens = num_text_tokens - + if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. - from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C - model_path = snapshot_download(model_id=model, - cache_dir=download_dir, - revision=revision) + from modelscope.hub.snapshot_download import ( + snapshot_download, + ) # pylint: disable=C + + model_path = snapshot_download( + model_id=model, cache_dir=download_dir, revision=revision + ) self.model = model_path self.download_dir = model_path self.tokenizer = model_path self.hf_config = get_config(self.model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() @@ -118,30 +120,32 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" - ] + supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"] rocm_not_supported_load_format = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." + ) if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ - f for f in supported_load_format + f + for f in supported_load_format if (f not in rocm_not_supported_load_format) ] raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " + f"load format '{load_format}' is not supported in ROCm. " f"Supported load format are " - f"{rocm_supported_load_format}") + f"{rocm_supported_load_format}" + ) # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) if "MixtralForCausalLM" in architectures and load_format == "pt": raise ValueError( "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") + "Please use the 'safetensors' format instead. " + ) self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -149,7 +153,8 @@ def _verify_tokenizer_mode(self) -> None: if tokenizer_mode not in ["auto", "slow"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto' or 'slow'.") + "either 'auto' or 'slow'." + ) self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: @@ -169,27 +174,32 @@ def _verify_quantization(self) -> None: "Quantization method specified in the model config " f"({hf_quant_method}) does not match the quantization " f"method specified in the `quantization` argument " - f"({self.quantization}).") + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") - if is_hip( - ) and self.quantization in rocm_not_supported_quantization: + f"be one of {supported_quantization}." + ) + if is_hip() and self.quantization in rocm_not_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not supported " - f"in ROCm.") - logger.warning(f"{self.quantization} quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.") + f"in ROCm." + ) + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models." + ) def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + self.max_context_len_to_capture = min( + self.max_context_len_to_capture, self.max_model_len + ) def verify_with_parallel_config( self, @@ -201,7 +211,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size})." + ) total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size @@ -209,7 +220,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " "must be divisible by pipeline parallel size " - f"({pipeline_parallel_size}).") + f"({pipeline_parallel_size})." + ) def get_sliding_window(self) -> Optional[int]: return getattr(self.hf_config, "sliding_window", None) @@ -233,9 +245,11 @@ def get_total_num_kv_heads(self) -> int: falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_config, - "multi_query", False): + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 @@ -265,8 +279,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers @@ -304,7 +317,8 @@ def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + f"{self.gpu_memory_utilization}." + ) def verify_with_parallel_config( self, @@ -316,9 +330,11 @@ def verify_with_parallel_config( num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " - f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " - "allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: @@ -355,8 +371,7 @@ def __init__( def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is not supported yet.") + raise NotImplementedError("Pipeline parallelism is not supported yet.") class SchedulerConfig: @@ -398,12 +413,14 @@ def _verify_args(self) -> None: "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) _STR_DTYPE_TO_TORCH_DTYPE = { @@ -447,11 +464,14 @@ def _get_and_verify_dtype( if is_hip() and torch_dtype == torch.float32: rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + k + for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE) ] - raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") + raise ValueError( + f"dtype '{dtype}' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}" + ) # Verify the dtype. if torch_dtype != config_dtype: @@ -502,7 +522,8 @@ def _get_and_verify_max_len( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " f"{possible_keys}. Assuming the model's maximum length is " - f"{default_max_len}.") + f"{default_max_len}." + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -510,8 +531,7 @@ def _get_and_verify_max_len( assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] + derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len *= scaling_factor if max_model_len is None: @@ -522,20 +542,22 @@ def _get_and_verify_max_len( f"the derived max_model_len ({max_len_key}={derived_max_model_len}" " in model's config.json). This may lead to incorrect model " "outputs or CUDA errors. Make sure the value is correct and " - "within the model context size.") + "within the model context size." + ) return int(max_model_len) @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str tokenizer: Optional[str] = None - tokenizer_mode: str = 'auto' + tokenizer_mode: str = "auto" trust_remote_code: bool = False download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' + load_format: str = "auto" + dtype: str = "auto" seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -556,14 +578,13 @@ class EngineArgs: max_context_len_to_capture: int = 8192 num_audio_tokens: int = 1024 num_text_tokens: int = 80 - + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" # NOTE: If you update any of the arguments below, please also @@ -571,162 +592,198 @@ def add_cli_args( # Model arguments parser.add_argument( - '--model', + "--model", type=str, - default='facebook/opt-125m', - help='name or path of the huggingface model to use') + default="facebook/opt-125m", + help="name or path of the huggingface model to use", + ) parser.add_argument( - '--tokenizer', + "--tokenizer", type=str, default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') + help="name or path of the huggingface tokenizer to use", + ) parser.add_argument( - '--revision', + "--revision", type=str, default=None, - help='the specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) parser.add_argument( - '--tokenizer-revision', + "--tokenizer-revision", type=str, default=None, - help='the specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument('--download-dir', - type=str, - default=EngineArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of ' - 'huggingface') + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="trust remote code from huggingface", + ) + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, " + "default to the default cache dir of " + "huggingface", + ) parser.add_argument( - '--load-format', + "--load-format", type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " '"pt" will load the weights in the pytorch bin format. ' '"safetensors" will load the weights in the safetensors format. ' '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' + "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + "which is mainly for profiling.", + ) parser.add_argument( - '--dtype', + "--dtype", type=str, default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--max-model-len', - type=int, - default=None, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, " + "will be automatically derived from the model.", + ) # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') parser.add_argument( - '--max-parallel-loading-workers', + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be " + "automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", type=int, - help='load model sequentially in multiple batches, ' - 'to avoid RAM OOM when using tensor ' - 'parallel and large models') + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) + parser.add_argument( + "--max-parallel-loading-workers", + type=int, + help="load model sequentially in multiple batches, " + "to avoid RAM OOM when using tensor " + "parallel and large models", + ) # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32], - help='token block size') + parser.add_argument( + "--block-size", + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32], + help="token block size", + ) # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', - type=int, - default=EngineArgs.seed, - help='random seed') - parser.add_argument('--swap-space', - type=int, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') parser.add_argument( - '--gpu-memory-utilization', + "--seed", type=int, default=EngineArgs.seed, help="random seed" + ) + parser.add_argument( + "--swap-space", + type=int, + default=EngineArgs.swap_space, + help="CPU swap space size (GiB) per GPU", + ) + parser.add_argument( + "--gpu-memory-utilization", type=float, default=EngineArgs.gpu_memory_utilization, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') - parser.add_argument('--disable-log-stats', - action='store_true', - help='disable logging statistics') + help="the fraction of GPU memory to be used for " + "the model executor, which can range from 0 to 1." + "If unspecified, will use the default value of 0.9.", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per " "iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument( + "--max-paddings", + type=int, + default=EngineArgs.max_paddings, + help="maximum number of paddings in a batch", + ) + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="disable logging statistics", + ) # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'gptq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') - parser.add_argument('--enforce-eager', - action='store_true', - help='Always use eager-mode PyTorch. If False, ' - 'will use eager mode and CUDA graph in hybrid ' - 'for maximal performance and flexibility.') - parser.add_argument('--max-context-len-to-capture', - type=int, - default=EngineArgs.max_context_len_to_capture, - help='maximum context length covered by CUDA ' - 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", "gptq", "squeezellm", None], + default=None, + help="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Always use eager-mode PyTorch. If False, " + "will use eager mode and CUDA graph in hybrid " + "for maximal performance and flexibility.", + ) + parser.add_argument( + "--max-context-len-to-capture", + type=int, + default=EngineArgs.max_context_len_to_capture, + help="maximum context length covered by CUDA " + "graphs. When a sequence has context length " + "larger than this, we fall back to eager mode.", + ) return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -736,52 +793,73 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: - model_config = ModelConfig(self.model, self.tokenizer, - self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, - self.tokenizer_revision, self.max_model_len, - self.quantization, self.enforce_eager, - self.max_context_len_to_capture, - self.num_audio_tokens, self.num_text_tokens, - ) - cache_config = CacheConfig(self.block_size, - self.gpu_memory_utilization, - self.swap_space, - model_config.get_sliding_window()) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config = ModelConfig( + self.model, + self.tokenizer, + self.tokenizer_mode, + self.trust_remote_code, + self.download_dir, + self.load_format, + self.dtype, + self.seed, + self.revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + self.enforce_eager, + self.max_context_len_to_capture, + self.num_audio_tokens, + self.num_text_tokens, + ) + cache_config = CacheConfig( + self.block_size, + self.gpu_memory_utilization, + self.swap_space, + model_config.get_sliding_window(), + ) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + ) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings, + ) return model_config, cache_config, parallel_config, scheduler_config @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + engine_use_ray: bool = False disable_log_requests: bool = False max_log_len: Optional[int] = None @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--engine-use-ray', - action='store_true', - help='use Ray to start the LLM engine in a ' - 'separate process as the server process.') - parser.add_argument('--disable-log-requests', - action='store_true', - help='disable logging requests') - parser.add_argument('--max-log-len', - type=int, - default=None, - help='max number of prompt characters or prompt ' - 'ID numbers being printed in log. ' - 'Default: unlimited.') + parser.add_argument( + "--engine-use-ray", + action="store_true", + help="use Ray to start the LLM engine in a " + "separate process as the server process.", + ) + parser.add_argument( + "--disable-log-requests", + action="store_true", + help="disable logging requests", + ) + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="max number of prompt characters or prompt " + "ID numbers being printed in log. " + "Default: unlimited.", + ) return parser diff --git a/ChatTTS/model/velocity/llama.py b/ChatTTS/model/velocity/llama.py index 415b09d86..8e6c8a896 100644 --- a/ChatTTS/model/velocity/llama.py +++ b/ChatTTS/model/velocity/llama.py @@ -31,19 +31,26 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, +) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -60,16 +67,19 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, linear_method=linear_method + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -136,10 +146,9 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads + ) def forward( self, @@ -168,8 +177,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -185,10 +193,10 @@ def __init__( hidden_act=config.hidden_act, linear_method=linear_method, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -203,8 +211,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -213,8 +220,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -234,10 +240,12 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -261,11 +269,13 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -276,15 +286,15 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -300,10 +310,10 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + class LlamaForCausalLM(nn.Module): def __init__( @@ -325,8 +335,7 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) return hidden_states def sample( @@ -334,15 +343,18 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, sampling_metadata + ) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -353,15 +365,15 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -377,6 +389,5 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index 9668c87cf..98a90af26 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -103,15 +103,14 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, - num_audio_tokens = num_audio_tokens, - num_text_tokens = num_text_tokens, + num_audio_tokens=num_audio_tokens, + num_text_tokens=num_text_tokens, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) self.request_counter = Counter() - def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer def set_tokenizer( @@ -146,28 +145,29 @@ def generate( completions in the same order as the input prompts. """ if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + raise ValueError("Either prompts or prompt_token_ids must be " "provided.") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if ( + prompts is not None + and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids) + ): + raise ValueError( + "The lengths of prompts and prompt_token_ids " "must be the same." + ) if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + num_requests = len(prompts) if prompts is not None else len(prompt_token_ids) for i in range(num_requests): prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] + token_ids = None if prompt_token_ids is None else prompt_token_ids[i] self._add_request(prompt, sampling_params, token_ids) - + rtns = self._run_engine(use_tqdm) for i, rtn in enumerate(rtns): token_ids = rtn.outputs[0].token_ids @@ -176,7 +176,7 @@ def generate( token_ids[j] = token_id[0] else: token_ids[j] = list(token_id) - + return rtns def _add_request( @@ -186,8 +186,9 @@ def _add_request( prompt_token_ids: Optional[List[int]], ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request( + request_id, prompt, sampling_params, prompt_token_ids + ) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py index 4a72c0c3f..66dd205ff 100644 --- a/ChatTTS/model/velocity/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -2,11 +2,9 @@ from collections import defaultdict import os import time -from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, - Union) +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig from ChatTTS.model.velocity.scheduler import Scheduler, SchedulerOutputs from ChatTTS.model.velocity.configs import EngineArgs from vllm.engine.metrics import record_metrics @@ -14,12 +12,18 @@ from vllm.logger import init_logger from ChatTTS.model.velocity.output import RequestOutput from ChatTTS.model.velocity.sampling_params import SamplingParams -from ChatTTS.model.velocity.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) +from ChatTTS.model.velocity.sequence import ( + SamplerOutput, + Sequence, + SequenceGroup, + SequenceGroupOutput, + SequenceOutput, + SequenceStatus, +) +from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port import numpy as np + if ray: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -85,7 +89,7 @@ def __init__( f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed}), " f"post_model_path={post_model_path!r}" - ) + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config @@ -125,8 +129,9 @@ def _init_workers(self): # before CUDA_VISIBLE_DEVICES is set in the Worker from ChatTTS.model.velocity.worker import Worker - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") + assert ( + self.parallel_config.world_size == 1 + ), "Ray is required if parallel_config.world_size > 1." self.workers: List[Worker] = [] distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" @@ -138,13 +143,12 @@ def _init_workers(self): rank=0, distributed_init_method=distributed_init_method, is_driver_worker=True, - post_model_path = self.post_model_path + post_model_path=self.post_model_path, ) self._run_workers("init_model") self._run_workers("load_model") - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: num_gpus = self.cache_config.gpu_memory_utilization else: @@ -181,20 +185,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " - "GPU node.") + "GPU node." + ) driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + self.driver_dummy_worker.get_node_and_gpu_ids.remote() + ) worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + [worker.get_node_and_gpu_ids.remote() for worker in self.workers] + ) node_workers = defaultdict(list) node_gpus = defaultdict(list) node_workers[driver_node_id].append(0) node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): @@ -216,10 +222,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) - for rank, (worker, (node_id, - _)) in enumerate(zip(self.workers, - worker_node_and_gpu_ids), - start=1): + for rank, (worker, (node_id, _)) in enumerate( + zip(self.workers, worker_node_and_gpu_ids), start=1 + ): local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( lambda rank=rank, local_rank=local_rank: Worker( @@ -229,7 +234,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank, rank, distributed_init_method, - )) + ) + ) driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) @@ -246,8 +252,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("init_model") self._run_workers( "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, ) def _verify_args(self) -> None: @@ -270,13 +275,16 @@ def _init_cache(self) -> None: num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) # FIXME(woosuk): Change to debug log. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info( + f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}" + ) if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_seq_len = self.cache_config.block_size * num_gpu_blocks if self.model_config.max_model_len > max_seq_len: raise ValueError( @@ -284,7 +292,8 @@ def _init_cache(self) -> None: "is larger than the maximum number of tokens that can be " f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + "initializing the engine." + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -296,7 +305,9 @@ def _init_cache(self) -> None: self._run_workers("warm_up_model") @classmethod - def from_engine_args(cls, engine_args: EngineArgs, post_model_path=None) -> "LLMEngine": + def from_engine_args( + cls, engine_args: EngineArgs, post_model_path=None + ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_configs = engine_args.create_engine_configs() @@ -304,11 +315,12 @@ def from_engine_args(cls, engine_args: EngineArgs, post_model_path=None) -> "LLM # Initialize the cluster. placement_group = initialize_cluster(parallel_config) # Create the LLM engine. - engine = cls(*engine_configs, - placement_group, - log_stats=not engine_args.disable_log_stats, - post_model_path = post_model_path - ) + engine = cls( + *engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats, + post_model_path=post_model_path, + ) return engine def add_request( @@ -337,7 +349,7 @@ def add_request( """ if arrival_time is None: arrival_time = time.monotonic() - + assert prompt_token_ids is not None, "prompt_token_ids must be provided" # Create the sequences. block_size = self.cache_config.block_size @@ -345,8 +357,7 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -383,13 +394,13 @@ def _check_beam_search_early_stopping( if early_stopping is True: return True - current_worst_score = (current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -397,26 +408,27 @@ def _check_beam_search_early_stopping( # sequences. The highest attainable score calculation is # based on the longest possible sequence length in this case. max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - seq_len=max_possible_length)) + best_running_seq.get_prompt_len() + sampling_params.max_tokens, + self.scheduler_config.max_model_len, + ) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length, + ) else: # Otherwise, beam search will prefer shorter sequences. The # highest attainable score calculation is based on the current # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + ) return current_worst_score >= highest_attainable_score - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + def _process_sequence_group_outputs( + self, seq_group: SequenceGroup, outputs: SequenceGroupOutput + ) -> None: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: @@ -426,10 +438,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } + parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs} for sample in samples: parent_child_dict[sample.parent_seq_id].append(sample) # List of (child, parent) @@ -437,8 +446,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process the child samples for each parent sequence for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] + child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id] if len(child_samples) == 0: # This parent sequence has no children samples. Remove # the parent sequence from the sequence group since it will @@ -451,27 +459,29 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs, - child_sample.hidden_states, - child_sample.finished - ) + child.append_token_id( + child_sample.output_token, + child_sample.logprobs, + child_sample.hidden_states, + child_sample.finished, + ) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs, - last_child_sample.hidden_states, - last_child_sample.finished - ) + parent.append_token_id( + last_child_sample.output_token, + last_child_sample.logprobs, + last_child_sample.hidden_states, + last_child_sample.finished, + ) child_seqs.append((parent, parent)) for seq, _ in child_seqs: # self._decode_sequence(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) - + # Non-beam search case if not seq_group.sampling_params.use_beam_search: # For newly created child sequences, add them to the sequence group @@ -501,16 +511,18 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the newly finished sequences with the highest scores # to replace existing finished sequences. # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] + existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs] + new_finished_seqs = [ + (seq, parent, True) for seq, parent in child_seqs if seq.is_finished() + ] all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + all_finished_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: # A newly generated child sequence finishes and has a high @@ -532,13 +544,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # select the top beam_width sequences from the running # sequences for the next iteration to continue the beam # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] + running_child_seqs = [ + (seq, parent) for seq, parent in child_seqs if not seq.is_finished() + ] # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + running_child_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) # Check if we can stop the beam search. if len(running_child_seqs) == 0: @@ -553,7 +568,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + seq_group.sampling_params, + best_running_seq, + current_worst_seq, + ) if stop_beam_search: # Stop the beam search and remove all the running sequences from @@ -593,8 +611,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, self.scheduler.free_seq(seq) def _process_model_outputs( - self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs + ) -> List[RequestOutput]: # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups for seq_group, outputs in zip(scheduled_seq_groups, output): @@ -605,15 +623,15 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in (scheduled_seq_groups + - scheduler_outputs.ignored_seq_groups): + for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) if self.log_stats: # Log the system stats. - self._log_system_stats(scheduler_outputs.prompt_run, - scheduler_outputs.num_batched_tokens) + self._log_system_stats( + scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens + ) return request_outputs def step(self) -> List[RequestOutput]: @@ -626,7 +644,7 @@ def step(self) -> List[RequestOutput]: the sequences and returns the newly generated results. """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - + if not scheduler_outputs.is_empty(): # Execute the model. all_outputs = self._run_workers( @@ -636,7 +654,8 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) + }, + ) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -662,11 +681,14 @@ def _log_system_stats( return # Discard the old stats. - self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens - if now - t < _LOGGING_INTERVAL_SEC] - self.num_generation_tokens = [(t, n) - for t, n in self.num_generation_tokens - if now - t < _LOGGING_INTERVAL_SEC] + self.num_prompt_tokens = [ + (t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC + ] + self.num_generation_tokens = [ + (t, n) + for t, n in self.num_generation_tokens + if now - t < _LOGGING_INTERVAL_SEC + ] if len(self.num_prompt_tokens) > 1: total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) @@ -675,23 +697,20 @@ def _log_system_stats( else: avg_prompt_throughput = 0.0 if len(self.num_generation_tokens) > 1: - total_num_tokens = sum(n - for _, n in self.num_generation_tokens[:-1]) + total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1]) window = now - self.num_generation_tokens[0][0] avg_generation_throughput = total_num_tokens / window else: avg_generation_throughput = 0.0 total_num_gpu_blocks = self.cache_config.num_gpu_blocks - num_free_gpu_blocks = ( - self.scheduler.block_manager.get_num_free_gpu_blocks()) + num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks() num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks if total_num_cpu_blocks > 0: - num_free_cpu_blocks = ( - self.scheduler.block_manager.get_num_free_cpu_blocks()) + num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks() num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks else: @@ -707,29 +726,32 @@ def _log_system_stats( cpu_cache_usage=cpu_cache_usage, ) - logger.info("Avg prompt throughput: " - f"{avg_prompt_throughput:.1f} tokens/s, " - "Avg generation throughput: " - f"{avg_generation_throughput:.1f} tokens/s, " - f"Running: {len(self.scheduler.running)} reqs, " - f"Swapped: {len(self.scheduler.swapped)} reqs, " - f"Pending: {len(self.scheduler.waiting)} reqs, " - f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " - f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") + logger.info( + "Avg prompt throughput: " + f"{avg_prompt_throughput:.1f} tokens/s, " + "Avg generation throughput: " + f"{avg_generation_throughput:.1f} tokens/s, " + f"Running: {len(self.scheduler.running)} reqs, " + f"Swapped: {len(self.scheduler.swapped)} reqs, " + f"Pending: {len(self.scheduler.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%" + ) self.last_logging_time = now def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) + (new_tokens, new_output_text, prefix_offset, read_offset) = ( + detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + ) if seq.tokens is None: seq.tokens = new_tokens else: @@ -738,21 +760,20 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: seq.read_offset = read_offset seq.output_text += new_output_text - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: + def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): if not sampling_params.include_stop_str_in_output: # Truncate the output text so that the stop string is # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + seq.output_text = seq.output_text[: -len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return if seq.data.finished: seq.status = SequenceStatus.FINISHED_STOPPED return - + for token_id in seq.get_last_token_id(): if token_id == sampling_params.eos_token: seq.status = SequenceStatus.FINISHED_STOPPED @@ -769,11 +790,12 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id()[0] == sampling_params.eos_token): + if (not sampling_params.ignore_eos) and seq.get_last_token_id()[ + 0 + ] == sampling_params.eos_token: seq.status = SequenceStatus.FINISHED_STOPPED return - + def _run_workers( self, method: str, @@ -786,8 +808,7 @@ def _run_workers( """Runs the given method on all workers.""" if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") + raise NotImplementedError("max_concurrent_workers is not supported yet.") # Start the ray workers first. ray_worker_outputs = [ @@ -801,8 +822,9 @@ def _run_workers( driver_kwargs = kwargs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) + driver_worker_output = getattr(self.driver_worker, method)( + *driver_args, **driver_kwargs + ) # Get the results of the ray workers. if self.workers: diff --git a/ChatTTS/model/velocity/model_loader.py b/ChatTTS/model/velocity/model_loader.py index bb4605875..40de6d960 100644 --- a/ChatTTS/model/velocity/model_loader.py +++ b/ChatTTS/model/velocity/model_loader.py @@ -1,4 +1,5 @@ """Utilities for selecting and loading models.""" + import contextlib from typing import Type @@ -8,10 +9,10 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) +from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights import importlib + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" @@ -22,19 +23,24 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - model_cls = getattr(importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None) + model_cls = getattr( + importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None + ) return model_cls + def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. linear_method = None if model_config.quantization is not None: - quant_config = get_quant_config(model_config.quantization, - model_config.model, - model_config.hf_config, - model_config.download_dir) + quant_config = get_quant_config( + model_config.quantization, + model_config.model, + model_config.hf_config, + model_config.download_dir, + ) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): @@ -42,13 +48,15 @@ def get_model(model_config: ModelConfig) -> nn.Module: f"The quantization method {model_config.quantization} is not " "supported for the current GPU. " f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + f"{supported_dtypes}" + ) linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): @@ -62,6 +70,10 @@ def get_model(model_config: ModelConfig) -> nn.Module: initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 86ed4a730..5b0f2c2d8 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -10,9 +10,17 @@ from ChatTTS.model.velocity.model_loader import get_model from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( - broadcast, broadcast_object_list) + broadcast, + broadcast_object_list, +) from ChatTTS.model.velocity.sampling_params import SamplingParams, SamplingType -from ChatTTS.model.velocity.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput +from ChatTTS.model.velocity.sequence import ( + SamplerOutput, + SequenceData, + SequenceGroupMetadata, + SequenceGroupOutput, + SequenceOutput, +) from vllm.utils import in_wsl from ChatTTS.model.velocity.post_model import Post_model, Sampler from safetensors.torch import safe_open @@ -34,18 +42,19 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, is_driver_worker: bool = False, - post_model_path: str = None + post_model_path: str = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.is_driver_worker = is_driver_worker self.post_model_path = post_model_path - + # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) + self.sliding_window = ( + model_config.get_sliding_window() if model_config is not None else None + ) self.model = None self.block_size = None # Set after initial profiling. @@ -54,7 +63,9 @@ def __init__( self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + if self.model_config is not None + else 0 + ) # When using CUDA graph, the input block tables must be padded to # max_context_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -68,28 +79,27 @@ def __init__( def load_model(self) -> None: self.model = get_model(self.model_config) self.post_model = Post_model( - self.model_config.get_hidden_size(), - self.model_config.num_audio_tokens, - self.model_config.num_text_tokens - ) + self.model_config.get_hidden_size(), + self.model_config.num_audio_tokens, + self.model_config.num_text_tokens, + ) state_dict_tensors = {} with safe_open(self.post_model_path, framework="pt", device=0) as f: for k in f.keys(): state_dict_tensors[k] = f.get_tensor(k) self.post_model.load_state_dict(state_dict_tensors) self.post_model.to(next(self.model.parameters())).eval() - self.sampler = Sampler( - self.post_model, - self.model_config.num_audio_tokens, - 4 - ) + self.sampler = Sampler(self.post_model, self.model_config.num_audio_tokens, 4) + def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size + max_num_blocks = ( + self.max_context_len_to_capture + block_size - 1 + ) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32 + ) def _prepare_prompt( self, @@ -145,18 +155,15 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(prompt_lens) - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long) + input_tokens = _make_tensor_with_pad( + input_tokens, max_prompt_len, pad=0, dtype=torch.long + ) + input_positions = _make_tensor_with_pad( + input_positions, max_prompt_len, pad=0, dtype=torch.long + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long + ) input_metadata = InputMetadata( is_prompt=True, @@ -192,8 +199,11 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + context_len = ( + seq_len + if self.sliding_window is None + else min(seq_len, self.sliding_window) + ) context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -203,8 +213,7 @@ def _prepare_decode( slot_mapping.append([slot]) if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) + sliding_window_blocks = self.sliding_window // self.block_size block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) @@ -213,7 +222,8 @@ def _prepare_decode( use_captured_graph = ( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + and max_context_len <= self.max_context_len_to_capture + ) if use_captured_graph: # Pad the input tokens, positions, and slot mapping to match the # batch size of the captured graph. @@ -227,24 +237,16 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device="cuda") - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device="cuda") - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device="cuda") - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") + input_tokens = _make_tensor_with_pad( + input_tokens, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + input_positions = _make_tensor_with_pad( + input_positions, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, device="cuda" + ) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") if use_captured_graph: # The shape of graph_block_tables is @@ -252,7 +254,7 @@ def _prepare_decode( input_block_tables = self.graph_block_tables[:batch_size] for i, block_table in enumerate(block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + input_block_tables[i, : len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device="cuda") else: block_tables = _make_tensor_with_pad( @@ -297,34 +299,38 @@ def _prepare_sample( # NOTE: prompt token positions do not need sample, skip categorized_sample_indices_start_idx += prompt_len - 1 - categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx + ) categorized_sample_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) + range( + selected_token_start_idx, + selected_token_start_idx + prompt_len - 1, + ) + ) + selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_prompt_len else: num_seqs = len(seq_ids) selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) + range(selected_token_start_idx, selected_token_start_idx + num_seqs) + ) selected_token_start_idx += num_seqs - categorized_sample_indices[ - sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices[sampling_params.sampling_type].extend( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs, + ) + ) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = _async_h2d(selected_token_indices, - dtype=torch.long, - pin_memory=not self.in_wsl) + selected_token_indices = _async_h2d( + selected_token_indices, dtype=torch.long, pin_memory=not self.in_wsl + ) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() @@ -353,14 +359,17 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, prompt_lens) = ( + self._prepare_prompt(seq_group_metadata_list) + ) else: - (input_tokens, input_positions, input_metadata - ) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata) = self._prepare_decode( + seq_group_metadata_list + ) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = self._prepare_sample( + seq_group_metadata_list, prompt_lens + ) def get_size_or_none(x: Optional[torch.Tensor]): return x.size() if x is not None else None @@ -369,24 +378,15 @@ def get_size_or_none(x: Optional[torch.Tensor]): # its shape and then broadcast the tensor to avoid high # serialization cost. py_data = { - "input_tokens_size": - input_tokens.size(), - "input_positions_size": - input_positions.size(), - "is_prompt": - input_metadata.is_prompt, - "slot_mapping_size": - get_size_or_none(input_metadata.slot_mapping), - "max_context_len": - input_metadata.max_context_len, - "context_lens_size": - get_size_or_none(input_metadata.context_lens), - "block_tables_size": - get_size_or_none(input_metadata.block_tables), - "use_cuda_graph": - input_metadata.use_cuda_graph, - "selected_token_indices_size": - sampling_metadata.selected_token_indices.size(), + "input_tokens_size": input_tokens.size(), + "input_positions_size": input_positions.size(), + "is_prompt": input_metadata.is_prompt, + "slot_mapping_size": get_size_or_none(input_metadata.slot_mapping), + "max_context_len": input_metadata.max_context_len, + "context_lens_size": get_size_or_none(input_metadata.context_lens), + "block_tables_size": get_size_or_none(input_metadata.block_tables), + "use_cuda_graph": input_metadata.use_cuda_graph, + "selected_token_indices_size": sampling_metadata.selected_token_indices.size(), } broadcast_object_list([py_data], src=0) # TODO(zhuohan): Combine the broadcasts or set async_op=True. @@ -403,39 +403,38 @@ def get_size_or_none(x: Optional[torch.Tensor]): receving_list = [None] broadcast_object_list(receving_list, src=0) py_data = receving_list[0] - input_tokens = torch.empty(*py_data["input_tokens_size"], - dtype=torch.long, - device="cuda") + input_tokens = torch.empty( + *py_data["input_tokens_size"], dtype=torch.long, device="cuda" + ) broadcast(input_tokens, src=0) - input_positions = torch.empty(*py_data["input_positions_size"], - dtype=torch.long, - device="cuda") + input_positions = torch.empty( + *py_data["input_positions_size"], dtype=torch.long, device="cuda" + ) broadcast(input_positions, src=0) if py_data["slot_mapping_size"] is not None: - slot_mapping = torch.empty(*py_data["slot_mapping_size"], - dtype=torch.long, - device="cuda") + slot_mapping = torch.empty( + *py_data["slot_mapping_size"], dtype=torch.long, device="cuda" + ) broadcast(slot_mapping, src=0) else: slot_mapping = None if py_data["context_lens_size"] is not None: - context_lens = torch.empty(*py_data["context_lens_size"], - dtype=torch.int, - device="cuda") + context_lens = torch.empty( + *py_data["context_lens_size"], dtype=torch.int, device="cuda" + ) broadcast(context_lens, src=0) else: context_lens = None if py_data["block_tables_size"] is not None: - block_tables = torch.empty(*py_data["block_tables_size"], - dtype=torch.int, - device="cuda") + block_tables = torch.empty( + *py_data["block_tables_size"], dtype=torch.int, device="cuda" + ) broadcast(block_tables, src=0) else: block_tables = None selected_token_indices = torch.empty( - *py_data["selected_token_indices_size"], - dtype=torch.long, - device="cuda") + *py_data["selected_token_indices_size"], dtype=torch.long, device="cuda" + ) broadcast(selected_token_indices, src=0) input_metadata = InputMetadata( is_prompt=py_data["is_prompt"], @@ -463,7 +462,8 @@ def execute_model( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: input_tokens, input_positions, input_metadata, sampling_metadata = ( - self.prepare_input_tensors(seq_group_metadata_list)) + self.prepare_input_tensors(seq_group_metadata_list) + ) # print(sampling_metadata.seq_data) seq_groups = [] input_tokens_history = [] @@ -476,14 +476,16 @@ def execute_model( else: tokens_history = [list(token) for token in tokens_history] input_tokens_history.append(tokens_history) - input_tokens_history = torch.tensor(input_tokens_history).to(input_tokens.device) - # token_ids = rtn.outputs[0].token_ids - # for j, token_id in enumerate(token_ids): - # if len(token_id) == 1: - # token_ids[j] = token_id[0] - # else: - # token_ids[j] = list(token_id) - + input_tokens_history = torch.tensor(input_tokens_history).to( + input_tokens.device + ) + # token_ids = rtn.outputs[0].token_ids + # for j, token_id in enumerate(token_ids): + # if len(token_id) == 1: + # token_ids[j] = token_id[0] + # else: + # token_ids[j] = list(token_id) + # Execute the model. # print("it1",input_tokens) if len(input_tokens.shape) == 2: @@ -494,25 +496,29 @@ def execute_model( # print("it2",input_tokens.shape) text_mask = input_tokens != 0 text_mask = text_mask[:, :, 0] - + if input_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - + infer_text = sampling_metadata.seq_groups[0][1].infer_text temperture = sampling_metadata.seq_groups[0][1].temperature if not infer_text: temperture = torch.tensor(temperture).to(input_tokens.device) - logits_processors, logits_warpers = sampling_metadata.seq_groups[0][1].logits_processors + logits_processors, logits_warpers = sampling_metadata.seq_groups[0][ + 1 + ].logits_processors # print(logits_processors, logits_warpers) min_new_token = sampling_metadata.seq_groups[0][1].min_new_token eos_token = sampling_metadata.seq_groups[0][1].eos_token start_idx = sampling_metadata.seq_groups[0][1].start_idx if input_tokens.shape[-2] == 1: if infer_text: - input_emb: torch.Tensor = self.post_model.emb_text(input_tokens[:, :, 0]) + input_emb: torch.Tensor = self.post_model.emb_text( + input_tokens[:, :, 0] + ) else: code_emb = [ self.post_model.emb_code[i](input_tokens[:, :, i]) @@ -531,7 +537,11 @@ def execute_model( # print(hidden_states.shape) # print(input_tokens) idx_next, logprob, finish = self.sampler.sample( - inputs_ids=input_tokens if input_tokens_history.shape[-2] == 0 else input_tokens_history, + inputs_ids=( + input_tokens + if input_tokens_history.shape[-2] == 0 + else input_tokens_history + ), hidden_states=hidden_states, infer_text=infer_text, temperature=temperture, @@ -540,11 +550,11 @@ def execute_model( min_new_token=min_new_token, now_length=1, eos_token=eos_token, - start_idx=start_idx + start_idx=start_idx, ) # print(logprob.shape, idx_next.shape) if len(logprob.shape) == 2: - logprob = logprob[:,None,:] + logprob = logprob[:, None, :] logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0] # print("测试",idx_next.shape, logprob.shape) # Sample the next token. @@ -557,14 +567,16 @@ def execute_model( idx_next_i = idx_next[i, 0, :].cpu().tolist() logprob_i = logprob[i].cpu().tolist() result = SequenceGroupOutput( - samples = [SequenceOutput( - parent_seq_id=seq_groups[i], - logprobs={tuple(idx_next_i):logprob_i}, - output_token=tuple(idx_next_i), - hidden_states=hidden_states[i].cpu(), - finished=finish[i].item(), - ),], - prompt_logprobs = None + samples=[ + SequenceOutput( + parent_seq_id=seq_groups[i], + logprobs={tuple(idx_next_i): logprob_i}, + output_token=tuple(idx_next_i), + hidden_states=hidden_states[i].cpu(), + finished=finish[i].item(), + ), + ], + prompt_logprobs=None, ) results.append(result) # print(results) @@ -575,7 +587,9 @@ def execute_model( def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. vocab_size = self.model_config.get_vocab_size() - sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1, infer_text=True) + sampling_params = SamplingParams( + top_p=0.99, top_k=vocab_size - 1, infer_text=True + ) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs @@ -583,8 +597,9 @@ def profile_run(self) -> None: # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = max_num_batched_tokens // max_num_seqs + ( + group_id < max_num_batched_tokens % max_num_seqs + ) seq_data = SequenceData([0] * seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -605,20 +620,28 @@ def profile_run(self) -> None: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: assert not self.model_config.enforce_eager - logger.info("Capturing the model for CUDA graphs. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " - "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode.") + logger.info( + "Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI." + ) + logger.info( + "CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode." + ) start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_emb = torch.zeros(max_batch_size, 1, self.model_config.get_hidden_size(), dtype=next(self.model.parameters()).dtype).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() + input_emb = torch.zeros( + max_batch_size, + 1, + self.model_config.get_hidden_size(), + dtype=next(self.model.parameters()).dtype, + ).cuda() + input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() @@ -718,12 +741,15 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, - non_blocking=True) - self.input_buffers["context_lens"].copy_(input_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(input_metadata.block_tables, - non_blocking=True) + self.input_buffers["slot_mapping"].copy_( + input_metadata.slot_mapping, non_blocking=True + ) + self.input_buffers["context_lens"].copy_( + input_metadata.context_lens, non_blocking=True + ) + self.input_buffers["block_tables"].copy_( + input_metadata.block_tables, non_blocking=True + ) # Run the graph. self.graph.replay() @@ -749,10 +775,12 @@ def _make_tensor_with_pad( pin_memory: bool = False, ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu") + return torch.tensor( + padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu", + ) def _get_graph_batch_size(batch_size: int) -> int: diff --git a/ChatTTS/model/velocity/output.py b/ChatTTS/model/velocity/output.py index ea3c81d80..3413a3e2b 100644 --- a/ChatTTS/model/velocity/output.py +++ b/ChatTTS/model/velocity/output.py @@ -1,8 +1,12 @@ from typing import List, Optional import torch -from ChatTTS.model.velocity.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, - SequenceStatus) +from ChatTTS.model.velocity.sequence import ( + PromptLogprobs, + SampleLogprobs, + SequenceGroup, + SequenceStatus, +) class CompletionOutput: @@ -41,13 +45,15 @@ def finished(self) -> bool: return self.finish_reason is not None def __repr__(self) -> str: - return (f"CompletionOutput(index={self.index}, " - f"text={self.text!r}, " - f"token_ids={self.token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason}, " - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})") + return ( + f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})" + ) class RequestOutput: @@ -85,7 +91,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seqs = seq_group.get_seqs() if seq_group.sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + seq_group.sampling_params.length_penalty + ) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -101,12 +108,15 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # logprobs are not requested. logprobs = None finshed_reason = SequenceStatus.get_finished_reason(seq.status) - output = CompletionOutput(seqs.index(seq), seq.output_text, - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), logprobs, - finshed_reason, - seq.data.hidden_states - ) + output = CompletionOutput( + seqs.index(seq), + seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + logprobs, + finshed_reason, + seq.data.hidden_states, + ) outputs.append(output) # Every sequence in the sequence group should have the same prompt. @@ -114,14 +124,21 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls( + seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + ) def __repr__(self) -> str: - return (f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"outputs={self.outputs}, " - f"finished={self.finished})" - ) + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished})" + ) diff --git a/ChatTTS/model/velocity/post_model.py b/ChatTTS/model/velocity/post_model.py index c38853b3a..89bc79dcc 100644 --- a/ChatTTS/model/velocity/post_model.py +++ b/ChatTTS/model/velocity/post_model.py @@ -12,13 +12,11 @@ from torch.functional import F from torch.nn.utils.parametrizations import weight_norm from typing import List, Callable + + class Post_model(nn.Module): def __init__( - self, - hidden_size: int, - num_audio_tokens: int, - num_text_tokens: int, - num_vq=4 + self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 ): super().__init__() @@ -27,41 +25,24 @@ def __init__( self.model_dim = hidden_size self.emb_code = nn.ModuleList( - [ - nn.Embedding( - num_audio_tokens, - self.model_dim - ) - for _ in range(num_vq) - ], - ) - self.emb_text = nn.Embedding( - num_text_tokens, self.model_dim + [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], ) + self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) self.head_text = weight_norm( - nn.Linear( - self.model_dim, - num_text_tokens, - bias=False - ), + nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight", ) self.head_code = nn.ModuleList( [ weight_norm( - nn.Linear( - self.model_dim, - num_audio_tokens, - bias=False - ), + nn.Linear(self.model_dim, num_audio_tokens, bias=False), name="weight", ) for _ in range(self.num_vq) ], ) - def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: """ get_emb @@ -90,112 +71,118 @@ def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Ten del emb_text, emb_code, text_mask_inv return emb - + + class Sampler: - def __init__(self, - post_model: Post_model, - num_audio_tokens: int, - num_vq: int - ): + def __init__(self, post_model: Post_model, num_audio_tokens: int, num_vq: int): self.post_model = post_model self.device = next(self.post_model.parameters()).device self.num_audio_tokens = num_audio_tokens self.num_vq = num_vq - - def sample(self, - inputs_ids: torch.Tensor, - hidden_states: torch.Tensor, - infer_text: bool = False, - temperature: torch.Tensor = 1.0, - logits_processors: List[Callable] = [lambda logits_token, logits: logits,], - logits_warpers: List[Callable] = [lambda logits_token, logits: logits,], - min_new_token: int = 0, - now_length: int = 0, - eos_token: int = 0, - start_idx: int = 0, - ): - # print(inputs_ids.shape) - B = hidden_states.shape[0] - - end_idx = torch.zeros( - inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + + def sample( + self, + inputs_ids: torch.Tensor, + hidden_states: torch.Tensor, + infer_text: bool = False, + temperature: torch.Tensor = 1.0, + logits_processors: List[Callable] = [ + lambda logits_token, logits: logits, + ], + logits_warpers: List[Callable] = [ + lambda logits_token, logits: logits, + ], + min_new_token: int = 0, + now_length: int = 0, + eos_token: int = 0, + start_idx: int = 0, + ): + # print(inputs_ids.shape) + B = hidden_states.shape[0] + + end_idx = torch.zeros( + inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + ) + finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() + if not infer_text: + temperature = ( + temperature.unsqueeze(0) + .expand(inputs_ids.shape[0], -1) + .contiguous() + .view(-1, 1) ) - finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() - if not infer_text: - temperature = ( - temperature.unsqueeze(0) - .expand(inputs_ids.shape[0], -1) - .contiguous() - .view(-1, 1) - ) - - if infer_text: - logits: torch.Tensor = self.post_model.head_text(hidden_states) - else: - # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) - logits = torch.empty( - hidden_states.size(0), - hidden_states.size(1), - self.num_audio_tokens, - self.num_vq, - dtype=torch.float, - device=self.device, - ) - for num_vq_iter in range(self.num_vq): - x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) - logits[..., num_vq_iter] = x - del x - - del hidden_states - - # logits = logits[:, -1].float() - logits = logits.narrow(1, -1, 1).squeeze_(1).float() - - if not infer_text: - # logits = rearrange(logits, "b c n -> (b n) c") - logits = logits.permute(0, 2, 1) - logits = logits.reshape(-1, logits.size(2)) - # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") - inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) - logits_token = inputs_ids_sliced.reshape( - inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), - -1, - ).to(self.device) - else: - logits_token = inputs_ids[:, start_idx:, 0].to(self.device) - - logits /= temperature - - for logitsProcessors in logits_processors: - logits = logitsProcessors(logits_token, logits) - - for logitsWarpers in logits_warpers: - logits = logitsWarpers(logits_token, logits) - - del logits_token - - if now_length < min_new_token: - logits[:, eos_token] = -torch.inf - - scores = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) - if not infer_text: - scores = scores.reshape(B, -1, scores.shape[-1]) - if not infer_text: - # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) - idx_next = idx_next.view(-1, self.num_vq) - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - else: - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - - del inputs_ids - - not_finished = finish.logical_not().to(end_idx.device) - - end_idx.add_(not_finished.int()) - idx_next = idx_next[:, None, :] - return idx_next, torch.log(scores), finish, \ No newline at end of file + + if infer_text: + logits: torch.Tensor = self.post_model.head_text(hidden_states) + else: + # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) + logits = torch.empty( + hidden_states.size(0), + hidden_states.size(1), + self.num_audio_tokens, + self.num_vq, + dtype=torch.float, + device=self.device, + ) + for num_vq_iter in range(self.num_vq): + x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) + logits[..., num_vq_iter] = x + del x + + del hidden_states + + # logits = logits[:, -1].float() + logits = logits.narrow(1, -1, 1).squeeze_(1).float() + + if not infer_text: + # logits = rearrange(logits, "b c n -> (b n) c") + logits = logits.permute(0, 2, 1) + logits = logits.reshape(-1, logits.size(2)) + # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") + inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) + logits_token = inputs_ids_sliced.reshape( + inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), + -1, + ).to(self.device) + else: + logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + + logits /= temperature + + for logitsProcessors in logits_processors: + logits = logitsProcessors(logits_token, logits) + + for logitsWarpers in logits_warpers: + logits = logitsWarpers(logits_token, logits) + + del logits_token + + if now_length < min_new_token: + logits[:, eos_token] = -torch.inf + + scores = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) + if not infer_text: + scores = scores.reshape(B, -1, scores.shape[-1]) + if not infer_text: + # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) + idx_next = idx_next.view(-1, self.num_vq) + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + else: + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + + del inputs_ids + + not_finished = finish.logical_not().to(end_idx.device) + + end_idx.add_(not_finished.int()) + idx_next = idx_next[:, None, :] + return ( + idx_next, + torch.log(scores), + finish, + ) diff --git a/ChatTTS/model/velocity/sampling_params.py b/ChatTTS/model/velocity/sampling_params.py index be3f9bf7f..e650fc546 100644 --- a/ChatTTS/model/velocity/sampling_params.py +++ b/ChatTTS/model/velocity/sampling_params.py @@ -1,4 +1,5 @@ """Sampling parameters for text generation.""" + from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union @@ -113,13 +114,20 @@ def __init__( prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = ([lambda logits_token, logits: logits,],[lambda logits_token, logits: logits,]), + logits_processors: Optional[List[LogitsProcessor]] = ( + [ + lambda logits_token, logits: logits, + ], + [ + lambda logits_token, logits: logits, + ], + ), min_new_token: int = 0, max_new_token: int = 8192, infer_text: bool = False, eos_token: int = 0, - spk_emb:str = None, - start_idx:int = 0, + spk_emb: str = None, + start_idx: int = 0, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -173,42 +181,50 @@ def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if self.best_of < self.n: - raise ValueError(f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}." + ) if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2], got " - f"{self.presence_penalty}.") + raise ValueError( + "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." + ) if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}.") + raise ValueError( + "frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}." + ) if not 0.0 < self.repetition_penalty <= 2.0: - raise ValueError("repetition_penalty must be in (0, 2], got " - f"{self.repetition_penalty}.") + raise ValueError( + "repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}." + ) # if self.temperature < 0.0: # raise ValueError( # f"temperature must be non-negative, got {self.temperature}.") if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, " - f"got {self.top_k}.") + raise ValueError( + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError("min_p must be in [0, 1], got " - f"{self.min_p}.") + raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.") if self.max_tokens < 1: - raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") if self.logprobs is not None and self.logprobs < 0: - raise ValueError( - f"logprobs must be non-negative, got {self.logprobs}.") + raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.") if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + raise ValueError( + f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}." + ) def _verify_beam_search(self) -> None: if self.best_of == 1: - raise ValueError("best_of must be greater than 1 when using beam " - f"search. Got {self.best_of}.") + raise ValueError( + "best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}." + ) if self.temperature > _SAMPLING_EPS: raise ValueError("temperature must be 0 when using beam search.") if self.top_p < 1.0 - _SAMPLING_EPS: @@ -218,22 +234,29 @@ def _verify_beam_search(self) -> None: if self.early_stopping not in [True, False, "never"]: raise ValueError( f"early_stopping must be True, False, or 'never', " - f"got {self.early_stopping}.") + f"got {self.early_stopping}." + ) def _verify_non_beam_search(self) -> None: if self.early_stopping is not False: - raise ValueError("early_stopping is not effective and must be " - "False when not using beam search.") - if (self.length_penalty < 1.0 - _SAMPLING_EPS - or self.length_penalty > 1.0 + _SAMPLING_EPS): + raise ValueError( + "early_stopping is not effective and must be " + "False when not using beam search." + ) + if ( + self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS + ): raise ValueError( "length_penalty is not effective and must be the " - "default value of 1.0 when not using beam search.") + "default value of 1.0 when not using beam search." + ) def _verify_greedy_sampling(self) -> None: if self.best_of > 1: - raise ValueError("best_of must be 1 when using greedy sampling." - f"Got {self.best_of}.") + raise ValueError( + "best_of must be 1 when using greedy sampling." f"Got {self.best_of}." + ) @cached_property def sampling_type(self) -> SamplingType: @@ -270,4 +293,4 @@ def __repr__(self) -> str: f"max_new_token={self.max_new_token}), " f"min_new_token={self.min_new_token}), " f"infer_text={self.infer_text})" - ) + ) diff --git a/ChatTTS/model/velocity/scheduler.py b/ChatTTS/model/velocity/scheduler.py index e93ca7fd6..4eb38d278 100644 --- a/ChatTTS/model/velocity/scheduler.py +++ b/ChatTTS/model/velocity/scheduler.py @@ -6,8 +6,13 @@ from ChatTTS.model.velocity.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger -from ChatTTS.model.velocity.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) +from ChatTTS.model.velocity.sequence import ( + Sequence, + SequenceData, + SequenceGroup, + SequenceGroupMetadata, + SequenceStatus, +) logger = init_logger(__name__) @@ -21,6 +26,7 @@ class PreemptionMode(enum.Enum): recompute them when the sequences are resumed, treating the sequences as new prompts. """ + SWAP = enum.auto() RECOMPUTE = enum.auto() @@ -49,8 +55,12 @@ def __init__( def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. - return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) + return ( + not self.scheduled_seq_groups + and not self.blocks_to_swap_in + and not self.blocks_to_swap_out + and not self.blocks_to_copy + ) class Scheduler: @@ -63,8 +73,10 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config - self.prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name="fcfs") @@ -73,7 +85,8 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, - sliding_window=self.cache_config.sliding_window) + sliding_window=self.cache_config.sliding_window, + ) # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. @@ -89,7 +102,7 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): - request_id = (request_id, ) + request_id = (request_id,) request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: # We need to reverse the list as we are removing elements @@ -129,8 +142,9 @@ def _schedule(self) -> SchedulerOutputs: scheduled: List[SequenceGroup] = [] # The total number of sequences on the fly, including the # requests in the generation phase. - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted @@ -139,16 +153,16 @@ def _schedule(self) -> SchedulerOutputs: while self.waiting: seq_group = self.waiting[0] - waiting_seqs = seq_group.get_seqs( - status=SequenceStatus.WAITING) + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") + "Waiting sequence group should have only one prompt " "sequence." + ) num_prompt_tokens = waiting_seqs[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + f" and exceeds limit of {self.prompt_limit}" + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -162,7 +176,8 @@ def _schedule(self) -> SchedulerOutputs: elif can_allocate == AllocStatus.NEVER: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + f" and exceeds the capacity of block_manager" + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -172,15 +187,13 @@ def _schedule(self) -> SchedulerOutputs: # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) - if (num_batched_tokens > - self.scheduler_config.max_num_batched_tokens): + if num_batched_tokens > self.scheduler_config.max_num_batched_tokens: break # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break num_paddings = num_batched_tokens - sum(new_seq_lens) @@ -198,8 +211,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -239,8 +251,9 @@ def _schedule(self) -> SchedulerOutputs: # Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) if not preempted: - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) while self.swapped: seq_group = self.swapped[0] @@ -251,8 +264,7 @@ def _schedule(self) -> SchedulerOutputs: # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break seq_group = self.swapped.pop(0) @@ -266,7 +278,8 @@ def _schedule(self) -> SchedulerOutputs: # sequences in the RUNNING state. num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) + for seq_group in self.running + ) scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=self.running, @@ -313,8 +326,7 @@ def free_seq(self, seq: Sequence) -> None: def free_finished_seq_groups(self) -> None: self.running = [ - seq_group for seq_group in self.running - if not seq_group.is_finished() + seq_group for seq_group in self.running if not seq_group.is_finished() ] def _allocate(self, seq_group: SequenceGroup) -> None: @@ -406,7 +418,8 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") + "the swap space to avoid this error." + ) mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): diff --git a/ChatTTS/model/velocity/sequence.py b/ChatTTS/model/velocity/sequence.py index f5c2a09ad..4bc3f354d 100644 --- a/ChatTTS/model/velocity/sequence.py +++ b/ChatTTS/model/velocity/sequence.py @@ -1,4 +1,5 @@ """Sequence and its related classes.""" + import copy import enum from typing import Dict, List, Optional, Union @@ -12,6 +13,7 @@ class SequenceStatus(enum.Enum): """Status of a sequence.""" + WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() @@ -69,10 +71,12 @@ def __init__( self.cumulative_logprob = 0.0 self.hidden_states: Optional[torch.Tensor] = None self.finished = False - + def append_token_id(self, token_id: int, logprob: float) -> None: if isinstance(self.cumulative_logprob, float): - self.cumulative_logprob = [0.0, ] * len(logprob) + self.cumulative_logprob = [ + 0.0, + ] * len(logprob) self.output_token_ids.append(token_id) for i in range(len(self.cumulative_logprob)): self.cumulative_logprob[i] += logprob[i] @@ -82,7 +86,7 @@ def append_hidden_states(self, hidden_states: torch.Tensor) -> None: self.hidden_states = hidden_states else: self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) - + def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) @@ -101,12 +105,14 @@ def get_last_token_id(self) -> int: return self.output_token_ids[-1] def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}), " - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " - f"finished={self.finished})") + return ( + f"SequenceData(" + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}), " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " + f"finished={self.finished})" + ) class Sequence: @@ -165,8 +171,7 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: last_block = self.logical_token_blocks[-1] num_empty_slots = last_block.get_num_empty_slots() - last_block.append_tokens(token_ids[cursor:cursor + - num_empty_slots]) + last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) cursor += num_empty_slots def append_token_id( @@ -174,7 +179,7 @@ def append_token_id( token_id: int, logprobs: Dict[int, float], hidden_states: Optional[torch.Tensor] = None, - finished: bool = False + finished: bool = False, ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) @@ -182,7 +187,7 @@ def append_token_id( self.data.append_token_id(token_id, logprobs[token_id]) self.data.append_hidden_states(hidden_states) self.data.finished = finished - + def get_len(self) -> int: return self.data.get_len() @@ -204,10 +209,12 @@ def get_output_token_ids(self) -> List[int]: def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob - def get_beam_search_score(self, - length_penalty: float = 0.0, - seq_len: Optional[int] = None, - eos_token_id: Optional[int] = None) -> float: + def get_beam_search_score( + self, + length_penalty: float = 0.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> float: """Calculate the beam search score with length penalty. Adapted from @@ -218,8 +225,7 @@ def get_beam_search_score(self, seq_len = self.get_len() # NOTE: HF implementation does not count the EOS token # towards the length, we align with that here for testing. - if (eos_token_id is not None - and self.get_last_token_id() == eos_token_id): + if eos_token_id is not None and self.get_last_token_id() == eos_token_id: seq_len -= 1 return self.get_cumulative_logprob() / (seq_len**length_penalty) @@ -232,9 +238,11 @@ def fork(self, new_seq_id: int) -> "Sequence": return new_seq def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={len(self.logical_token_blocks)})") + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={len(self.logical_token_blocks)})" + ) class SequenceGroup: @@ -296,14 +304,10 @@ def get_seqs( if status is None: return list(self.seqs_dict.values()) else: - return [ - seq for seq in self.seqs_dict.values() if seq.status == status - ] + return [seq for seq in self.seqs_dict.values() if seq.status == status] def get_unfinished_seqs(self) -> List[Sequence]: - return [ - seq for seq in self.seqs_dict.values() if not seq.is_finished() - ] + return [seq for seq in self.seqs_dict.values() if not seq.is_finished()] def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] @@ -336,9 +340,11 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs_dict)})") + return ( + f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs_dict)})" + ) class SequenceGroupMetadata: @@ -386,27 +392,31 @@ def __init__( output_token: int, logprobs: Dict[int, float], hidden_states: Optional[torch.Tensor] = None, - finished: bool = False + finished: bool = False, ) -> None: self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs self.finished = finished self.hidden_states = hidden_states + def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})," - f"finished={self.finished})," - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" - ) + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})," + f"finished={self.finished})," + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput): raise NotImplementedError() - return (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token - and self.logprobs == other.logprobs) + return ( + self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + and self.logprobs == other.logprobs + ) class SequenceGroupOutput: @@ -421,14 +431,18 @@ def __init__( self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") + return ( + f"SequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceGroupOutput): raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) + return ( + self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs + ) # For each sequence group, we generate a list of SequenceOutput object, diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py index 0162302bf..9578551d9 100644 --- a/ChatTTS/model/velocity/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -1,17 +1,15 @@ """A GPU worker class.""" + import os from typing import Dict, List, Optional, Tuple import torch import torch.distributed -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_object_list) -from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel) +from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list +from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from ChatTTS.model.velocity.model_runner import ModelRunner @@ -33,7 +31,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, - post_model_path:str, + post_model_path: str, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -44,12 +42,17 @@ def __init__( self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker self.post_model_path = post_model_path - + if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config, is_driver_worker, post_model_path) + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + is_driver_worker, + post_model_path, + ) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -74,8 +77,9 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + _init_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method + ) # Initialize the model. set_random_seed(self.model_config.seed) @@ -105,10 +109,12 @@ def profile_num_available_blocks( peak_memory = total_gpu_memory - free_gpu_memory cache_block_size = CacheEngine.get_cache_block_size( - block_size, self.model_config, self.parallel_config) + block_size, self.model_config, self.parallel_config + ) num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) + (total_gpu_memory * gpu_memory_utilization - peak_memory) + // cache_block_size + ) num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) @@ -117,8 +123,9 @@ def profile_num_available_blocks( def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config - self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config) + self.cache_engine = CacheEngine( + self.cache_config, self.model_config, self.parallel_config + ) self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) @@ -171,10 +178,11 @@ def execute_model( assert blocks_to_swap_out is not None assert blocks_to_copy is not None block_swapping_info = [ - blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, ] - broadcast_object_list([num_seq_groups] + block_swapping_info, - src=0) + broadcast_object_list([num_seq_groups] + block_swapping_info, src=0) else: # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, # blocks_to_copy (4 elements) @@ -189,8 +197,9 @@ def execute_model( if num_seq_groups == 0: return {} - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) + output = self.model_runner.execute_model( + seq_group_metadata_list, self.gpu_cache + ) return output @@ -206,11 +215,13 @@ def _init_distributed_environment( raise RuntimeError( "torch.distributed is already initialized but the torch world " "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") + f"({torch_world_size} vs. {parallel_config.world_size})." + ) elif not distributed_init_method: raise ValueError( "distributed_init_method must be set if torch.distributed " - "is not already initialized") + "is not already initialized" + ) else: torch.distributed.init_process_group( backend="nccl", @@ -221,8 +232,9 @@ def _init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + initialize_model_parallel( + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): @@ -234,4 +246,5 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") + f"{compute_capability[0]}.{compute_capability[1]}." + ) diff --git a/test.py b/test.py index 07e9fae5d..2690b28fd 100644 --- a/test.py +++ b/test.py @@ -2,26 +2,28 @@ import torch import torchaudio import soundfile as sf + chat = ChatTTS.Chat() -chat.load(compile=False) # Set to True for better performance +chat.load(compile=False) # Set to True for better performance rand_spk = chat.sample_random_speaker() -print(rand_spk) # save it for later timbre recovery +print(rand_spk) # save it for later timbre recovery params_infer_code = ChatTTS.Chat.InferCodeParams( - spk_emb = rand_spk, # add sampled speaker - temperature = .3, # using custom temperature - top_P = 0.7, # top P decode - top_K = 20, # top K decode + spk_emb=rand_spk, # add sampled speaker + temperature=0.3, # using custom temperature + top_P=0.7, # top P decode + top_K=20, # top K decode ) params_refine_text = ChatTTS.Chat.RefineTextParams( - prompt='[oral_2][laugh_0][break_6]', + prompt="[oral_2][laugh_0][break_6]", ) texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] -wavs = chat.infer(texts, +wavs = chat.infer( + texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code, - ) +) # torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) -sf.write("output1.wav", wavs[1], 24000) \ No newline at end of file +sf.write("output1.wav", wavs[1], 24000) From 72f2ba2f547bce52cc3d481a21a3714f4e5d2b9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:37:14 +0900 Subject: [PATCH 06/27] chore: restore some latest changes --- ChatTTS/core.py | 147 +++++++++++++++++++---------------------------- requirements.txt | 1 + setup.py | 1 + 3 files changed, 62 insertions(+), 87 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index bfc728437..fa86b5018 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -5,20 +5,18 @@ from typing import Literal, Optional, List, Tuple, Dict, Union from json import load from pathlib import Path -import lzma -import pathlib -from ChatTTS.model.velocity.post_model import Post_model -from safetensors.torch import save_file, safe_open -from omegaconf import OmegaConf + +from safetensors.torch import save_file import numpy as np import torch from vocos import Vocos from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download -import pybase16384 as b14 -from ChatTTS.model.velocity.llm import LLM -from ChatTTS.model.velocity.sampling_params import SamplingParams -import yaml + +from .config import Config +from .model.velocity.llm import LLM +from .model.velocity.post_model import Post_model +from .model.velocity.sampling_params import SamplingParams from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( check_all_assets, @@ -30,7 +28,6 @@ from .utils import logger as utils_logger from .norm import Normalizer -import pybase16384 as b14 class Chat: @@ -38,6 +35,8 @@ def __init__(self, logger=logging.getLogger(__name__)): self.logger = logger utils_logger.set_logger(logger) + self.config = Config() + self.normalizer = Normalizer( os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"), logger, @@ -144,9 +143,7 @@ def load( use_flash_attn=use_flash_attn, **{ k: os.path.join(download_path, v) - for k, v in OmegaConf.load( - os.path.join(download_path, "config", "path.yaml") - ).items() + for k, v in asdict(self.config.path).items() }, ) @@ -172,7 +169,7 @@ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: @torch.no_grad() def _sample_random_speaker(self) -> torch.Tensor: - dim: int = self.hidden_size + dim: int = self.config.gpt.hidden_size spk = ( torch.randn(dim, device=self.std.device, dtype=self.std.dtype) .mul_(self.std) @@ -242,13 +239,9 @@ def interrupt(self): @torch.no_grad() def _load( self, - vocos_config_path: str = None, vocos_ckpt_path: str = None, - dvae_config_path: str = None, dvae_ckpt_path: str = None, - gpt_config_path: str = None, gpt_ckpt_path: str = None, - decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: Optional[torch.device] = None, @@ -262,34 +255,42 @@ def _load( self.device = device self.compile = compile - if vocos_config_path: - vocos = ( - Vocos.from_hparams(vocos_config_path) - .to( - # vocos on mps will crash, use cpu fallback - "cpu" - if "mps" in str(device) - else device - ) - .eval() - ) - assert vocos_ckpt_path, "vocos_ckpt_path should not be None" - vocos.load_state_dict( - torch.load(vocos_ckpt_path, weights_only=True, mmap=True) + feature_extractor = instantiate_class( + args=(), init=asdict(self.config.vocos.feature_extractor) + ) + backbone = instantiate_class(args=(), init=asdict(self.config.vocos.backbone)) + head = instantiate_class(args=(), init=asdict(self.config.vocos.head)) + vocos = ( + Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head) + .to( + # vocos on mps will crash, use cpu fallback + "cpu" + if "mps" in str(device) + else device ) - self.vocos = vocos - self.logger.log(logging.INFO, "vocos loaded.") - - if dvae_config_path: - cfg = OmegaConf.load(dvae_config_path) - dvae = DVAE(**cfg, coef=coef).to(device).eval() - coef = str(dvae) - assert dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_state_dict( - torch.load(dvae_ckpt_path, weights_only=True, mmap=True) + .eval() + ) + assert vocos_ckpt_path, "vocos_ckpt_path should not be None" + vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True)) + self.vocos = vocos + self.logger.log(logging.INFO, "vocos loaded.") + + dvae = ( + DVAE( + decoder_config=asdict(self.config.dvae.decoder), + encoder_config=asdict(self.config.dvae.encoder), + vq_config=asdict(self.config.dvae.vq), + dim=self.config.dvae.decoder.idim, + coef=coef, ) - self.dvae = dvae - self.logger.log(logging.INFO, "dvae loaded.") + .to(device) + .eval() + ) + coef = str(dvae) + assert dvae_ckpt_path, "dvae_ckpt_path should not be None" + dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True)) + self.dvae = dvae + self.logger.log(logging.INFO, "dvae loaded.") if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) @@ -343,7 +344,6 @@ def _load( self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) self.logger.log(logging.INFO, "gpt loaded.") - self.hidden_size = cfg.gpt_config.hidden_size self.gpt = LLM( model="asset/vllm_model/gpt", num_audio_tokens=cfg.num_audio_tokens, @@ -351,27 +351,22 @@ def _load( post_model_path="asset/vllm_model/post_model.safetensors", ) - if dvae_config_path: - cfg = OmegaConf.load(dvae_config_path) - dvae = DVAE(**cfg, coef=coef).to(device).eval() - coef = str(dvae) - assert dvae_ckpt_path, "dvae_ckpt_path should not be None" - dvae.load_state_dict( - torch.load(dvae_ckpt_path, weights_only=True, mmap=True) - ) - self.dvae = dvae - self.logger.log(logging.INFO, "dvae loaded.") - - if decoder_config_path: - cfg = OmegaConf.load(decoder_config_path) - decoder = DVAE(**cfg, coef=coef).to(device).eval() - coef = str(decoder) - assert decoder_ckpt_path, "decoder_ckpt_path should not be None" - decoder.load_state_dict( - torch.load(decoder_ckpt_path, weights_only=True, mmap=True) + decoder = ( + DVAE( + decoder_config=asdict(self.config.decoder), + dim=self.config.decoder.idim, + coef=coef, ) - self.decoder = decoder - self.logger.log(logging.INFO, "decoder loaded.") + .to(device) + .eval() + ) + coef = str(decoder) + assert decoder_ckpt_path, "decoder_ckpt_path should not be None" + decoder.load_state_dict( + torch.load(decoder_ckpt_path, weights_only=True, mmap=True) + ) + self.decoder = decoder + self.logger.log(logging.INFO, "decoder loaded.") if tokenizer_path: self.tokenizer = Tokenizer(tokenizer_path, device) @@ -496,28 +491,6 @@ def _decode_to_wavs( del mel_specs return wavs - @staticmethod - def _decode_spk_emb(spk_emb: str) -> np.ndarray: - return np.frombuffer( - lzma.decompress( - b14.decode_from_string(spk_emb), - format=lzma.FORMAT_RAW, - filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], - ), - dtype=np.float16, - ).copy() - - @dataclass(repr=False, eq=False) - class GenerationOutputs: - ids: List[torch.Tensor] - # attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] - hiddens: List[torch.Tensor] - - def destroy(self): - del_all(self.ids) - # del_all(self.attentions) - # del_all(self.hiddens) - @torch.no_grad() def _infer_code( self, diff --git a/requirements.txt b/requirements.txt index 75066bb96..a75c42f3c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ WeTextProcessing; sys_platform == 'linux' nemo_text_processing; sys_platform == 'linux' av pydub +safetensors diff --git a/setup.py b/setup.py index da7b609e6..c5fe0a69e 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "transformers>=4.41.1", "vector_quantize_pytorch", "vocos", + "safetensors", ], platforms="any", classifiers=[ From fe68af9eec15c977e291f7560aa8559d92fc87f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:43:23 +0900 Subject: [PATCH 07/27] chore: restore some latest changes --- ChatTTS/core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index fa86b5018..afcc4672f 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -6,8 +6,8 @@ from json import load from pathlib import Path -from safetensors.torch import save_file import numpy as np +from safetensors.torch import save_file import torch from vocos import Vocos from vocos.pretrained import instantiate_class @@ -294,7 +294,7 @@ def _load( if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) - self.num_vq = 4 + self.config.gpt.num_vq = 4 if not os.path.exists("asset/vllm_model"): gpt = GPT( **cfg, @@ -501,7 +501,7 @@ def _infer_code( params: InferCodeParams, ): - gpt: LLM = self.gpt + gpt = self.gpt if not isinstance(text, list): text = [text] @@ -509,7 +509,7 @@ def _infer_code( assert len(text), "text should not be empty" if not isinstance(params.temperature, list): - temperature = [params.temperature] * self.num_vq + temperature = [params.temperature] * self.config.gpt.num_vq else: temperature = params.temperature @@ -535,7 +535,7 @@ def _infer_code( input_ids, attention_mask, text_mask = self.tokenizer.encode( text, - self.num_vq, + self.config.gpt.num_vq, prompt_str=params.spk_smp, device=self.device, ) @@ -587,7 +587,7 @@ def _refine_text( params: RefineTextParams, ): - gpt: LLM = self.gpt + gpt = self.gpt if not isinstance(text, list): text = [text] @@ -596,7 +596,7 @@ def _refine_text( input_ids, attention_mask, text_mask = self.tokenizer.encode( text, - self.num_vq, + self.config.gpt.num_vq, device=self.device, ) From 9c0a1df31c1ab5ef92d7774681cac16bdbba2280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:47:39 +0900 Subject: [PATCH 08/27] chore: restore some latest changes --- ChatTTS/core.py | 8 +++++--- ChatTTS/model/gpt.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index afcc4672f..3aa8ce912 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -253,6 +253,7 @@ def _load( device = select_device() self.logger.info("use device %s", str(device)) self.device = device + self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") self.compile = compile feature_extractor = instantiate_class( @@ -299,7 +300,8 @@ def _load( gpt = GPT( **cfg, use_flash_attn=use_flash_attn, - device=device, + device=self.device, + device_gpt=self.device_gpt, logger=self.logger, ).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" @@ -537,7 +539,7 @@ def _infer_code( text, self.config.gpt.num_vq, prompt_str=params.spk_smp, - device=self.device, + device=self.device_gpt, ) start_idx = input_ids.shape[-2] @@ -597,7 +599,7 @@ def _refine_text( input_ids, attention_mask, text_mask = self.tokenizer.encode( text, self.config.gpt.num_vq, - device=self.device, + device=self.device_gpt, ) start_idx = input_ids.shape[-2] diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 178c19446..eb4f638cf 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -28,6 +28,7 @@ def __init__( num_vq=4, use_flash_attn=False, device=torch.device("cpu"), + device_gpt=torch.device("cpu"), logger=logging.getLogger(__name__), ): super().__init__() @@ -35,7 +36,7 @@ def __init__( self.logger = logger self.device = device - self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") + self.device_gpt = device_gpt self.num_vq = num_vq self.num_audio_tokens = num_audio_tokens From 5a29d8e51c6dfb4716144a9d39a923d4f11a7e6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 24 Jul 2024 00:26:43 +0800 Subject: [PATCH 09/27] chore: restore some latest changes --- ChatTTS/config/config.py | 1 + ChatTTS/core.py | 106 ++++++++++++++++++--------------------- ChatTTS/model/gpt.py | 13 +++-- 3 files changed, 56 insertions(+), 64 deletions(-) diff --git a/ChatTTS/config/config.py b/ChatTTS/config/config.py index 1a3eeff14..b58fc261d 100644 --- a/ChatTTS/config/config.py +++ b/ChatTTS/config/config.py @@ -58,6 +58,7 @@ class GPT: spk_emb_dim: int = 192 spk_KL: bool = False num_audio_tokens: int = 626 + num_text_tokens: int = 21178 num_vq: int = 4 diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 3aa8ce912..d71665d37 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -293,66 +293,58 @@ def _load( self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") - if gpt_config_path: - cfg = OmegaConf.load(gpt_config_path) - self.config.gpt.num_vq = 4 - if not os.path.exists("asset/vllm_model"): - gpt = GPT( - **cfg, - use_flash_attn=use_flash_attn, - device=self.device, - device_gpt=self.device_gpt, - logger=self.logger, - ).eval() - assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.load_state_dict( - torch.load(gpt_ckpt_path, weights_only=True, mmap=True) + if not os.path.exists("asset/vllm_model"): + gpt = GPT( + gpt_config=asdict(self.config.gpt), + use_flash_attn=use_flash_attn, + device=device, + logger=self.logger, + ).eval() + assert gpt_ckpt_path, "gpt_ckpt_path should not be None" + gpt.from_pretrained(gpt_ckpt_path) + gpt.prepare(compile=compile and "cuda" in str(device)) + self.gpt = gpt + + pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) + self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") + self.post_model = ( + Post_model( + self.config.gpt.hidden_size, + self.config.gpt.num_audio_tokens, + self.config.gpt.num_text_tokens, + device=device, ) - gpt.prepare(compile=compile and "cuda" in str(device)) - self.gpt = gpt - pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) - self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") - self.post_model = ( - Post_model( - cfg.gpt_config.hidden_size, - cfg.num_audio_tokens, - cfg.num_text_tokens, - device=device, - ) - .to(device) - .eval() - ) - - self.post_model.emb_code = self.gpt.emb_code - self.post_model.emb_text = self.gpt.emb_text - self.post_model.head_text = self.gpt.head_text - self.post_model.head_code = self.gpt.head_code - save_file( - self.post_model.state_dict(), - "asset/vllm_model/post_model.safetensors", - ) - - self.num_audio_tokens = cfg.num_audio_tokens - spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") - assert os.path.exists( - spk_stat_path - ), f"Missing spk_stat.pt: {spk_stat_path}" - spk_stat: torch.Tensor = torch.load( - spk_stat_path, - weights_only=True, - mmap=True, - map_location=device, + .to(device) + .eval() ) - self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) - self.logger.log(logging.INFO, "gpt loaded.") - - self.gpt = LLM( - model="asset/vllm_model/gpt", - num_audio_tokens=cfg.num_audio_tokens, - num_text_tokens=cfg.num_text_tokens, - post_model_path="asset/vllm_model/post_model.safetensors", + + self.post_model.emb_code = self.gpt.emb_code + self.post_model.emb_text = self.gpt.emb_text + self.post_model.head_text = self.gpt.head_text + self.post_model.head_code = self.gpt.head_code + save_file( + self.post_model.state_dict(), + "asset/vllm_model/post_model.safetensors", ) + self.gpt = LLM( + model="asset/vllm_model/gpt", + num_audio_tokens=self.config.gpt.num_audio_tokens, + num_text_tokens=self.config.gpt.num_text_tokens, + post_model_path="asset/vllm_model/post_model.safetensors", + ) + + spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") + assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" + spk_stat: torch.Tensor = torch.load( + spk_stat_path, + weights_only=True, + mmap=True, + map_location=device, + ) + self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) + self.logger.log(logging.INFO, "gpt loaded.") + decoder = ( DVAE( decoder_config=asdict(self.config.decoder), @@ -543,7 +535,7 @@ def _infer_code( ) start_idx = input_ids.shape[-2] - num_code = self.num_audio_tokens - 1 + num_code = self.config.gpt.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( num_code=num_code, diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index eb4f638cf..7c8e6f965 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -23,8 +23,6 @@ class GPT(nn.Module): def __init__( self, gpt_config: dict, - num_audio_tokens: int = 626, - num_text_tokens: int = 21178, num_vq=4, use_flash_attn=False, device=torch.device("cpu"), @@ -39,7 +37,8 @@ def __init__( self.device_gpt = device_gpt self.num_vq = num_vq - self.num_audio_tokens = num_audio_tokens + self.num_audio_tokens = int(gpt_config["num_audio_tokens"]) + self.num_text_tokens = int(gpt_config["num_text_tokens"]) self.use_flash_attn = use_flash_attn @@ -49,7 +48,7 @@ def __init__( self.emb_code = nn.ModuleList( [ nn.Embedding( - num_audio_tokens, + self.num_audio_tokens, self.model_dim, device=self.device_gpt, ) @@ -57,13 +56,13 @@ def __init__( ], ) self.emb_text = nn.Embedding( - num_text_tokens, self.model_dim, device=self.device_gpt + self.num_text_tokens, self.model_dim, device=self.device_gpt ) self.head_text = weight_norm( nn.Linear( self.model_dim, - num_text_tokens, + self.num_text_tokens, bias=False, device=device, ), @@ -74,7 +73,7 @@ def __init__( weight_norm( nn.Linear( self.model_dim, - num_audio_tokens, + self.num_audio_tokens, bias=False, device=device, ), From cc8024a399c637612cd69de0734749bdc94e7c37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 24 Jul 2024 00:36:01 +0800 Subject: [PATCH 10/27] chore: restore some latest changes --- ChatTTS/core.py | 16 ++++++++++++++++ ChatTTS/model/gpt.py | 1 + 2 files changed, 17 insertions(+) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index d71665d37..63953f273 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -535,6 +535,16 @@ def _infer_code( ) start_idx = input_ids.shape[-2] + if not gpt.is_vllm: + emb = gpt(input_ids, text_mask) + + del text_mask + + if params.spk_emb is not None: + self.tokenizer.apply_spk_emb( + emb, params.spk_emb, input_ids, self.gpt.device_gpt + ) + num_code = self.config.gpt.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( @@ -603,6 +613,12 @@ def _refine_text( repetition_penalty=params.repetition_penalty, ) + if not gpt.is_vllm: + + emb = gpt(input_ids, text_mask) + + del text_mask + sample_params = SamplingParams( temperature=params.temperature, max_new_token=params.max_new_token, diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 7c8e6f965..dec4151ee 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -44,6 +44,7 @@ def __init__( self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) self.is_te_llama = False + self.is_vllm = False self.model_dim = int(self.gpt.config.hidden_size) self.emb_code = nn.ModuleList( [ From 9cd5620276409dafd4d5eb4ed4b7ad70d2575394 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:31:51 +0800 Subject: [PATCH 11/27] doc: use green pypi badge --- README.md | 2 +- docs/cn/README.md | 2 +- docs/fr/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e9d0d2af3..41bfc830b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ A generative speech model for daily dialogue. [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) diff --git a/docs/cn/README.md b/docs/cn/README.md index 8dbfa1244..c5dd6f753 100644 --- a/docs/cn/README.md +++ b/docs/cn/README.md @@ -6,7 +6,7 @@ 一款适用于日常对话的生成式语音模型。 [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) diff --git a/docs/fr/README.md b/docs/fr/README.md index 3addc9b73..d2fbb00d7 100644 --- a/docs/fr/README.md +++ b/docs/fr/README.md @@ -6,7 +6,7 @@ Un modèle de parole génératif pour le dialogue quotidien. [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) From f54ddaad149f570b7d7f9e78f22907573a08f3f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 24 Jul 2024 21:52:04 +0800 Subject: [PATCH 12/27] chore(vllm): rename Post_model to PostModel --- ChatTTS/core.py | 4 ++-- ChatTTS/model/velocity/model_runner.py | 4 ++-- ChatTTS/model/velocity/post_model.py | 8 +++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 63953f273..995a59163 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -15,7 +15,7 @@ from .config import Config from .model.velocity.llm import LLM -from .model.velocity.post_model import Post_model +from .model.velocity.post_model import PostModel from .model.velocity.sampling_params import SamplingParams from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( @@ -308,7 +308,7 @@ def _load( pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") self.post_model = ( - Post_model( + PostModel( self.config.gpt.hidden_size, self.config.gpt.num_audio_tokens, self.config.gpt.num_text_tokens, diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 5b0f2c2d8..e59b3fffc 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -22,7 +22,7 @@ SequenceOutput, ) from vllm.utils import in_wsl -from ChatTTS.model.velocity.post_model import Post_model, Sampler +from ChatTTS.model.velocity.post_model import PostModel, Sampler from safetensors.torch import safe_open logger = init_logger(__name__) @@ -78,7 +78,7 @@ def __init__( def load_model(self) -> None: self.model = get_model(self.model_config) - self.post_model = Post_model( + self.post_model = PostModel( self.model_config.get_hidden_size(), self.model_config.num_audio_tokens, self.model_config.num_text_tokens, diff --git a/ChatTTS/model/velocity/post_model.py b/ChatTTS/model/velocity/post_model.py index 89bc79dcc..79b8900a4 100644 --- a/ChatTTS/model/velocity/post_model.py +++ b/ChatTTS/model/velocity/post_model.py @@ -1,12 +1,10 @@ -import os, platform +import os os.environ["TOKENIZERS_PARALLELISM"] = "false" """ https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning """ -import logging - import torch import torch.nn as nn from torch.functional import F @@ -14,7 +12,7 @@ from typing import List, Callable -class Post_model(nn.Module): +class PostModel(nn.Module): def __init__( self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 ): @@ -74,7 +72,7 @@ def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Ten class Sampler: - def __init__(self, post_model: Post_model, num_audio_tokens: int, num_vq: int): + def __init__(self, post_model: PostModel, num_audio_tokens: int, num_vq: int): self.post_model = post_model self.device = next(self.post_model.parameters()).device self.num_audio_tokens = num_audio_tokens From f6ffdca74f53772342eb70e86731f2a915014687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 24 Jul 2024 22:24:18 +0800 Subject: [PATCH 13/27] chore(vLLM): move load logic to gpt --- ChatTTS/core.py | 53 ++++++++-------------------------------- ChatTTS/model/gpt.py | 57 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 50 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 995a59163..2e2333a29 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -7,15 +7,12 @@ from pathlib import Path import numpy as np -from safetensors.torch import save_file import torch from vocos import Vocos from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download from .config import Config -from .model.velocity.llm import LLM -from .model.velocity.post_model import PostModel from .model.velocity.sampling_params import SamplingParams from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( @@ -293,46 +290,16 @@ def _load( self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") - if not os.path.exists("asset/vllm_model"): - gpt = GPT( - gpt_config=asdict(self.config.gpt), - use_flash_attn=use_flash_attn, - device=device, - logger=self.logger, - ).eval() - assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.from_pretrained(gpt_ckpt_path) - gpt.prepare(compile=compile and "cuda" in str(device)) - self.gpt = gpt - - pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) - self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") - self.post_model = ( - PostModel( - self.config.gpt.hidden_size, - self.config.gpt.num_audio_tokens, - self.config.gpt.num_text_tokens, - device=device, - ) - .to(device) - .eval() - ) - - self.post_model.emb_code = self.gpt.emb_code - self.post_model.emb_text = self.gpt.emb_text - self.post_model.head_text = self.gpt.head_text - self.post_model.head_code = self.gpt.head_code - save_file( - self.post_model.state_dict(), - "asset/vllm_model/post_model.safetensors", - ) - - self.gpt = LLM( - model="asset/vllm_model/gpt", - num_audio_tokens=self.config.gpt.num_audio_tokens, - num_text_tokens=self.config.gpt.num_text_tokens, - post_model_path="asset/vllm_model/post_model.safetensors", - ) + gpt = GPT( + gpt_config=asdict(self.config.gpt), + use_flash_attn=use_flash_attn, + device=device, + logger=self.logger, + ).eval() + assert gpt_ckpt_path, "gpt_ckpt_path should not be None" + gpt.from_pretrained(gpt_ckpt_path) + gpt.prepare(compile=compile and "cuda" in str(device)) + self.gpt = gpt spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index dec4151ee..c0fd769a0 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,9 +1,11 @@ -import platform +import os, platform from dataclasses import dataclass import logging from typing import Union, List, Optional, Tuple import gc +from pathlib import Path +from safetensors.torch import save_file import torch import torch.nn as nn import torch.nn.functional as F @@ -17,14 +19,16 @@ from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat from ..utils import del_all +from .velocity.llm import LLM +from .velocity.post_model import PostModel class GPT(nn.Module): def __init__( self, gpt_config: dict, - num_vq=4, use_flash_attn=False, + use_vllm=False, device=torch.device("cpu"), device_gpt=torch.device("cpu"), logger=logging.getLogger(__name__), @@ -36,15 +40,20 @@ def __init__( self.device = device self.device_gpt = device_gpt - self.num_vq = num_vq + self.config = gpt_config + self.num_vq = int(gpt_config["num_vq"]) self.num_audio_tokens = int(gpt_config["num_audio_tokens"]) self.num_text_tokens = int(gpt_config["num_text_tokens"]) self.use_flash_attn = use_flash_attn + self.is_te_llama = False + self.is_vllm = use_vllm + + if self.is_vllm: + return self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) - self.is_te_llama = False - self.is_vllm = False + self.model_dim = int(self.gpt.config.hidden_size) self.emb_code = nn.ModuleList( [ @@ -53,7 +62,7 @@ def __init__( self.model_dim, device=self.device_gpt, ) - for _ in range(num_vq) + for _ in range(self.num_vq) ], ) self.emb_text = nn.Embedding( @@ -85,6 +94,40 @@ def __init__( ) def from_pretrained(self, file_path: str): + if self.is_vllm and platform.system().lower() == "linux": + vllm_folder = Path(os.getcwd()) / "asset" / "vllm" + if not os.path.exists(vllm_folder): + self.logger.info("initializing vLLM model to %s", str(vllm_folder)) + vllm_folder.mkdir(mode=0o755, parents=True, exist_ok=True) + gpt = GPT(gpt_config=self.config) + gpt.from_pretrained(file_path) + gpt.gpt.save_pretrained(vllm_folder / "gpt") + post_model = ( + PostModel( + int(self.gpt.config.hidden_size), + self.num_audio_tokens, + self.num_text_tokens, + ) + .to(self.device) + .eval() + ) + post_model.emb_code = gpt.emb_code + post_model.emb_text = gpt.emb_text + post_model.head_text = gpt.head_text + post_model.head_code = gpt.head_code + save_file( + post_model.state_dict(), + vllm_folder / "post_model.safetensors", + ) + del post_model, gpt + self.llm = LLM( + model=str(vllm_folder / "gpt"), + num_audio_tokens=self.num_audio_tokens, + num_text_tokens=self.num_text_tokens, + post_model_path=vllm_folder / "post_model.safetensors", + ) + self.logger.info("vLLM model loaded") + return self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True)) @@ -142,7 +185,7 @@ def _build_llama( def prepare(self, compile=False): if self.use_flash_attn and is_flash_attn_2_available(): self.gpt = self.gpt.to(dtype=torch.float16) - if compile and not self.is_te_llama: + if compile and not self.is_te_llama and not self.is_vllm: try: self.compile(backend="inductor", dynamic=True) self.gpt.compile(backend="inductor", dynamic=True) From 041d803712a649d62607f29777dd59244e39444d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:10:49 +0800 Subject: [PATCH 14/27] chore(core): restore all normal infers --- ChatTTS/core.py | 174 +++++++++++++++++++++++++++++-------------- ChatTTS/model/gpt.py | 16 ++-- 2 files changed, 123 insertions(+), 67 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 2e2333a29..9b3b18d67 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -13,7 +13,6 @@ from huggingface_hub import snapshot_download from .config import Config -from .model.velocity.sampling_params import SamplingParams from .model import DVAE, GPT, gen_logits, Tokenizer from .utils import ( check_all_assets, @@ -129,6 +128,7 @@ def load( device: Optional[torch.device] = None, coef: Optional[torch.Tensor] = None, use_flash_attn=False, + use_vllm=False, ) -> bool: download_path = self.download_models(source, force_redownload, custom_path) if download_path is None: @@ -138,6 +138,7 @@ def load( compile=compile, coef=coef, use_flash_attn=use_flash_attn, + use_vllm=use_vllm, **{ k: os.path.join(download_path, v) for k, v in asdict(self.config.path).items() @@ -245,6 +246,7 @@ def _load( compile: bool = True, coef: Optional[str] = None, use_flash_attn=False, + use_vllm=False, ): if device is None: device = select_device() @@ -293,6 +295,7 @@ def _load( gpt = GPT( gpt_config=asdict(self.config.gpt), use_flash_attn=use_flash_attn, + use_vllm=use_vllm, device=device, logger=self.logger, ).eval() @@ -502,16 +505,6 @@ def _infer_code( ) start_idx = input_ids.shape[-2] - if not gpt.is_vllm: - emb = gpt(input_ids, text_mask) - - del text_mask - - if params.spk_emb is not None: - self.tokenizer.apply_spk_emb( - emb, params.spk_emb, input_ids, self.gpt.device_gpt - ) - num_code = self.config.gpt.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( @@ -521,34 +514,74 @@ def _infer_code( repetition_penalty=params.repetition_penalty, ) - sample_params = SamplingParams( - temperature=temperature, + if gpt.is_vllm: + from .model.velocity.sampling_params import SamplingParams + sample_params = SamplingParams( + temperature=temperature, + max_new_token=params.max_new_token, + max_tokens=8192, + min_new_token=params.min_new_token, + logits_processors=(logits_processors, logits_warpers), + eos_token=num_code, + infer_text=False, + start_idx=start_idx, + ) + input_ids = [i.tolist() for i in input_ids] + + result = gpt.llm.generate( + None, + sample_params, + input_ids, + ) + + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append( + i.outputs[0].hidden_states.to(torch.float32).to(self.device) + ) + + del text_mask, input_ids + del_all(logits_warpers) + del_all(logits_processors) + + return [ + self.GenerationOutputs(ids=token_ids, hiddens=hidden_states), + ] + + emb = gpt(input_ids, text_mask) + + del text_mask + + if params.spk_emb is not None: + self.tokenizer.apply_spk_emb( + emb, params.spk_emb, input_ids, self.gpt.device_gpt + ) + + result = gpt.generate( + emb, + input_ids, + temperature=torch.tensor(temperature, device=device), + eos_token=num_code, + attention_mask=attention_mask, max_new_token=params.max_new_token, - max_tokens=8192, min_new_token=params.min_new_token, - logits_processors=(logits_warpers, logits_processors), - eos_token=num_code, + logits_processors=(*logits_processors, *logits_warpers), infer_text=False, - start_idx=start_idx, + return_hidden=return_hidden, + stream=stream, + show_tqdm=params.show_tqdm, + ensure_non_empty=params.ensure_non_empty, + stream_batch=params.stream_batch, + context=self.context, ) - input_ids = [i.tolist() for i in input_ids] - result = gpt.generate( - None, - sample_params, - input_ids, - ) + del emb, input_ids + del_all(logits_warpers) + del_all(logits_processors) - token_ids = [] - hidden_states = [] - for i in result: - token_ids.append(torch.tensor(i.outputs[0].token_ids)) - hidden_states.append( - i.outputs[0].hidden_states.to(torch.float32).to(self.device) - ) - return [ - self.GenerationOutputs(ids=token_ids, hiddens=hidden_states), - ] + return result @torch.no_grad() def _refine_text( @@ -580,28 +613,57 @@ def _refine_text( repetition_penalty=params.repetition_penalty, ) - if not gpt.is_vllm: - - emb = gpt(input_ids, text_mask) + if gpt.is_vllm: + from .model.velocity.sampling_params import SamplingParams + sample_params = SamplingParams( + temperature=params.temperature, + max_new_token=params.max_new_token, + max_tokens=8192, + min_new_token=params.min_new_token, + logits_processors=(logits_processors, logits_warpers), + eos_token=self.tokenizer.eos_token, + infer_text=True, + start_idx=start_idx, + ) + input_ids = [i.tolist() for i in input_ids] + + result = gpt.llm.generate(None, sample_params, input_ids) + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append(i.outputs[0].hidden_states) + + del text_mask, input_ids + del_all(logits_warpers) + del_all(logits_processors) + + return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states) + + emb = gpt(input_ids, text_mask) + + del text_mask + + result = next( + gpt.generate( + emb, + input_ids, + temperature=torch.tensor([params.temperature], device=device), + eos_token=self.tokenizer.eos_token, + attention_mask=attention_mask, + max_new_token=params.max_new_token, + min_new_token=params.min_new_token, + logits_processors=(*logits_processors, *logits_warpers), + infer_text=True, + stream=False, + show_tqdm=params.show_tqdm, + ensure_non_empty=params.ensure_non_empty, + context=self.context, + ) + ) - del text_mask + del emb, input_ids + del_all(logits_warpers) + del_all(logits_processors) - sample_params = SamplingParams( - temperature=params.temperature, - max_new_token=params.max_new_token, - max_tokens=8192, - min_new_token=params.min_new_token, - logits_processors=(logits_warpers, logits_processors), - eos_token=self.tokenizer.eos_token, - infer_text=True, - start_idx=start_idx, - ) - input_ids = [i.tolist() for i in input_ids] - - result = gpt.generate(None, sample_params, input_ids) - token_ids = [] - hidden_states = [] - for i in result: - token_ids.append(torch.tensor(i.outputs[0].token_ids)) - hidden_states.append(i.outputs[0].hidden_states) - return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states) + return result diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index c0fd769a0..133a8f0dc 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,7 +1,7 @@ import os, platform from dataclasses import dataclass import logging -from typing import Union, List, Optional, Tuple +from typing import Union, List, Optional, Tuple, Callable import gc from pathlib import Path @@ -12,15 +12,12 @@ import torch.nn.utils.parametrize as P from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm -from transformers import LlamaModel, LlamaConfig, LogitsWarper +from transformers import LlamaModel, LlamaConfig from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import is_flash_attn_2_available -from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat from ..utils import del_all -from .velocity.llm import LLM -from .velocity.post_model import PostModel class GPT(nn.Module): @@ -95,6 +92,8 @@ def __init__( def from_pretrained(self, file_path: str): if self.is_vllm and platform.system().lower() == "linux": + from .velocity.llm import LLM + from .velocity.post_model import PostModel vllm_folder = Path(os.getcwd()) / "asset" / "vllm" if not os.path.exists(vllm_folder): self.logger.info("initializing vLLM model to %s", str(vllm_folder)) @@ -406,8 +405,7 @@ def generate( attention_mask: Optional[torch.Tensor] = None, max_new_token=2048, min_new_token=0, - logits_warpers: List[LogitsWarper] = [], - logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], + logits_processors: Tuple[Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]] = (), infer_text=False, return_attn=False, return_hidden=False, @@ -571,9 +569,6 @@ def generate( for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) - for logitsWarpers in logits_warpers: - logits = logitsWarpers(logits_token, logits) - del logits_token if i < min_new_token: @@ -631,7 +626,6 @@ def generate( attention_mask, max_new_token, min_new_token, - logits_warpers, logits_processors, infer_text, return_attn, From b165532bdc51bda62ce4f09be19576cc1589909a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:14:32 +0800 Subject: [PATCH 15/27] chore: remove unnecessary files --- ChatTTS/core.py | 4 +--- test.py | 29 ----------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) delete mode 100644 test.py diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 9b3b18d67..713896201 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -604,8 +604,6 @@ def _refine_text( device=self.device_gpt, ) - start_idx = input_ids.shape[-2] - # print(start_idx) logits_warpers, logits_processors = gen_logits( num_code=self.tokenizer.len, top_P=params.top_P, @@ -623,7 +621,7 @@ def _refine_text( logits_processors=(logits_processors, logits_warpers), eos_token=self.tokenizer.eos_token, infer_text=True, - start_idx=start_idx, + start_idx=input_ids.shape[-2], ) input_ids = [i.tolist() for i in input_ids] diff --git a/test.py b/test.py deleted file mode 100644 index 2690b28fd..000000000 --- a/test.py +++ /dev/null @@ -1,29 +0,0 @@ -import ChatTTS as ChatTTS -import torch -import torchaudio -import soundfile as sf - -chat = ChatTTS.Chat() -chat.load(compile=False) # Set to True for better performance -rand_spk = chat.sample_random_speaker() -print(rand_spk) # save it for later timbre recovery - -params_infer_code = ChatTTS.Chat.InferCodeParams( - spk_emb=rand_spk, # add sampled speaker - temperature=0.3, # using custom temperature - top_P=0.7, # top P decode - top_K=20, # top K decode -) -params_refine_text = ChatTTS.Chat.RefineTextParams( - prompt="[oral_2][laugh_0][break_6]", -) -texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] - -wavs = chat.infer( - texts, - params_refine_text=params_refine_text, - params_infer_code=params_infer_code, -) - -# torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) -sf.write("output1.wav", wavs[1], 24000) From e508fee76c48d8e4ffdcb14a1d78a9ffeb0efb85 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:16:56 +0800 Subject: [PATCH 16/27] chore(format): run black on dev (#626) Co-authored-by: github-actions[bot] --- ChatTTS/core.py | 2 ++ ChatTTS/model/gpt.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 713896201..6a29f049f 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -516,6 +516,7 @@ def _infer_code( if gpt.is_vllm: from .model.velocity.sampling_params import SamplingParams + sample_params = SamplingParams( temperature=temperature, max_new_token=params.max_new_token, @@ -613,6 +614,7 @@ def _refine_text( if gpt.is_vllm: from .model.velocity.sampling_params import SamplingParams + sample_params = SamplingParams( temperature=params.temperature, max_new_token=params.max_new_token, diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 133a8f0dc..f7fdbe616 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -94,6 +94,7 @@ def from_pretrained(self, file_path: str): if self.is_vllm and platform.system().lower() == "linux": from .velocity.llm import LLM from .velocity.post_model import PostModel + vllm_folder = Path(os.getcwd()) / "asset" / "vllm" if not os.path.exists(vllm_folder): self.logger.info("initializing vLLM model to %s", str(vllm_folder)) @@ -405,7 +406,9 @@ def generate( attention_mask: Optional[torch.Tensor] = None, max_new_token=2048, min_new_token=0, - logits_processors: Tuple[Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]] = (), + logits_processors: Tuple[ + Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor] + ] = (), infer_text=False, return_attn=False, return_hidden=False, From a9af30ba847a4eeff7c1ee0624b5659aa5988711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:24:47 +0800 Subject: [PATCH 17/27] fix(vllm): No module named 'ChatTTS.model' --- ChatTTS/model/velocity/block_manager.py | 2 +- ChatTTS/model/velocity/llm.py | 8 ++++---- ChatTTS/model/velocity/llm_engine.py | 12 ++++++------ ChatTTS/model/velocity/model_loader.py | 2 +- ChatTTS/model/velocity/model_runner.py | 10 +++++----- ChatTTS/model/velocity/output.py | 2 +- ChatTTS/model/velocity/scheduler.py | 4 ++-- ChatTTS/model/velocity/sequence.py | 2 +- ChatTTS/model/velocity/worker.py | 2 +- requirements.txt | 1 + 10 files changed, 23 insertions(+), 22 deletions(-) diff --git a/ChatTTS/model/velocity/block_manager.py b/ChatTTS/model/velocity/block_manager.py index ad69aa1b9..869a285c0 100644 --- a/ChatTTS/model/velocity/block_manager.py +++ b/ChatTTS/model/velocity/block_manager.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Set, Tuple from vllm.block import PhysicalTokenBlock -from ChatTTS.model.velocity.sequence import Sequence, SequenceGroup, SequenceStatus +from .sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device # Mapping: logical block number -> physical block. diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index 98a90af26..b473b562c 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -3,10 +3,10 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from ChatTTS.model.velocity.configs import EngineArgs -from ChatTTS.model.velocity.llm_engine import LLMEngine -from ChatTTS.model.velocity.output import RequestOutput -from ChatTTS.model.velocity.sampling_params import SamplingParams +from .configs import EngineArgs +from .llm_engine import LLMEngine +from .output import RequestOutput +from .sampling_params import SamplingParams from vllm.utils import Counter diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py index 66dd205ff..0d144d0fd 100644 --- a/ChatTTS/model/velocity/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -5,14 +5,14 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig -from ChatTTS.model.velocity.scheduler import Scheduler, SchedulerOutputs -from ChatTTS.model.velocity.configs import EngineArgs +from .scheduler import Scheduler, SchedulerOutputs +from .configs import EngineArgs from vllm.engine.metrics import record_metrics from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.logger import init_logger -from ChatTTS.model.velocity.output import RequestOutput -from ChatTTS.model.velocity.sampling_params import SamplingParams -from ChatTTS.model.velocity.sequence import ( +from .output import RequestOutput +from .sampling_params import SamplingParams +from .sequence import ( SamplerOutput, Sequence, SequenceGroup, @@ -127,7 +127,7 @@ def __init__( def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from ChatTTS.model.velocity.worker import Worker + from .worker import Worker assert ( self.parallel_config.world_size == 1 diff --git a/ChatTTS/model/velocity/model_loader.py b/ChatTTS/model/velocity/model_loader.py index 40de6d960..2007a96a2 100644 --- a/ChatTTS/model/velocity/model_loader.py +++ b/ChatTTS/model/velocity/model_loader.py @@ -24,7 +24,7 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: model_cls = getattr( - importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None + importlib.import_module(".llama"), "LlamaModel", None ) return model_cls diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index e59b3fffc..5becd0aa6 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -5,16 +5,16 @@ import torch import torch.nn as nn -from ChatTTS.model.velocity.configs import ModelConfig, ParallelConfig, SchedulerConfig +from .configs import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger -from ChatTTS.model.velocity.model_loader import get_model +from .model_loader import get_model from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast, broadcast_object_list, ) -from ChatTTS.model.velocity.sampling_params import SamplingParams, SamplingType -from ChatTTS.model.velocity.sequence import ( +from .sampling_params import SamplingParams, SamplingType +from .sequence import ( SamplerOutput, SequenceData, SequenceGroupMetadata, @@ -22,7 +22,7 @@ SequenceOutput, ) from vllm.utils import in_wsl -from ChatTTS.model.velocity.post_model import PostModel, Sampler +from .post_model import PostModel, Sampler from safetensors.torch import safe_open logger = init_logger(__name__) diff --git a/ChatTTS/model/velocity/output.py b/ChatTTS/model/velocity/output.py index 3413a3e2b..05cc54600 100644 --- a/ChatTTS/model/velocity/output.py +++ b/ChatTTS/model/velocity/output.py @@ -1,7 +1,7 @@ from typing import List, Optional import torch -from ChatTTS.model.velocity.sequence import ( +from .sequence import ( PromptLogprobs, SampleLogprobs, SequenceGroup, diff --git a/ChatTTS/model/velocity/scheduler.py b/ChatTTS/model/velocity/scheduler.py index 4eb38d278..97d9cb450 100644 --- a/ChatTTS/model/velocity/scheduler.py +++ b/ChatTTS/model/velocity/scheduler.py @@ -3,10 +3,10 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig -from ChatTTS.model.velocity.block_manager import AllocStatus, BlockSpaceManager +from .block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger -from ChatTTS.model.velocity.sequence import ( +from .sequence import ( Sequence, SequenceData, SequenceGroup, diff --git a/ChatTTS/model/velocity/sequence.py b/ChatTTS/model/velocity/sequence.py index 4bc3f354d..76f9cf4e7 100644 --- a/ChatTTS/model/velocity/sequence.py +++ b/ChatTTS/model/velocity/sequence.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union import torch from vllm.block import LogicalTokenBlock -from ChatTTS.model.velocity.sampling_params import SamplingParams +from .sampling_params import SamplingParams PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py index 9578551d9..90aca7f32 100644 --- a/ChatTTS/model/velocity/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -12,7 +12,7 @@ from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine -from ChatTTS.model.velocity.model_runner import ModelRunner +from .model_runner import ModelRunner class Worker: diff --git a/requirements.txt b/requirements.txt index a75c42f3c..8d8066224 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ nemo_text_processing; sys_platform == 'linux' av pydub safetensors +vllm>=0.2.7; sys_platform == 'linux' From 9d7c437de454afab4d4959c8f320617f8a514046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:35:15 +0800 Subject: [PATCH 18/27] doc: add vLLM instruction --- ChatTTS/model/gpt.py | 5 +++-- README.md | 9 +++++++-- requirements.txt | 2 -- setup.py | 1 - 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index f7fdbe616..5c9f63283 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -5,7 +5,6 @@ import gc from pathlib import Path -from safetensors.torch import save_file import torch import torch.nn as nn import torch.nn.functional as F @@ -92,6 +91,8 @@ def __init__( def from_pretrained(self, file_path: str): if self.is_vllm and platform.system().lower() == "linux": + from safetensors.torch import save_file + from .velocity.llm import LLM from .velocity.post_model import PostModel @@ -104,7 +105,7 @@ def from_pretrained(self, file_path: str): gpt.gpt.save_pretrained(vllm_folder / "gpt") post_model = ( PostModel( - int(self.gpt.config.hidden_size), + int(gpt.gpt.config.hidden_size), self.num_audio_tokens, self.num_text_tokens, ) diff --git a/README.md b/README.md index 41bfc830b..e31408cc0 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,12 @@ conda activate chattts pip install -r requirements.txt ``` -#### Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) +#### Optional: Install vLLM (Linux only) +```bash +pip install safetensors vllm==0.2.7 torchaudio +``` + +#### Unrecommended Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) > [!Note] > The installation process is very slow. @@ -113,7 +118,7 @@ pip install -r requirements.txt pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable ``` -#### Optional: Install FlashAttention-2 (mainly NVIDIA GPU) +#### Unrecommended Optional: Install FlashAttention-2 (mainly NVIDIA GPU) > [!Note] > See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). diff --git a/requirements.txt b/requirements.txt index 8d8066224..75066bb96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,3 @@ WeTextProcessing; sys_platform == 'linux' nemo_text_processing; sys_platform == 'linux' av pydub -safetensors -vllm>=0.2.7; sys_platform == 'linux' diff --git a/setup.py b/setup.py index c5fe0a69e..da7b609e6 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ "transformers>=4.41.1", "vector_quantize_pytorch", "vocos", - "safetensors", ], platforms="any", classifiers=[ From 4f72f4a23da353bd780133a83241d6dc9f694861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:39:40 +0800 Subject: [PATCH 19/27] fix(vLLM): importlib relative import --- ChatTTS/model/velocity/model_loader.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/ChatTTS/model/velocity/model_loader.py b/ChatTTS/model/velocity/model_loader.py index 2007a96a2..01c1883f5 100644 --- a/ChatTTS/model/velocity/model_loader.py +++ b/ChatTTS/model/velocity/model_loader.py @@ -1,16 +1,15 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type import torch import torch.nn as nn -from transformers import PretrainedConfig from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights -import importlib + +from .llama import LlamaModel @contextlib.contextmanager @@ -22,16 +21,7 @@ def _set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) -def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - model_cls = getattr( - importlib.import_module(".llama"), "LlamaModel", None - ) - return model_cls - - def get_model(model_config: ModelConfig) -> nn.Module: - model_class = _get_model_architecture(model_config.hf_config) - # Get the (maybe quantized) linear method. linear_method = None if model_config.quantization is not None: @@ -63,7 +53,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - model = model_class(model_config.hf_config, linear_method) + model = LlamaModel(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. From 00cd9436eff71c4f12f0539f4571d7a8ac19778c Mon Sep 17 00:00:00 2001 From: YuriHead Date: Thu, 25 Jul 2024 03:15:15 +0800 Subject: [PATCH 20/27] chore: optimize tensor padding in model_runner.py (#628) --- ChatTTS/model/velocity/model_runner.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 5becd0aa6..e1e8d038e 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -536,6 +536,9 @@ def execute_model( ) # print(hidden_states.shape) # print(input_tokens) + B_NO_PAD = input_tokens_history.shape[0] + input_tokens = input_tokens[:B_NO_PAD, :, :] + hidden_states = hidden_states[:B_NO_PAD, :, :] idx_next, logprob, finish = self.sampler.sample( inputs_ids=( input_tokens @@ -774,13 +777,17 @@ def _make_tensor_with_pad( device: Union[str, torch.device] = "cuda", pin_memory: bool = False, ) -> torch.Tensor: - padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor( - padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu", - ) + padded_x = [] + for x_i in x: + pad_i = pad + if isinstance(x[0][0],tuple): + pad_i = (0,) * len(x[0][0]) + padded_x.append(_pad_to_max(x_i, max_len, pad_i)) + + return torch.tensor(padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu") def _get_graph_batch_size(batch_size: int) -> int: From 6a4c97b80a6275f6ebd1f71a51b763e78c5b873f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 03:24:41 +0800 Subject: [PATCH 21/27] fix(core): no attribute 'GenerationOutputs' --- ChatTTS/core.py | 4 ++-- ChatTTS/model/velocity/model_runner.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 6a29f049f..26b780d95 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -548,7 +548,7 @@ def _infer_code( del_all(logits_processors) return [ - self.GenerationOutputs(ids=token_ids, hiddens=hidden_states), + GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states), ] emb = gpt(input_ids, text_mask) @@ -638,7 +638,7 @@ def _refine_text( del_all(logits_warpers) del_all(logits_processors) - return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states) + return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states) emb = gpt(input_ids, text_mask) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index e1e8d038e..073f59982 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -766,7 +766,9 @@ def __call__(self, *args, **kwargs): def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len - return x + [pad] * (max_len - len(x)) + if len(x) == max_len: + return list(x) + return list(x) + [pad] * (max_len - len(x)) def _make_tensor_with_pad( From a21bafaf743b44455be815bea45cd66458da6d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 03:26:37 +0800 Subject: [PATCH 22/27] fix(core): missig param --- ChatTTS/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 26b780d95..5824e8f96 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -548,7 +548,9 @@ def _infer_code( del_all(logits_processors) return [ - GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states), + GPT.GenerationOutputs( + ids=token_ids, hiddens=hidden_states, attentions=[], + ), ] emb = gpt(input_ids, text_mask) @@ -638,7 +640,8 @@ def _refine_text( del_all(logits_warpers) del_all(logits_processors) - return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states) + return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states, attentions=[], + ) emb = gpt(input_ids, text_mask) From 4991dfd93baeac7f55e9028eeff74b83e381072c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 03:32:15 +0800 Subject: [PATCH 23/27] chore(format): run black on dev (#627) Co-authored-by: github-actions[bot] --- ChatTTS/core.py | 11 ++++++++--- ChatTTS/model/velocity/model_runner.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 5824e8f96..bcbb804d1 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -549,7 +549,9 @@ def _infer_code( return [ GPT.GenerationOutputs( - ids=token_ids, hiddens=hidden_states, attentions=[], + ids=token_ids, + hiddens=hidden_states, + attentions=[], ), ] @@ -640,8 +642,11 @@ def _refine_text( del_all(logits_warpers) del_all(logits_processors) - return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states, attentions=[], - ) + return GPT.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states, + attentions=[], + ) emb = gpt(input_ids, text_mask) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 073f59982..a13df5990 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -782,14 +782,16 @@ def _make_tensor_with_pad( padded_x = [] for x_i in x: pad_i = pad - if isinstance(x[0][0],tuple): + if isinstance(x[0][0], tuple): pad_i = (0,) * len(x[0][0]) padded_x.append(_pad_to_max(x_i, max_len, pad_i)) - - return torch.tensor(padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu") + + return torch.tensor( + padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu", + ) def _get_graph_batch_size(batch_size: int) -> int: From 4a1962be95248f9db3f70779fb72dfed5bec4454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:00:29 +0800 Subject: [PATCH 24/27] feat(vLLM): add missing params in refine_text --- ChatTTS/core.py | 14 +++++++++----- ChatTTS/model/gpt.py | 3 +-- ChatTTS/model/velocity/__init__.py | 3 +++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index bcbb804d1..411769c53 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -515,7 +515,7 @@ def _infer_code( ) if gpt.is_vllm: - from .model.velocity.sampling_params import SamplingParams + from .model.velocity import SamplingParams sample_params = SamplingParams( temperature=temperature, @@ -617,10 +617,13 @@ def _refine_text( ) if gpt.is_vllm: - from .model.velocity.sampling_params import SamplingParams + from .model.velocity import SamplingParams sample_params = SamplingParams( + repetition_penalty=params.repetition_penalty, temperature=params.temperature, + top_p=params.top_P, + top_k=params.top_K, max_new_token=params.max_new_token, max_tokens=8192, min_new_token=params.min_new_token, @@ -629,16 +632,17 @@ def _refine_text( infer_text=True, start_idx=input_ids.shape[-2], ) - input_ids = [i.tolist() for i in input_ids] + input_ids_list = [i.tolist() for i in input_ids] + del input_ids - result = gpt.llm.generate(None, sample_params, input_ids) + result = gpt.llm.generate(None, sample_params, input_ids_list) token_ids = [] hidden_states = [] for i in result: token_ids.append(torch.tensor(i.outputs[0].token_ids)) hidden_states.append(i.outputs[0].hidden_states) - del text_mask, input_ids + del text_mask, input_ids_list, result del_all(logits_warpers) del_all(logits_processors) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 5c9f63283..413ab12e3 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -93,8 +93,7 @@ def from_pretrained(self, file_path: str): if self.is_vllm and platform.system().lower() == "linux": from safetensors.torch import save_file - from .velocity.llm import LLM - from .velocity.post_model import PostModel + from .velocity import LLM, PostModel vllm_folder = Path(os.getcwd()) / "asset" / "vllm" if not os.path.exists(vllm_folder): diff --git a/ChatTTS/model/velocity/__init__.py b/ChatTTS/model/velocity/__init__.py index e69de29bb..866983506 100644 --- a/ChatTTS/model/velocity/__init__.py +++ b/ChatTTS/model/velocity/__init__.py @@ -0,0 +1,3 @@ +from .llm import LLM +from .post_model import PostModel +from .sampling_params import SamplingParams From 319d037ae84376f2250090593ba0161a3f06ab39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:01:20 +0800 Subject: [PATCH 25/27] feat(vLLM): add missing params in refine_text --- ChatTTS/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 411769c53..75bcdec12 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -635,7 +635,7 @@ def _refine_text( input_ids_list = [i.tolist() for i in input_ids] del input_ids - result = gpt.llm.generate(None, sample_params, input_ids_list) + result = gpt.llm.generate(None, sample_params, input_ids_list, params.show_tqdm) token_ids = [] hidden_states = [] for i in result: From 2bfb0977d2e0f458b3ede2cab7f01fa673458893 Mon Sep 17 00:00:00 2001 From: YuriHead Date: Sat, 27 Jul 2024 18:01:19 +0800 Subject: [PATCH 26/27] chore: optimize tensor padding in model_runner.py (#639) --- ChatTTS/model/velocity/model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index a13df5990..a850f1075 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -569,13 +569,16 @@ def execute_model( for i in range(idx_next.shape[0]): idx_next_i = idx_next[i, 0, :].cpu().tolist() logprob_i = logprob[i].cpu().tolist() + tmp_hidden_states = hidden_states[i].cpu() + if input_tokens[i].shape[-2] != 1: + tmp_hidden_states = tmp_hidden_states[-1:,:] result = SequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq_groups[i], logprobs={tuple(idx_next_i): logprob_i}, output_token=tuple(idx_next_i), - hidden_states=hidden_states[i].cpu(), + hidden_states=tmp_hidden_states, finished=finish[i].item(), ), ], From 06b823be8b80ddd66fb1d42fd2a0fc6bf19c0203 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 28 Jul 2024 00:23:11 +0800 Subject: [PATCH 27/27] chore(format): run black on dev (#629) Co-authored-by: github-actions[bot] --- ChatTTS/core.py | 4 +++- ChatTTS/model/velocity/model_runner.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 75bcdec12..a4f4b6287 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -635,7 +635,9 @@ def _refine_text( input_ids_list = [i.tolist() for i in input_ids] del input_ids - result = gpt.llm.generate(None, sample_params, input_ids_list, params.show_tqdm) + result = gpt.llm.generate( + None, sample_params, input_ids_list, params.show_tqdm + ) token_ids = [] hidden_states = [] for i in result: diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index a850f1075..39d635866 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -571,7 +571,7 @@ def execute_model( logprob_i = logprob[i].cpu().tolist() tmp_hidden_states = hidden_states[i].cpu() if input_tokens[i].shape[-2] != 1: - tmp_hidden_states = tmp_hidden_states[-1:,:] + tmp_hidden_states = tmp_hidden_states[-1:, :] result = SequenceGroupOutput( samples=[ SequenceOutput(