diff --git a/ChatTTS/config/config.py b/ChatTTS/config/config.py index 1a3eeff14..b58fc261d 100644 --- a/ChatTTS/config/config.py +++ b/ChatTTS/config/config.py @@ -58,6 +58,7 @@ class GPT: spk_emb_dim: int = 192 spk_KL: bool = False num_audio_tokens: int = 626 + num_text_tokens: int = 21178 num_vq: int = 4 diff --git a/ChatTTS/core.py b/ChatTTS/core.py index fb00ba00d..a4f4b6287 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -5,14 +5,12 @@ from typing import Literal, Optional, List, Tuple, Dict, Union from json import load from pathlib import Path -import lzma import numpy as np import torch from vocos import Vocos from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download -import pybase16384 as b14 from .config import Config from .model import DVAE, GPT, gen_logits, Tokenizer @@ -130,6 +128,7 @@ def load( device: Optional[torch.device] = None, coef: Optional[torch.Tensor] = None, use_flash_attn=False, + use_vllm=False, ) -> bool: download_path = self.download_models(source, force_redownload, custom_path) if download_path is None: @@ -139,6 +138,7 @@ def load( compile=compile, coef=coef, use_flash_attn=use_flash_attn, + use_vllm=use_vllm, **{ k: os.path.join(download_path, v) for k, v in asdict(self.config.path).items() @@ -167,7 +167,7 @@ def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: @torch.no_grad() def _sample_random_speaker(self) -> torch.Tensor: - dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features + dim: int = self.config.gpt.hidden_size spk = ( torch.randn(dim, device=self.std.device, dtype=self.std.dtype) .mul_(self.std) @@ -246,11 +246,13 @@ def _load( compile: bool = True, coef: Optional[str] = None, use_flash_attn=False, + use_vllm=False, ): if device is None: device = select_device() self.logger.info("use device %s", str(device)) self.device = device + self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") self.compile = compile feature_extractor = instantiate_class( @@ -293,6 +295,7 @@ def _load( gpt = GPT( gpt_config=asdict(self.config.gpt), use_flash_attn=use_flash_attn, + use_vllm=use_vllm, device=device, logger=self.logger, ).eval() @@ -300,6 +303,7 @@ def _load( gpt.from_pretrained(gpt_ckpt_path) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt + spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" spk_stat: torch.Tensor = torch.load( @@ -469,7 +473,7 @@ def _infer_code( assert len(text), "text should not be empty" if not isinstance(params.temperature, list): - temperature = [params.temperature] * gpt.num_vq + temperature = [params.temperature] * self.config.gpt.num_vq else: temperature = params.temperature @@ -495,11 +499,62 @@ def _infer_code( input_ids, attention_mask, text_mask = self.tokenizer.encode( text, - self.gpt.num_vq, + self.config.gpt.num_vq, prompt_str=params.spk_smp, - device=gpt.device_gpt, + device=self.device_gpt, + ) + start_idx = input_ids.shape[-2] + + num_code = self.config.gpt.num_audio_tokens - 1 + + logits_warpers, logits_processors = gen_logits( + num_code=num_code, + top_P=params.top_P, + top_K=params.top_K, + repetition_penalty=params.repetition_penalty, ) + if gpt.is_vllm: + from .model.velocity import SamplingParams + + sample_params = SamplingParams( + temperature=temperature, + max_new_token=params.max_new_token, + max_tokens=8192, + min_new_token=params.min_new_token, + logits_processors=(logits_processors, logits_warpers), + eos_token=num_code, + infer_text=False, + start_idx=start_idx, + ) + input_ids = [i.tolist() for i in input_ids] + + result = gpt.llm.generate( + None, + sample_params, + input_ids, + ) + + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append( + i.outputs[0].hidden_states.to(torch.float32).to(self.device) + ) + + del text_mask, input_ids + del_all(logits_warpers) + del_all(logits_processors) + + return [ + GPT.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states, + attentions=[], + ), + ] + emb = gpt(input_ids, text_mask) del text_mask @@ -509,15 +564,6 @@ def _infer_code( emb, params.spk_emb, input_ids, self.gpt.device_gpt ) - num_code = int(gpt.emb_code[0].num_embeddings - 1) - - logits_warpers, logits_processors = gen_logits( - num_code=num_code, - top_P=params.top_P, - top_K=params.top_K, - repetition_penalty=params.repetition_penalty, - ) - result = gpt.generate( emb, input_ids, @@ -526,8 +572,7 @@ def _infer_code( attention_mask=attention_mask, max_new_token=params.max_new_token, min_new_token=params.min_new_token, - logits_warpers=logits_warpers, - logits_processors=logits_processors, + logits_processors=(*logits_processors, *logits_warpers), infer_text=False, return_hidden=return_hidden, stream=stream, @@ -560,8 +605,8 @@ def _refine_text( input_ids, attention_mask, text_mask = self.tokenizer.encode( text, - self.gpt.num_vq, - device=gpt.device_gpt, + self.config.gpt.num_vq, + device=self.device_gpt, ) logits_warpers, logits_processors = gen_logits( @@ -571,6 +616,44 @@ def _refine_text( repetition_penalty=params.repetition_penalty, ) + if gpt.is_vllm: + from .model.velocity import SamplingParams + + sample_params = SamplingParams( + repetition_penalty=params.repetition_penalty, + temperature=params.temperature, + top_p=params.top_P, + top_k=params.top_K, + max_new_token=params.max_new_token, + max_tokens=8192, + min_new_token=params.min_new_token, + logits_processors=(logits_processors, logits_warpers), + eos_token=self.tokenizer.eos_token, + infer_text=True, + start_idx=input_ids.shape[-2], + ) + input_ids_list = [i.tolist() for i in input_ids] + del input_ids + + result = gpt.llm.generate( + None, sample_params, input_ids_list, params.show_tqdm + ) + token_ids = [] + hidden_states = [] + for i in result: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append(i.outputs[0].hidden_states) + + del text_mask, input_ids_list, result + del_all(logits_warpers) + del_all(logits_processors) + + return GPT.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states, + attentions=[], + ) + emb = gpt(input_ids, text_mask) del text_mask @@ -584,8 +667,7 @@ def _refine_text( attention_mask=attention_mask, max_new_token=params.max_new_token, min_new_token=params.min_new_token, - logits_warpers=logits_warpers, - logits_processors=logits_processors, + logits_processors=(*logits_processors, *logits_warpers), infer_text=True, stream=False, show_tqdm=params.show_tqdm, diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 178c19446..413ab12e3 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -1,8 +1,9 @@ -import platform +import os, platform from dataclasses import dataclass import logging -from typing import Union, List, Optional, Tuple +from typing import Union, List, Optional, Tuple, Callable import gc +from pathlib import Path import torch import torch.nn as nn @@ -10,12 +11,11 @@ import torch.nn.utils.parametrize as P from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm -from transformers import LlamaModel, LlamaConfig, LogitsWarper +from transformers import LlamaModel, LlamaConfig from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import is_flash_attn_2_available -from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat from ..utils import del_all @@ -23,11 +23,10 @@ class GPT(nn.Module): def __init__( self, gpt_config: dict, - num_audio_tokens: int = 626, - num_text_tokens: int = 21178, - num_vq=4, use_flash_attn=False, + use_vllm=False, device=torch.device("cpu"), + device_gpt=torch.device("cpu"), logger=logging.getLogger(__name__), ): super().__init__() @@ -35,34 +34,41 @@ def __init__( self.logger = logger self.device = device - self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") + self.device_gpt = device_gpt - self.num_vq = num_vq - self.num_audio_tokens = num_audio_tokens + self.config = gpt_config + self.num_vq = int(gpt_config["num_vq"]) + self.num_audio_tokens = int(gpt_config["num_audio_tokens"]) + self.num_text_tokens = int(gpt_config["num_text_tokens"]) self.use_flash_attn = use_flash_attn + self.is_te_llama = False + self.is_vllm = use_vllm + + if self.is_vllm: + return self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) - self.is_te_llama = False + self.model_dim = int(self.gpt.config.hidden_size) self.emb_code = nn.ModuleList( [ nn.Embedding( - num_audio_tokens, + self.num_audio_tokens, self.model_dim, device=self.device_gpt, ) - for _ in range(num_vq) + for _ in range(self.num_vq) ], ) self.emb_text = nn.Embedding( - num_text_tokens, self.model_dim, device=self.device_gpt + self.num_text_tokens, self.model_dim, device=self.device_gpt ) self.head_text = weight_norm( nn.Linear( self.model_dim, - num_text_tokens, + self.num_text_tokens, bias=False, device=device, ), @@ -73,7 +79,7 @@ def __init__( weight_norm( nn.Linear( self.model_dim, - num_audio_tokens, + self.num_audio_tokens, bias=False, device=device, ), @@ -84,6 +90,44 @@ def __init__( ) def from_pretrained(self, file_path: str): + if self.is_vllm and platform.system().lower() == "linux": + from safetensors.torch import save_file + + from .velocity import LLM, PostModel + + vllm_folder = Path(os.getcwd()) / "asset" / "vllm" + if not os.path.exists(vllm_folder): + self.logger.info("initializing vLLM model to %s", str(vllm_folder)) + vllm_folder.mkdir(mode=0o755, parents=True, exist_ok=True) + gpt = GPT(gpt_config=self.config) + gpt.from_pretrained(file_path) + gpt.gpt.save_pretrained(vllm_folder / "gpt") + post_model = ( + PostModel( + int(gpt.gpt.config.hidden_size), + self.num_audio_tokens, + self.num_text_tokens, + ) + .to(self.device) + .eval() + ) + post_model.emb_code = gpt.emb_code + post_model.emb_text = gpt.emb_text + post_model.head_text = gpt.head_text + post_model.head_code = gpt.head_code + save_file( + post_model.state_dict(), + vllm_folder / "post_model.safetensors", + ) + del post_model, gpt + self.llm = LLM( + model=str(vllm_folder / "gpt"), + num_audio_tokens=self.num_audio_tokens, + num_text_tokens=self.num_text_tokens, + post_model_path=vllm_folder / "post_model.safetensors", + ) + self.logger.info("vLLM model loaded") + return self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True)) @@ -141,7 +185,7 @@ def _build_llama( def prepare(self, compile=False): if self.use_flash_attn and is_flash_attn_2_available(): self.gpt = self.gpt.to(dtype=torch.float16) - if compile and not self.is_te_llama: + if compile and not self.is_te_llama and not self.is_vllm: try: self.compile(backend="inductor", dynamic=True) self.gpt.compile(backend="inductor", dynamic=True) @@ -362,8 +406,9 @@ def generate( attention_mask: Optional[torch.Tensor] = None, max_new_token=2048, min_new_token=0, - logits_warpers: List[LogitsWarper] = [], - logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], + logits_processors: Tuple[ + Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor] + ] = (), infer_text=False, return_attn=False, return_hidden=False, @@ -527,9 +572,6 @@ def generate( for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) - for logitsWarpers in logits_warpers: - logits = logitsWarpers(logits_token, logits) - del logits_token if i < min_new_token: @@ -587,7 +629,6 @@ def generate( attention_mask, max_new_token, min_new_token, - logits_warpers, logits_processors, infer_text, return_attn, diff --git a/ChatTTS/model/velocity/__init__.py b/ChatTTS/model/velocity/__init__.py new file mode 100644 index 000000000..866983506 --- /dev/null +++ b/ChatTTS/model/velocity/__init__.py @@ -0,0 +1,3 @@ +from .llm import LLM +from .post_model import PostModel +from .sampling_params import SamplingParams diff --git a/ChatTTS/model/velocity/block_manager.py b/ChatTTS/model/velocity/block_manager.py new file mode 100644 index 000000000..869a285c0 --- /dev/null +++ b/ChatTTS/model/velocity/block_manager.py @@ -0,0 +1,296 @@ +"""A block manager that manages token blocks.""" + +import enum +from typing import Dict, List, Optional, Set, Tuple + +from vllm.block import PhysicalTokenBlock +from .sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] + + +class BlockAllocator: + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: BlockTable = [] + for i in range(num_blocks): + block = PhysicalTokenBlock( + device=device, block_number=i, block_size=block_size + ) + self.free_blocks.append(block) + + def allocate(self) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager: + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.block_sliding_window = None + if sliding_window is not None: + assert sliding_window % block_size == 0, (sliding_window, block_size) + self.block_sliding_window = sliding_window // block_size + + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + # Mapping: seq_id -> BlockTable. + self.block_tables: Dict[int, BlockTable] = {} + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = len(seq.logical_token_blocks) + if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, self.block_sliding_window) + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + + # Use watermark to avoid frequent cache eviction. + if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def allocate(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + # Allocate new physical token blocks that will store the prompt tokens. + block_table: BlockTable = [] + for logical_idx in range(len(seq.logical_token_blocks)): + if ( + self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window + ): + block = block_table[logical_idx % self.block_sliding_window] + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + # Assign the block table for each sequence. + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() + + def can_append_slot(self, seq_group: SequenceGroup) -> bool: + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + + def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + logical_blocks = seq.logical_token_blocks + block_table = self.block_tables[seq.seq_id] + + if len(block_table) < len(logical_blocks): + if ( + self.block_sliding_window + and len(block_table) >= self.block_sliding_window + ): + # re-use a block + block_table.append( + block_table[len(block_table) % self.block_sliding_window] + ) + else: + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + return None + + # We want to append the token to the last physical block. + last_block = block_table[-1] + assert last_block.device == Device.GPU + if last_block.ref_count == 1: + # Not shared with other sequences. Appendable. + return None + else: + # The last block is shared with other sequences. + # Copy on Write: Allocate a new block and copy the tokens. + new_block = self.gpu_allocator.allocate() + block_table[-1] = new_block + self.gpu_allocator.free(last_block) + return last_block.block_number, new_block.block_number + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + # NOTE: fork does not allocate a new physical block. + # Thus, it is always safe from OOM. + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.copy() + for block in src_block_table: + block.ref_count += 1 + + def _get_physical_blocks( + self, seq_group: SequenceGroup + ) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by + # the sequences in the same group. + blocks: Set[PhysicalTokenBlock] = set() + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + blocks.update(self.block_tables[seq.seq_id]) + return list(blocks) + + def can_swap_in(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + num_free_blocks = self.gpu_allocator.get_num_free_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate + # at least one free block right after the swap-in. + # NOTE: This should match the logic in can_append_slot(). + num_required_blocks = len(blocks) + num_swapped_seqs + return num_free_blocks - num_required_blocks >= self.watermark_blocks + + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + # CPU block -> GPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate() + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + cpu_block.block_number: gpu_block.block_number + for cpu_block, gpu_block in mapping.items() + } + return block_number_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + # GPU block -> CPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate() + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + gpu_block.block_number: cpu_block.block_number + for gpu_block, cpu_block in mapping.items() + } + return block_number_mapping + + def _free_block_table(self, block_table: BlockTable) -> None: + for block in set(block_table): + if block.device == Device.GPU: + self.gpu_allocator.free(block) + else: + self.cpu_allocator.free(block) + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + block_table = self.block_tables[seq.seq_id] + self._free_block_table(block_table) + del self.block_tables[seq.seq_id] + + def reset(self) -> None: + for block_table in self.block_tables.values(): + self._free_block_table(block_table) + self.block_tables.clear() + + def get_block_table(self, seq: Sequence) -> List[int]: + block_table = self.block_tables[seq.seq_id] + return [block.block_number for block in block_table] + + def get_num_free_gpu_blocks(self) -> int: + return self.gpu_allocator.get_num_free_blocks() + + def get_num_free_cpu_blocks(self) -> int: + return self.cpu_allocator.get_num_free_blocks() diff --git a/ChatTTS/model/velocity/configs.py b/ChatTTS/model/velocity/configs.py new file mode 100644 index 000000000..c578f468a --- /dev/null +++ b/ChatTTS/model/velocity/configs.py @@ -0,0 +1,865 @@ +from typing import Optional, Union, Tuple +import os + +import torch +from transformers import PretrainedConfig + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config +from vllm.utils import get_cpu_memory, is_hip + +import argparse +import dataclasses +from dataclasses import dataclass + + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + download_dir: Optional[str], + load_format: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + num_audio_tokens: int = 1024, + num_text_tokens: int = 80, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.download_dir = download_dir + self.load_format = load_format + self.seed = seed + self.revision = revision + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.enforce_eager = enforce_eager + self.max_context_len_to_capture = max_context_len_to_capture + self.num_audio_tokens = num_audio_tokens + self.num_text_tokens = num_text_tokens + + if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + from modelscope.hub.snapshot_download import ( + snapshot_download, + ) # pylint: disable=C + + model_path = snapshot_download( + model_id=model, cache_dir=download_dir, revision=revision + ) + self.model = model_path + self.download_dir = model_path + self.tokenizer = model_path + + self.hf_config = get_config(self.model, trust_remote_code, revision) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) + self._verify_load_format() + self._verify_tokenizer_mode() + self._verify_quantization() + self._verify_cuda_graph() + + def _verify_load_format(self) -> None: + load_format = self.load_format.lower() + supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"] + rocm_not_supported_load_format = [] + if load_format not in supported_load_format: + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." + ) + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f + for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}" + ) + + # TODO: Remove this check once HF updates the pt weights of Mixtral. + architectures = getattr(self.hf_config, "architectures", []) + if "MixtralForCausalLM" in architectures and load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. " + ) + self.load_format = load_format + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'." + ) + self.tokenizer_mode = tokenizer_mode + + def _verify_quantization(self) -> None: + supported_quantization = ["awq", "gptq", "squeezellm"] + rocm_not_supported_quantization = ["awq"] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + hf_quant_config = getattr(self.hf_config, "quantization_config", None) + if hf_quant_config is not None: + hf_quant_method = str(hf_quant_config["quant_method"]).lower() + if self.quantization is None: + self.quantization = hf_quant_method + elif self.quantization != hf_quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({hf_quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}." + ) + if is_hip() and self.quantization in rocm_not_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not supported " + f"in ROCm." + ) + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models." + ) + + def _verify_cuda_graph(self) -> None: + if self.max_context_len_to_capture is None: + self.max_context_len_to_capture = self.max_model_len + self.max_context_len_to_capture = min( + self.max_context_len_to_capture, self.max_model_len + ) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_num_attention_heads = self.hf_config.num_attention_heads + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size})." + ) + + total_num_hidden_layers = self.hf_config.num_hidden_layers + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if total_num_hidden_layers % pipeline_parallel_size != 0: + raise ValueError( + f"Total number of hidden layers ({total_num_hidden_layers}) " + "must be divisible by pipeline parallel size " + f"({pipeline_parallel_size})." + ) + + def get_sliding_window(self) -> Optional[int]: + return getattr(self.hf_config, "sliding_window", None) + + def get_vocab_size(self) -> int: + return self.hf_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_config.hidden_size + + def get_head_size(self) -> int: + # FIXME(woosuk): This may not be true for all models. + return self.hf_config.hidden_size // self.hf_config.num_attention_heads + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + total_num_hidden_layers = self.hf_config.num_hidden_layers + return total_num_hidden_layers // parallel_config.pipeline_parallel_size + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.sliding_window = sliding_window + self._verify_args() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}." + ) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = ( + f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space." + ) + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. " + msg) + + +class ParallelConfig: + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + worker_use_ray: Whether to use Ray for model workers. Will be set to + True if either pipeline_parallel_size or tensor_parallel_size is + greater than 1. + """ + + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, + ) -> None: + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + self.worker_use_ray = worker_use_ray + self.max_parallel_loading_workers = max_parallel_loading_workers + + self.world_size = pipeline_parallel_size * tensor_parallel_size + if self.world_size > 1: + self.worker_use_ray = True + self._verify_args() + + def _verify_args(self) -> None: + if self.pipeline_parallel_size > 1: + raise NotImplementedError("Pipeline parallelism is not supported yet.") + + +class SchedulerConfig: + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. + max_model_len: Maximum length of a sequence (including prompt + and generated text). + max_paddings: Maximum number of paddings to be added to a batch. + """ + + def __init__( + self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + max_paddings: int, + ) -> None: + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + # If max_model_len is too short, use 2048 as the default value for + # higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_seqs = max_num_seqs + self.max_model_len = max_model_len + self.max_paddings = max_paddings + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len." + ) + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs})." + ) + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + if is_hip() and torch_dtype == torch.float32: + rocm_supported_dtypes = [ + k + for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + if (k not in _ROCM_NOT_SUPPORTED_DTYPE) + ] + raise ValueError( + f"dtype '{dtype}' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}" + ) + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + f"{possible_keys}. Assuming the model's maximum length is " + f"{default_max_len}." + ) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling["original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if max_model_len is None: + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size." + ) + return int(max_model_len) + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + + model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = "auto" + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = "auto" + dtype: str = "auto" + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + block_size: int = 16 + swap_space: int = 4 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_paddings: int = 256 + disable_log_stats: bool = False + revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: bool = False + max_context_len_to_capture: int = 8192 + num_audio_tokens: int = 1024 + num_text_tokens: int = 80 + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # NOTE: If you update any of the arguments below, please also + # make sure to update docs/source/models/engine_args.rst + + # Model arguments + parser.add_argument( + "--model", + type=str, + default="facebook/opt-125m", + help="name or path of the huggingface model to use", + ) + parser.add_argument( + "--tokenizer", + type=str, + default=EngineArgs.tokenizer, + help="name or path of the huggingface tokenizer to use", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-revision", + type=str, + default=None, + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="trust remote code from huggingface", + ) + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, " + "default to the default cache dir of " + "huggingface", + ) + parser.add_argument( + "--load-format", + type=str, + default=EngineArgs.load_format, + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " + '"auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading. " + '"dummy" will initialize the weights with random values, ' + "which is mainly for profiling.", + ) + parser.add_argument( + "--dtype", + type=str, + default=EngineArgs.dtype, + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " + 'The "auto" option will use FP16 precision ' + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, " + "will be automatically derived from the model.", + ) + # Parallel arguments + parser.add_argument( + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be " + "automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", + type=int, + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) + parser.add_argument( + "--max-parallel-loading-workers", + type=int, + help="load model sequentially in multiple batches, " + "to avoid RAM OOM when using tensor " + "parallel and large models", + ) + # KV cache arguments + parser.add_argument( + "--block-size", + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32], + help="token block size", + ) + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument( + "--seed", type=int, default=EngineArgs.seed, help="random seed" + ) + parser.add_argument( + "--swap-space", + type=int, + default=EngineArgs.swap_space, + help="CPU swap space size (GiB) per GPU", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=EngineArgs.gpu_memory_utilization, + help="the fraction of GPU memory to be used for " + "the model executor, which can range from 0 to 1." + "If unspecified, will use the default value of 0.9.", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per " "iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument( + "--max-paddings", + type=int, + default=EngineArgs.max_paddings, + help="maximum number of paddings in a batch", + ) + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="disable logging statistics", + ) + # Quantization settings. + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", "gptq", "squeezellm", None], + default=None, + help="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Always use eager-mode PyTorch. If False, " + "will use eager mode and CUDA graph in hybrid " + "for maximal performance and flexibility.", + ) + parser.add_argument( + "--max-context-len-to-capture", + type=int, + default=EngineArgs.max_context_len_to_capture, + help="maximum context length covered by CUDA " + "graphs. When a sequence has context length " + "larger than this, we fall back to eager mode.", + ) + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_engine_configs( + self, + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + model_config = ModelConfig( + self.model, + self.tokenizer, + self.tokenizer_mode, + self.trust_remote_code, + self.download_dir, + self.load_format, + self.dtype, + self.seed, + self.revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + self.enforce_eager, + self.max_context_len_to_capture, + self.num_audio_tokens, + self.num_text_tokens, + ) + cache_config = CacheConfig( + self.block_size, + self.gpu_memory_utilization, + self.swap_space, + model_config.get_sliding_window(), + ) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + ) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings, + ) + return model_config, cache_config, parallel_config, scheduler_config + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + + engine_use_ray: bool = False + disable_log_requests: bool = False + max_log_len: Optional[int] = None + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument( + "--engine-use-ray", + action="store_true", + help="use Ray to start the LLM engine in a " + "separate process as the server process.", + ) + parser.add_argument( + "--disable-log-requests", + action="store_true", + help="disable logging requests", + ) + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="max number of prompt characters or prompt " + "ID numbers being printed in log. " + "Default: unlimited.", + ) + return parser diff --git a/ChatTTS/model/velocity/llama.py b/ChatTTS/model/velocity/llama.py new file mode 100644 index 000000000..8e6c8a896 --- /dev/null +++ b/ChatTTS/model/velocity/llama.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ParallelLMHead, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, linear_method=linear_method + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention( + self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = input_emb + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +class LlamaForCausalLM(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = LlamaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, sampling_metadata + ) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py new file mode 100644 index 000000000..b473b562c --- /dev/null +++ b/ChatTTS/model/velocity/llm.py @@ -0,0 +1,213 @@ +from typing import List, Optional, Union + +from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from .configs import EngineArgs +from .llm_engine import LLMEngine +from .output import RequestOutput +from .sampling_params import SamplingParams +from vllm.utils import Counter + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq" and "squeezellm". If None, we first check + the `quantization_config` attribute in the model config file. If + that is None, we assume the model weights are not quantized and use + `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_context_len_to_capture: int = 8192, + post_model_path: str = None, + num_audio_tokens: int = 0, + num_text_tokens: int = 0, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = EngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + num_audio_tokens=num_audio_tokens, + num_text_tokens=num_text_tokens, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) + self.request_counter = Counter() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + + Returns: + A list of `RequestOutput` objects containing the generated + completions in the same order as the input prompts. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + if ( + prompts is not None + and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids) + ): + raise ValueError( + "The lengths of prompts and prompt_token_ids " "must be the same." + ) + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + # Add requests to the engine. + num_requests = len(prompts) if prompts is not None else len(prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + token_ids = None if prompt_token_ids is None else prompt_token_ids[i] + self._add_request(prompt, sampling_params, token_ids) + + rtns = self._run_engine(use_tqdm) + for i, rtn in enumerate(rtns): + token_ids = rtn.outputs[0].token_ids + for j, token_id in enumerate(token_ids): + if len(token_id) == 1: + token_ids[j] = token_id[0] + else: + token_ids[j] = list(token_id) + + return rtns + + def _add_request( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]], + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request( + request_id, prompt, sampling_params, prompt_token_ids + ) + + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm(total=num_requests, desc="Processed prompts") + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py new file mode 100644 index 000000000..0d144d0fd --- /dev/null +++ b/ChatTTS/model/velocity/llm_engine.py @@ -0,0 +1,833 @@ +import copy +from collections import defaultdict +import os +import time +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union + +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig +from .scheduler import Scheduler, SchedulerOutputs +from .configs import EngineArgs +from vllm.engine.metrics import record_metrics +from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.logger import init_logger +from .output import RequestOutput +from .sampling_params import SamplingParams +from .sequence import ( + SamplerOutput, + Sequence, + SequenceGroup, + SequenceGroupOutput, + SequenceOutput, + SequenceStatus, +) +from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer +from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port +import numpy as np + +if ray: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + +_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The `LLM` class wraps this class for offline batched inference and the + `AsyncLLMEngine` class wraps this class for online serving. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + placement_group: Ray placement group for distributed execution. + Required for distributed execution. + log_stats: Whether to log statistics. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + placement_group: Optional["PlacementGroup"], + post_model_path: str, + log_stats: bool, + ) -> None: + logger.info( + "Initializing an LLM engine with config: " + f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " + f"tokenizer_revision={model_config.tokenizer_revision}, " + f"trust_remote_code={model_config.trust_remote_code}, " + f"dtype={model_config.dtype}, " + f"max_seq_len={model_config.max_model_len}, " + f"download_dir={model_config.download_dir!r}, " + f"load_format={model_config.load_format}, " + f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"quantization={model_config.quantization}, " + f"enforce_eager={model_config.enforce_eager}, " + f"seed={model_config.seed}), " + f"post_model_path={post_model_path!r}" + ) + # TODO(woosuk): Print more configs in debug mode. + + self.model_config = model_config + self.cache_config = cache_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.log_stats = log_stats + self._verify_args() + self.post_model_path = post_model_path + self.seq_counter = Counter() + + # Create the parallel GPU workers. + if self.parallel_config.worker_use_ray: + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + self._init_workers_ray(placement_group) + else: + self._init_workers() + + # Profile the memory usage and initialize the cache. + self._init_cache() + + # Create the scheduler. + self.scheduler = Scheduler(scheduler_config, cache_config) + + # Logging. + self.last_logging_time = 0.0 + # List of (timestamp, num_tokens) + self.num_prompt_tokens: List[Tuple[float, int]] = [] + # List of (timestamp, num_tokens) + self.num_generation_tokens: List[Tuple[float, int]] = [] + + def _init_workers(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker + + assert ( + self.parallel_config.world_size == 1 + ), "Ray is required if parallel_config.world_size > 1." + + self.workers: List[Worker] = [] + distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + post_model_path=self.post_model_path, + ) + self._run_workers("init_model") + self._run_workers("load_model") + + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): + if self.parallel_config.tensor_parallel_size == 1: + num_gpus = self.cache_config.gpu_memory_utilization + else: + num_gpus = 1 + + self.driver_dummy_worker: RayWorkerVllm = None + self.workers: List[RayWorkerVllm] = [] + + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + else: + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node." + ) + + driver_node_id, driver_gpu_ids = ray.get( + self.driver_dummy_worker.get_node_and_gpu_ids.remote() + ) + worker_node_and_gpu_ids = ray.get( + [worker.get_node_and_gpu_ids.remote() for worker in self.workers] + ) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + node_workers[driver_node_id].append(0) + node_gpus[driver_node_id].extend(driver_gpu_ids) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # Set CUDA_VISIBLE_DEVICES for the driver. + set_cuda_visible_devices(node_gpus[driver_node_id]) + for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + + distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}" + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + # Initialize torch distributed process group for the workers. + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + + for rank, (worker, (node_id, _)) in enumerate( + zip(self.workers, worker_node_and_gpu_ids), start=1 + ): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + local_rank, + rank, + distributed_init_method, + ) + ) + + driver_rank = 0 + driver_local_rank = node_workers[driver_node_id].index(driver_rank) + self.driver_worker = Worker( + model_config, + parallel_config, + scheduler_config, + driver_local_rank, + driver_rank, + distributed_init_method, + is_driver_worker=True, + ) + + self._run_workers("init_model") + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + ) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache.""" + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers( + "profile_num_available_blocks", + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + # FIXME(woosuk): Change to debug log. + logger.info( + f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}" + ) + + if num_gpu_blocks <= 0: + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) + max_seq_len = self.cache_config.block_size * num_gpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine." + ) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") + + @classmethod + def from_engine_args( + cls, engine_args: EngineArgs, post_model_path=None + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + # Initialize the cluster. + placement_group = initialize_cluster(parallel_config) + # Create the LLM engine. + engine = cls( + *engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats, + post_model_path=post_model_path, + ) + return engine + + def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + """ + if arrival_time is None: + arrival_time = time.monotonic() + + assert prompt_token_ids is not None, "prompt_token_ids must be provided" + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + """ + self.scheduler.abort_seq_group(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_seq_groups() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_seqs() + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + sampling_params.max_tokens, + self.scheduler_config.max_model_len, + ) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length, + ) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + ) + return current_worst_score >= highest_attainable_score + + def _process_sequence_group_outputs( + self, seq_group: SequenceGroup, outputs: SequenceGroupOutput + ) -> None: + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None: + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs} + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id( + child_sample.output_token, + child_sample.logprobs, + child_sample.hidden_states, + child_sample.finished, + ) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id( + last_child_sample.output_token, + last_child_sample.logprobs, + last_child_sample.hidden_states, + last_child_sample.finished, + ) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + # self._decode_sequence(seq, seq_group.sampling_params) + self._check_stop(seq, seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs] + new_finished_seqs = [ + (seq, parent, True) for seq, parent in child_seqs if seq.is_finished() + ] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [ + (seq, parent) for seq, parent in child_seqs if not seq.is_finished() + ] + # Sort the running sequences by their scores. + running_child_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, + best_running_seq, + current_worst_seq, + ) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _process_model_outputs( + self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs + ) -> List[RequestOutput]: + # Update the scheduled sequence groups with the model outputs. + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + for seq_group, outputs in zip(scheduled_seq_groups, output): + self._process_sequence_group_outputs(seq_group, outputs) + + # Free the finished sequence groups. + self.scheduler.free_finished_seq_groups() + + # Create the outputs. + request_outputs: List[RequestOutput] = [] + for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups: + request_output = RequestOutput.from_seq_group(seq_group) + request_outputs.append(request_output) + + if self.log_stats: + # Log the system stats. + self._log_system_stats( + scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens + ) + return request_outputs + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = self._run_workers( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }, + ) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] + + return self._process_model_outputs(output, scheduler_outputs) + + def _log_system_stats( + self, + prompt_run: bool, + num_batched_tokens: int, + ) -> None: + now = time.monotonic() + # Log the number of batched input tokens. + if prompt_run: + self.num_prompt_tokens.append((now, num_batched_tokens)) + else: + self.num_generation_tokens.append((now, num_batched_tokens)) + + should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC + if not should_log: + return + + # Discard the old stats. + self.num_prompt_tokens = [ + (t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC + ] + self.num_generation_tokens = [ + (t, n) + for t, n in self.num_generation_tokens + if now - t < _LOGGING_INTERVAL_SEC + ] + + if len(self.num_prompt_tokens) > 1: + total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) + window = now - self.num_prompt_tokens[0][0] + avg_prompt_throughput = total_num_tokens / window + else: + avg_prompt_throughput = 0.0 + if len(self.num_generation_tokens) > 1: + total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1]) + window = now - self.num_generation_tokens[0][0] + avg_generation_throughput = total_num_tokens / window + else: + avg_generation_throughput = 0.0 + + total_num_gpu_blocks = self.cache_config.num_gpu_blocks + num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks() + num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks + + total_num_cpu_blocks = self.cache_config.num_cpu_blocks + if total_num_cpu_blocks > 0: + num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks() + num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks + cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks + else: + cpu_cache_usage = 0.0 + + record_metrics( + avg_prompt_throughput=avg_prompt_throughput, + avg_generation_throughput=avg_generation_throughput, + scheduler_running=len(self.scheduler.running), + scheduler_swapped=len(self.scheduler.swapped), + scheduler_waiting=len(self.scheduler.waiting), + gpu_cache_usage=gpu_cache_usage, + cpu_cache_usage=cpu_cache_usage, + ) + + logger.info( + "Avg prompt throughput: " + f"{avg_prompt_throughput:.1f} tokens/s, " + "Avg generation throughput: " + f"{avg_generation_throughput:.1f} tokens/s, " + f"Running: {len(self.scheduler.running)} reqs, " + f"Swapped: {len(self.scheduler.swapped)} reqs, " + f"Pending: {len(self.scheduler.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%" + ) + self.last_logging_time = now + + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: + """Decodes the new token for a sequence.""" + (new_tokens, new_output_text, prefix_offset, read_offset) = ( + detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + ) + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_output_text + + def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + if not sampling_params.include_stop_str_in_output: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[: -len(stop_str)] + seq.status = SequenceStatus.FINISHED_STOPPED + return + if seq.data.finished: + seq.status = SequenceStatus.FINISHED_STOPPED + return + + for token_id in seq.get_last_token_id(): + if token_id == sampling_params.eos_token: + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has generated the EOS token. + if (not sampling_params.ignore_eos) and seq.get_last_token_id()[ + 0 + ] == sampling_params.eos_token: + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError("max_concurrent_workers is not supported yet.") + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, method)( + *driver_args, **driver_kwargs + ) + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs diff --git a/ChatTTS/model/velocity/model_loader.py b/ChatTTS/model/velocity/model_loader.py new file mode 100644 index 000000000..01c1883f5 --- /dev/null +++ b/ChatTTS/model/velocity/model_loader.py @@ -0,0 +1,69 @@ +"""Utilities for selecting and loading models.""" + +import contextlib + +import torch +import torch.nn as nn + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights + +from .llama import LlamaModel + + +@contextlib.contextmanager +def _set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model(model_config: ModelConfig) -> nn.Module: + # Get the (maybe quantized) linear method. + linear_method = None + if model_config.quantization is not None: + quant_config = get_quant_config( + model_config.quantization, + model_config.model, + model_config.hf_config, + model_config.download_dir, + ) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + linear_method = quant_config.get_linear_method() + + with _set_default_torch_dtype(model_config.dtype): + # Create a model instance. + # The weights will be initialized as empty tensors. + with torch.device("cuda"): + model = LlamaModel(model_config.hf_config, linear_method) + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) + return model.eval() diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py new file mode 100644 index 000000000..39d635866 --- /dev/null +++ b/ChatTTS/model/velocity/model_runner.py @@ -0,0 +1,811 @@ +import time +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from .configs import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.logger import init_logger +from .model_loader import get_model +from vllm.model_executor import InputMetadata, SamplingMetadata +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast, + broadcast_object_list, +) +from .sampling_params import SamplingParams, SamplingType +from .sequence import ( + SamplerOutput, + SequenceData, + SequenceGroupMetadata, + SequenceGroupOutput, + SequenceOutput, +) +from vllm.utils import in_wsl +from .post_model import PostModel, Sampler +from safetensors.torch import safe_open + +logger = init_logger(__name__) + +KVCache = Tuple[torch.Tensor, torch.Tensor] +_PAD_SLOT_ID = -1 +# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + is_driver_worker: bool = False, + post_model_path: str = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.is_driver_worker = is_driver_worker + self.post_model_path = post_model_path + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = ( + model_config.get_sliding_window() if model_config is not None else None + ) + self.model = None + self.block_size = None # Set after initial profiling. + + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None + else 0 + ) + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() + + def load_model(self) -> None: + self.model = get_model(self.model_config) + self.post_model = PostModel( + self.model_config.get_hidden_size(), + self.model_config.num_audio_tokens, + self.model_config.num_text_tokens, + ) + state_dict_tensors = {} + with safe_open(self.post_model_path, framework="pt", device=0) as f: + for k in f.keys(): + state_dict_tensors[k] = f.get_tensor(k) + self.post_model.load_state_dict(state_dict_tensors) + self.post_model.to(next(self.model.parameters())).eval() + self.sampler = Sampler(self.post_model, self.model_config.num_audio_tokens, 4) + + def set_block_size(self, block_size: int) -> None: + self.block_size = block_size + + max_num_blocks = ( + self.max_context_len_to_capture + block_size - 1 + ) // block_size + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32 + ) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + + prompt_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(prompt_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + for i in range(prompt_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(prompt_lens) + input_tokens = _make_tensor_with_pad( + input_tokens, max_prompt_len, pad=0, dtype=torch.long + ) + input_positions = _make_tensor_with_pad( + input_positions, max_prompt_len, pad=0, dtype=torch.long + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long + ) + + input_metadata = InputMetadata( + is_prompt=True, + slot_mapping=slot_mapping, + max_context_len=None, + context_lens=None, + block_tables=None, + use_cuda_graph=False, + ) + return input_tokens, input_positions, input_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + + context_len = ( + seq_len + if self.sliding_window is None + else min(seq_len, self.sliding_window) + ) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + if self.sliding_window is not None: + sliding_window_blocks = self.sliding_window // self.block_size + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + batch_size = len(input_tokens) + max_context_len = max(context_lens) + use_captured_graph = ( + not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_context_len <= self.max_context_len_to_capture + ) + if use_captured_graph: + # Pad the input tokens, positions, and slot mapping to match the + # batch size of the captured graph. + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append([]) + input_positions.append([]) + slot_mapping.append([]) + context_lens.append(1) + block_tables.append([]) + batch_size = graph_batch_size + + input_tokens = _make_tensor_with_pad( + input_tokens, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + input_positions = _make_tensor_with_pad( + input_positions, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, device="cuda" + ) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, : len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device="cuda") + else: + block_tables = _make_tensor_with_pad( + block_tables, + max_len=max_context_len, + pad=0, + dtype=torch.int, + device="cuda", + ) + + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + return input_tokens, input_positions, input_metadata + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + + max_prompt_len = max(prompt_lens) if prompt_lens else 1 + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + prompt_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx + ) + categorized_sample_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range( + selected_token_start_idx, + selected_token_start_idx + prompt_len - 1, + ) + ) + selected_token_indices.append(selected_token_start_idx + prompt_len - 1) + selected_token_start_idx += max_prompt_len + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, selected_token_start_idx + num_seqs) + ) + selected_token_start_idx += num_seqs + + categorized_sample_indices[sampling_params.sampling_type].extend( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs, + ) + ) + categorized_sample_indices_start_idx += num_seqs + + selected_token_indices = _async_h2d( + selected_token_indices, dtype=torch.long, pin_memory=not self.in_wsl + ) + categorized_sample_indices = { + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_metadata, prompt_lens) = ( + self._prepare_prompt(seq_group_metadata_list) + ) + else: + (input_tokens, input_positions, input_metadata) = self._prepare_decode( + seq_group_metadata_list + ) + prompt_lens = [] + sampling_metadata = self._prepare_sample( + seq_group_metadata_list, prompt_lens + ) + + def get_size_or_none(x: Optional[torch.Tensor]): + return x.size() if x is not None else None + + # Broadcast the input data. For input tensors, we first broadcast + # its shape and then broadcast the tensor to avoid high + # serialization cost. + py_data = { + "input_tokens_size": input_tokens.size(), + "input_positions_size": input_positions.size(), + "is_prompt": input_metadata.is_prompt, + "slot_mapping_size": get_size_or_none(input_metadata.slot_mapping), + "max_context_len": input_metadata.max_context_len, + "context_lens_size": get_size_or_none(input_metadata.context_lens), + "block_tables_size": get_size_or_none(input_metadata.block_tables), + "use_cuda_graph": input_metadata.use_cuda_graph, + "selected_token_indices_size": sampling_metadata.selected_token_indices.size(), + } + broadcast_object_list([py_data], src=0) + # TODO(zhuohan): Combine the broadcasts or set async_op=True. + broadcast(input_tokens, src=0) + broadcast(input_positions, src=0) + if input_metadata.slot_mapping is not None: + broadcast(input_metadata.slot_mapping, src=0) + if input_metadata.context_lens is not None: + broadcast(input_metadata.context_lens, src=0) + if input_metadata.block_tables is not None: + broadcast(input_metadata.block_tables, src=0) + broadcast(sampling_metadata.selected_token_indices, src=0) + else: + receving_list = [None] + broadcast_object_list(receving_list, src=0) + py_data = receving_list[0] + input_tokens = torch.empty( + *py_data["input_tokens_size"], dtype=torch.long, device="cuda" + ) + broadcast(input_tokens, src=0) + input_positions = torch.empty( + *py_data["input_positions_size"], dtype=torch.long, device="cuda" + ) + broadcast(input_positions, src=0) + if py_data["slot_mapping_size"] is not None: + slot_mapping = torch.empty( + *py_data["slot_mapping_size"], dtype=torch.long, device="cuda" + ) + broadcast(slot_mapping, src=0) + else: + slot_mapping = None + if py_data["context_lens_size"] is not None: + context_lens = torch.empty( + *py_data["context_lens_size"], dtype=torch.int, device="cuda" + ) + broadcast(context_lens, src=0) + else: + context_lens = None + if py_data["block_tables_size"] is not None: + block_tables = torch.empty( + *py_data["block_tables_size"], dtype=torch.int, device="cuda" + ) + broadcast(block_tables, src=0) + else: + block_tables = None + selected_token_indices = torch.empty( + *py_data["selected_token_indices_size"], dtype=torch.long, device="cuda" + ) + broadcast(selected_token_indices, src=0) + input_metadata = InputMetadata( + is_prompt=py_data["is_prompt"], + slot_mapping=slot_mapping, + max_context_len=py_data["max_context_len"], + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=py_data["use_cuda_graph"], + ) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + perform_sampling=False, + ) + + return input_tokens, input_positions, input_metadata, sampling_metadata + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Optional[SamplerOutput]: + input_tokens, input_positions, input_metadata, sampling_metadata = ( + self.prepare_input_tensors(seq_group_metadata_list) + ) + # print(sampling_metadata.seq_data) + seq_groups = [] + input_tokens_history = [] + for i, rtn in enumerate(sampling_metadata.seq_groups): + seq_groups.append(rtn[0][0]) + tokens_history = sampling_metadata.seq_data[rtn[0][0]].output_token_ids + if len(tokens_history) >= 1: + if len(tokens_history[0]) == 1: + tokens_history = [token[0] for token in tokens_history] + else: + tokens_history = [list(token) for token in tokens_history] + input_tokens_history.append(tokens_history) + input_tokens_history = torch.tensor(input_tokens_history).to( + input_tokens.device + ) + # token_ids = rtn.outputs[0].token_ids + # for j, token_id in enumerate(token_ids): + # if len(token_id) == 1: + # token_ids[j] = token_id[0] + # else: + # token_ids[j] = list(token_id) + + # Execute the model. + # print("it1",input_tokens) + if len(input_tokens.shape) == 2: + input_tokens = input_tokens.unsqueeze(2).repeat(1, 1, 4) + if len(input_tokens_history.shape) == 2: + input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4) + # print(input_tokens_history.shape) + # print("it2",input_tokens.shape) + text_mask = input_tokens != 0 + text_mask = text_mask[:, :, 0] + + if input_metadata.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + infer_text = sampling_metadata.seq_groups[0][1].infer_text + temperture = sampling_metadata.seq_groups[0][1].temperature + if not infer_text: + temperture = torch.tensor(temperture).to(input_tokens.device) + logits_processors, logits_warpers = sampling_metadata.seq_groups[0][ + 1 + ].logits_processors + # print(logits_processors, logits_warpers) + min_new_token = sampling_metadata.seq_groups[0][1].min_new_token + eos_token = sampling_metadata.seq_groups[0][1].eos_token + start_idx = sampling_metadata.seq_groups[0][1].start_idx + if input_tokens.shape[-2] == 1: + if infer_text: + input_emb: torch.Tensor = self.post_model.emb_text( + input_tokens[:, :, 0] + ) + else: + code_emb = [ + self.post_model.emb_code[i](input_tokens[:, :, i]) + for i in range(self.post_model.num_vq) + ] + input_emb = torch.stack(code_emb, 3).sum(3) + else: + input_emb = self.post_model(input_tokens, text_mask) + # print(input_emb.shape) + hidden_states = model_executable( + input_emb=input_emb, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + # print(hidden_states.shape) + # print(input_tokens) + B_NO_PAD = input_tokens_history.shape[0] + input_tokens = input_tokens[:B_NO_PAD, :, :] + hidden_states = hidden_states[:B_NO_PAD, :, :] + idx_next, logprob, finish = self.sampler.sample( + inputs_ids=( + input_tokens + if input_tokens_history.shape[-2] == 0 + else input_tokens_history + ), + hidden_states=hidden_states, + infer_text=infer_text, + temperature=temperture, + logits_processors=logits_processors, + logits_warpers=logits_warpers, + min_new_token=min_new_token, + now_length=1, + eos_token=eos_token, + start_idx=start_idx, + ) + # print(logprob.shape, idx_next.shape) + if len(logprob.shape) == 2: + logprob = logprob[:, None, :] + logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0] + # print("测试",idx_next.shape, logprob.shape) + # Sample the next token. + # output = self.model.sample( + # hidden_states=hidden_states, + # sampling_metadata=sampling_metadata, + # ) + results = [] + for i in range(idx_next.shape[0]): + idx_next_i = idx_next[i, 0, :].cpu().tolist() + logprob_i = logprob[i].cpu().tolist() + tmp_hidden_states = hidden_states[i].cpu() + if input_tokens[i].shape[-2] != 1: + tmp_hidden_states = tmp_hidden_states[-1:, :] + result = SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq_groups[i], + logprobs={tuple(idx_next_i): logprob_i}, + output_token=tuple(idx_next_i), + hidden_states=tmp_hidden_states, + finished=finish[i].item(), + ), + ], + prompt_logprobs=None, + ) + results.append(result) + # print(results) + # print(idx_next, idx_next.shape, logprob.shape) + return results + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model_config.get_vocab_size() + sampling_params = SamplingParams( + top_p=0.99, top_k=vocab_size - 1, infer_text=True + ) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = max_num_batched_tokens // max_num_seqs + ( + group_id < max_num_batched_tokens % max_num_seqs + ) + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [(None, None)] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + @torch.inference_mode() + def capture_model(self, kv_caches: List[KVCache]) -> None: + assert not self.model_config.enforce_eager + logger.info( + "Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI." + ) + logger.info( + "CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode." + ) + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_emb = torch.zeros( + max_batch_size, + 1, + self.model_config.get_hidden_size(), + dtype=next(self.model.parameters()).dtype, + ).cuda() + input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + # Create dummy input_metadata. + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping[:batch_size], + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_emb[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") + + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + def capture( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + memory_pool, + ) -> None: + assert self.graph is None + # Run the model once without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + self.model( + input_emb, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Capture the graph. + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states = self.model( + input_emb, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_emb": input_emb, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": input_metadata.slot_mapping, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return + + def forward( + self, + input_emb: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + input_metadata: InputMetadata, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_( + input_metadata.slot_mapping, non_blocking=True + ) + self.input_buffers["context_lens"].copy_( + input_metadata.context_lens, non_blocking=True + ) + self.input_buffers["block_tables"].copy_( + input_metadata.block_tables, non_blocking=True + ) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + if len(x) == max_len: + return list(x) + return list(x) + [pad] * (max_len - len(x)) + + +def _make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, +) -> torch.Tensor: + padded_x = [] + for x_i in x: + pad_i = pad + if isinstance(x[0][0], tuple): + pad_i = (0,) * len(x[0][0]) + padded_x.append(_pad_to_max(x_i, max_len, pad_i)) + + return torch.tensor( + padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu", + ) + + +def _get_graph_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return (batch_size + 7) // 8 * 8 + + +def _async_h2d(data: list, dtype, pin_memory): + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) + return t.to(device="cuda", non_blocking=True) diff --git a/ChatTTS/model/velocity/output.py b/ChatTTS/model/velocity/output.py new file mode 100644 index 000000000..05cc54600 --- /dev/null +++ b/ChatTTS/model/velocity/output.py @@ -0,0 +1,144 @@ +from typing import List, Optional +import torch + +from .sequence import ( + PromptLogprobs, + SampleLogprobs, + SequenceGroup, + SequenceStatus, +) + + +class CompletionOutput: + """The output data of one completion output of a request. + + Args: + index: The index of the output in the request. + text: The generated output text. + token_ids: The token IDs of the generated output text. + cumulative_logprob: The cumulative log probability of the generated + output text. + logprobs: The log probabilities of the top probability words at each + position if the logprobs are requested. + finish_reason: The reason why the sequence is finished. + """ + + def __init__( + self, + index: int, + text: str, + token_ids: List[int], + cumulative_logprob: float, + logprobs: Optional[SampleLogprobs], + finish_reason: Optional[str] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> None: + self.index = index + self.text = text + self.token_ids = token_ids + self.cumulative_logprob = cumulative_logprob + self.logprobs = logprobs + self.finish_reason = finish_reason + self.hidden_states = hidden_states + + def finished(self) -> bool: + return self.finish_reason is not None + + def __repr__(self) -> str: + return ( + f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})" + ) + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + prompt_logprobs: The log probabilities to return per prompt token. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + prompt_logprobs: Optional[PromptLogprobs], + outputs: List[CompletionOutput], + finished: bool, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs + self.outputs = outputs + self.finished = finished + + @classmethod + def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + # Get the top-n sequences. + n = seq_group.sampling_params.n + seqs = seq_group.get_seqs() + if seq_group.sampling_params.use_beam_search: + sorting_key = lambda seq: seq.get_beam_search_score( + seq_group.sampling_params.length_penalty + ) + else: + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + + # Create the outputs. + outputs: List[CompletionOutput] = [] + for seq in top_n_seqs: + logprobs = seq.output_logprobs + if seq_group.sampling_params.logprobs is None: + # NOTE: We need to take care of this case because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + logprobs = None + finshed_reason = SequenceStatus.get_finished_reason(seq.status) + output = CompletionOutput( + seqs.index(seq), + seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + logprobs, + finshed_reason, + seq.data.hidden_states, + ) + outputs.append(output) + + # Every sequence in the sequence group should have the same prompt. + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + finished = seq_group.is_finished() + return cls( + seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + ) + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished})" + ) diff --git a/ChatTTS/model/velocity/post_model.py b/ChatTTS/model/velocity/post_model.py new file mode 100644 index 000000000..79b8900a4 --- /dev/null +++ b/ChatTTS/model/velocity/post_model.py @@ -0,0 +1,186 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +""" +https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning +""" + +import torch +import torch.nn as nn +from torch.functional import F +from torch.nn.utils.parametrizations import weight_norm +from typing import List, Callable + + +class PostModel(nn.Module): + def __init__( + self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 + ): + super().__init__() + + self.num_vq = num_vq + self.num_audio_tokens = num_audio_tokens + + self.model_dim = hidden_size + self.emb_code = nn.ModuleList( + [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], + ) + self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) + + self.head_text = weight_norm( + nn.Linear(self.model_dim, num_text_tokens, bias=False), + name="weight", + ) + self.head_code = nn.ModuleList( + [ + weight_norm( + nn.Linear(self.model_dim, num_audio_tokens, bias=False), + name="weight", + ) + for _ in range(self.num_vq) + ], + ) + + def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: + """ + get_emb + """ + device = next(self.parameters()).device + emb_text: torch.Tensor = self.emb_text( + input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) + ) + + text_mask_inv = text_mask.logical_not().to(device) + masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) + + emb_code = [ + self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) + ] + emb_code = torch.stack(emb_code, 2).sum(2) + + emb = torch.zeros( + (input_ids.shape[:-1]) + (emb_text.shape[-1],), + device=emb_text.device, + dtype=emb_text.dtype, + ) + emb[text_mask] = emb_text + emb[text_mask_inv] = emb_code.to(emb.dtype) + + del emb_text, emb_code, text_mask_inv + + return emb + + +class Sampler: + def __init__(self, post_model: PostModel, num_audio_tokens: int, num_vq: int): + self.post_model = post_model + self.device = next(self.post_model.parameters()).device + self.num_audio_tokens = num_audio_tokens + self.num_vq = num_vq + + def sample( + self, + inputs_ids: torch.Tensor, + hidden_states: torch.Tensor, + infer_text: bool = False, + temperature: torch.Tensor = 1.0, + logits_processors: List[Callable] = [ + lambda logits_token, logits: logits, + ], + logits_warpers: List[Callable] = [ + lambda logits_token, logits: logits, + ], + min_new_token: int = 0, + now_length: int = 0, + eos_token: int = 0, + start_idx: int = 0, + ): + # print(inputs_ids.shape) + B = hidden_states.shape[0] + + end_idx = torch.zeros( + inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + ) + finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() + if not infer_text: + temperature = ( + temperature.unsqueeze(0) + .expand(inputs_ids.shape[0], -1) + .contiguous() + .view(-1, 1) + ) + + if infer_text: + logits: torch.Tensor = self.post_model.head_text(hidden_states) + else: + # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) + logits = torch.empty( + hidden_states.size(0), + hidden_states.size(1), + self.num_audio_tokens, + self.num_vq, + dtype=torch.float, + device=self.device, + ) + for num_vq_iter in range(self.num_vq): + x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) + logits[..., num_vq_iter] = x + del x + + del hidden_states + + # logits = logits[:, -1].float() + logits = logits.narrow(1, -1, 1).squeeze_(1).float() + + if not infer_text: + # logits = rearrange(logits, "b c n -> (b n) c") + logits = logits.permute(0, 2, 1) + logits = logits.reshape(-1, logits.size(2)) + # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") + inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) + logits_token = inputs_ids_sliced.reshape( + inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), + -1, + ).to(self.device) + else: + logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + + logits /= temperature + + for logitsProcessors in logits_processors: + logits = logitsProcessors(logits_token, logits) + + for logitsWarpers in logits_warpers: + logits = logitsWarpers(logits_token, logits) + + del logits_token + + if now_length < min_new_token: + logits[:, eos_token] = -torch.inf + + scores = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) + if not infer_text: + scores = scores.reshape(B, -1, scores.shape[-1]) + if not infer_text: + # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) + idx_next = idx_next.view(-1, self.num_vq) + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + else: + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + + del inputs_ids + + not_finished = finish.logical_not().to(end_idx.device) + + end_idx.add_(not_finished.int()) + idx_next = idx_next[:, None, :] + return ( + idx_next, + torch.log(scores), + finish, + ) diff --git a/ChatTTS/model/velocity/sampling_params.py b/ChatTTS/model/velocity/sampling_params.py new file mode 100644 index 000000000..e650fc546 --- /dev/null +++ b/ChatTTS/model/velocity/sampling_params.py @@ -0,0 +1,296 @@ +"""Sampling parameters for text generation.""" + +from enum import IntEnum +from functools import cached_property +from typing import Callable, List, Optional, Union + +import torch + +_SAMPLING_EPS = 1e-5 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + BEAM = 2 + + +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + +class SamplingParams: + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. + use_beam_search: Whether to use beam search instead of sampling. + length_penalty: Float that penalizes sequences based on their length. + Used in beam search. + early_stopping: Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm). + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + stop_token_ids: List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in output + text. Defaults to False. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + logprobs: Number of log probabilities to return per output token. + Note that the implementation follows the OpenAI API: The return + result includes the log probabilities on the `logprobs` most likely + tokens, as well the chosen tokens. The API will always return the + log probability of the sampled token, so there may be up to + `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. + skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens. + """ + + def __init__( + self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: int = 16, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = ( + [ + lambda logits_token, logits: logits, + ], + [ + lambda logits_token, logits: logits, + ], + ), + min_new_token: int = 0, + max_new_token: int = 8192, + infer_text: bool = False, + eos_token: int = 0, + spk_emb: str = None, + start_idx: int = 0, + ) -> None: + self.n = n + self.best_of = best_of if best_of is not None else n + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.repetition_penalty = repetition_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.use_beam_search = use_beam_search + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.min_new_token = min_new_token + self.max_new_token = max_new_token + self.infer_text = infer_text + self.eos_token = eos_token + self.spk_emb = spk_emb + self.start_idx = start_idx + if stop is None: + self.stop = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = list(stop) + if stop_token_ids is None: + self.stop_token_ids = [] + else: + self.stop_token_ids = list(stop_token_ids) + self.ignore_eos = ignore_eos + self.max_tokens = max_tokens + self.logprobs = logprobs + self.prompt_logprobs = prompt_logprobs + self.skip_special_tokens = skip_special_tokens + self.spaces_between_special_tokens = spaces_between_special_tokens + self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output + self._verify_args() + if self.use_beam_search: + self._verify_beam_search() + else: + self._verify_non_beam_search() + # if self.temperature < _SAMPLING_EPS: + # # Zero temperature means greedy sampling. + # self.top_p = 1.0 + # self.top_k = -1 + # self.min_p = 0.0 + # self._verify_greedy_sampling() + + def _verify_args(self) -> None: + if self.n < 1: + raise ValueError(f"n must be at least 1, got {self.n}.") + if self.best_of < self.n: + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}." + ) + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError( + "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." + ) + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError( + "frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}." + ) + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError( + "repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}." + ) + # if self.temperature < 0.0: + # raise ValueError( + # f"temperature must be non-negative, got {self.temperature}.") + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError( + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.") + if self.max_tokens < 1: + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.logprobs is not None and self.logprobs < 0: + raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError( + f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}." + ) + + def _verify_beam_search(self) -> None: + if self.best_of == 1: + raise ValueError( + "best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}." + ) + if self.temperature > _SAMPLING_EPS: + raise ValueError("temperature must be 0 when using beam search.") + if self.top_p < 1.0 - _SAMPLING_EPS: + raise ValueError("top_p must be 1 when using beam search.") + if self.top_k != -1: + raise ValueError("top_k must be -1 when using beam search.") + if self.early_stopping not in [True, False, "never"]: + raise ValueError( + f"early_stopping must be True, False, or 'never', " + f"got {self.early_stopping}." + ) + + def _verify_non_beam_search(self) -> None: + if self.early_stopping is not False: + raise ValueError( + "early_stopping is not effective and must be " + "False when not using beam search." + ) + if ( + self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS + ): + raise ValueError( + "length_penalty is not effective and must be the " + "default value of 1.0 when not using beam search." + ) + + def _verify_greedy_sampling(self) -> None: + if self.best_of > 1: + raise ValueError( + "best_of must be 1 when using greedy sampling." f"Got {self.best_of}." + ) + + @cached_property + def sampling_type(self) -> SamplingType: + if self.use_beam_search: + return SamplingType.BEAM + # if self.temperature < _SAMPLING_EPS: + # return SamplingType.GREEDY + return SamplingType.RANDOM + + def __repr__(self) -> str: + return ( + f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens}), " + f"max_new_token={self.max_new_token}), " + f"min_new_token={self.min_new_token}), " + f"infer_text={self.infer_text})" + ) diff --git a/ChatTTS/model/velocity/scheduler.py b/ChatTTS/model/velocity/scheduler.py new file mode 100644 index 000000000..97d9cb450 --- /dev/null +++ b/ChatTTS/model/velocity/scheduler.py @@ -0,0 +1,426 @@ +import enum +import time +from typing import Dict, Iterable, List, Optional, Tuple, Union + +from vllm.config import CacheConfig, SchedulerConfig +from .block_manager import AllocStatus, BlockSpaceManager +from vllm.core.policy import PolicyFactory +from vllm.logger import init_logger +from .sequence import ( + Sequence, + SequenceData, + SequenceGroup, + SequenceGroupMetadata, + SequenceStatus, +) + +logger = init_logger(__name__) + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +class SchedulerOutputs: + + def __init__( + self, + scheduled_seq_groups: List[SequenceGroup], + prompt_run: bool, + num_batched_tokens: int, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ignored_seq_groups: List[SequenceGroup], + ) -> None: + self.scheduled_seq_groups = scheduled_seq_groups + self.prompt_run = prompt_run + self.num_batched_tokens = num_batched_tokens + self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_copy = blocks_to_copy + # Swap in and swap out should never happen at the same time. + assert not (blocks_to_swap_in and blocks_to_swap_out) + self.ignored_seq_groups = ignored_seq_groups + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return ( + not self.scheduled_seq_groups + and not self.blocks_to_swap_in + and not self.blocks_to_swap_out + and not self.blocks_to_copy + ) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy(policy_name="fcfs") + # Create the block space manager. + self.block_manager = BlockSpaceManager( + block_size=self.cache_config.block_size, + num_gpu_blocks=self.cache_config.num_gpu_blocks, + num_cpu_blocks=self.cache_config.num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + ) + + # TODO(zhuohan): Use deque instead of list for better performance. + # Sequence groups in the WAITING state. + self.waiting: List[SequenceGroup] = [] + # Sequence groups in the RUNNING state. + self.running: List[SequenceGroup] = [] + # Sequence groups in the SWAPPED state. + self.swapped: List[SequenceGroup] = [] + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id,) + request_ids = set(request_id) + for state_queue in [self.waiting, self.running, self.swapped]: + # We need to reverse the list as we are removing elements + # from it as we iterate over it. If we don't do it, + # indices will get messed up and we will skip over elements. + for seq_group in reversed(state_queue): + if seq_group.request_id in request_ids: + # Remove the sequence group from the state queue. + state_queue.remove(seq_group) + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + request_ids.remove(seq_group.request_id) + if not request_ids: + return + + def has_unfinished_seqs(self) -> bool: + return self.waiting or self.running or self.swapped + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def _schedule(self) -> SchedulerOutputs: + # Blocks that need to be swaped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_out: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} + + # Fix the current time. + now = time.monotonic() + + # Join waiting sequences if possible. + if not self.swapped: + ignored_seq_groups: List[SequenceGroup] = [] + scheduled: List[SequenceGroup] = [] + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) + seq_lens: List[int] = [] + + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + while self.waiting: + seq_group = self.waiting[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " "sequence." + ) + num_prompt_tokens = waiting_seqs[0].get_len() + if num_prompt_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}" + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager" + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + continue + + # If the number of batched tokens exceeds the limit, stop. + new_seq_lens = seq_lens + [num_prompt_tokens] + num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + if num_batched_tokens > self.scheduler_config.max_num_batched_tokens: + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: + break + + num_paddings = num_batched_tokens - sum(new_seq_lens) + if num_paddings > self.scheduler_config.max_paddings: + break + seq_lens = new_seq_lens + + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_curr_seqs += num_new_seqs + scheduled.append(seq_group) + + if scheduled or ignored_seq_groups: + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=scheduled, + prompt_run=True, + num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + ) + return scheduler_outputs + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.pop(0) + while not self.block_manager.can_append_slot(seq_group): + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq_group = self.running.pop(-1) + self._preempt(victim_seq_group, blocks_to_swap_out) + preempted.append(victim_seq_group) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq_group, blocks_to_swap_out) + preempted.append(seq_group) + break + else: + # Append new slots to the sequence group. + self._append_slot(seq_group, blocks_to_copy) + running.append(seq_group) + self.running = running + + # Swap in the sequence groups in the SWAPPED state if possible. + self.swapped = self.policy.sort_by_priority(now, self.swapped) + if not preempted: + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) + + while self.swapped: + seq_group = self.swapped[0] + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: + break + + seq_group = self.swapped.pop(0) + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slot(seq_group, blocks_to_copy) + num_curr_seqs += num_new_seqs + self.running.append(seq_group) + + # Each sequence in the generation phase only takes one token slot. + # Therefore, the number of batched tokens is equal to the number of + # sequences in the RUNNING state. + num_batched_tokens = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running + ) + + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=self.running, + prompt_run=False, + num_batched_tokens=num_batched_tokens, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=[], + ) + return scheduler_outputs + + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_outputs = self._schedule() + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for seq_group in scheduler_outputs.scheduled_seq_groups: + seq_data: Dict[int, SequenceData] = {} + block_tables: Dict[int, List[int]] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=scheduler_outputs.prompt_run, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + ) + seq_group_metadata_list.append(seq_group_metadata) + return seq_group_metadata_list, scheduler_outputs + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + self.block_manager.free(seq) + + def free_finished_seq_groups(self) -> None: + self.running = [ + seq_group for seq_group in self.running if not seq_group.is_finished() + ] + + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slot( + self, + seq_group: SequenceGroup, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + ret = self.block_manager.append_slot(seq) + if ret is not None: + src_block, dst_block = ret + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + preemption_mode: Optional[PreemptionMode] = None, + ) -> None: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + # NOTE: For FCFS, we insert the preempted sequence group to the front + # of the waiting queue. + self.waiting.insert(0, seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error." + ) + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/ChatTTS/model/velocity/sequence.py b/ChatTTS/model/velocity/sequence.py new file mode 100644 index 000000000..76f9cf4e7 --- /dev/null +++ b/ChatTTS/model/velocity/sequence.py @@ -0,0 +1,450 @@ +"""Sequence and its related classes.""" + +import copy +import enum +from typing import Dict, List, Optional, Union +import torch +from vllm.block import LogicalTokenBlock +from .sampling_params import SamplingParams + +PromptLogprobs = List[Optional[Dict[int, float]]] +SampleLogprobs = List[Dict[int, float]] + + +class SequenceStatus(enum.Enum): + """Status of a sequence.""" + + WAITING = enum.auto() + RUNNING = enum.auto() + SWAPPED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + SequenceStatus.FINISHED_ABORTED, + SequenceStatus.FINISHED_IGNORED, + ] + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +class SequenceData: + """Data associated with a sequence. + + + Args: + prompt_token_ids: The token IDs of the prompt. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + + def __init__( + self, + prompt_token_ids: List[int], + ) -> None: + self.prompt_token_ids = prompt_token_ids + self.output_token_ids: List[int] = [] + self.cumulative_logprob = 0.0 + self.hidden_states: Optional[torch.Tensor] = None + self.finished = False + + def append_token_id(self, token_id: int, logprob: float) -> None: + if isinstance(self.cumulative_logprob, float): + self.cumulative_logprob = [ + 0.0, + ] * len(logprob) + self.output_token_ids.append(token_id) + for i in range(len(self.cumulative_logprob)): + self.cumulative_logprob[i] += logprob[i] + + def append_hidden_states(self, hidden_states: torch.Tensor) -> None: + if self.hidden_states is None: + self.hidden_states = hidden_states + else: + self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) + + def get_len(self) -> int: + return len(self.output_token_ids) + len(self.prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self.prompt_token_ids) + + def get_output_len(self) -> int: + return len(self.output_token_ids) + + def get_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.output_token_ids + + def get_last_token_id(self) -> int: + if not self.output_token_ids: + return self.prompt_token_ids[-1] + return self.output_token_ids[-1] + + def __repr__(self) -> str: + return ( + f"SequenceData(" + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}), " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " + f"finished={self.finished})" + ) + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + Args: + seq_id: The ID of the sequence. + prompt: The prompt of the sequence. + prompt_token_ids: The token IDs of the prompt. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + """ + + def __init__( + self, + seq_id: int, + prompt: str, + prompt_token_ids: List[int], + block_size: int, + ) -> None: + self.seq_id = seq_id + self.prompt = prompt + self.block_size = block_size + + self.data = SequenceData(prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" + + self.logical_token_blocks: List[LogicalTokenBlock] = [] + # Initialize the logical token blocks with the prompt token ids. + self._append_tokens_to_blocks(prompt_token_ids) + self.status = SequenceStatus.WAITING + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + + def _append_logical_block(self) -> None: + block = LogicalTokenBlock( + block_number=len(self.logical_token_blocks), + block_size=self.block_size, + ) + self.logical_token_blocks.append(block) + + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: + cursor = 0 + while cursor < len(token_ids): + if not self.logical_token_blocks: + self._append_logical_block() + + last_block = self.logical_token_blocks[-1] + if last_block.is_full(): + self._append_logical_block() + last_block = self.logical_token_blocks[-1] + + num_empty_slots = last_block.get_num_empty_slots() + last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) + cursor += num_empty_slots + + def append_token_id( + self, + token_id: int, + logprobs: Dict[int, float], + hidden_states: Optional[torch.Tensor] = None, + finished: bool = False, + ) -> None: + assert token_id in logprobs + self._append_tokens_to_blocks([token_id]) + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id]) + self.data.append_hidden_states(hidden_states) + self.data.finished = finished + + def get_len(self) -> int: + return self.data.get_len() + + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + + def get_output_len(self) -> int: + return self.data.get_output_len() + + def get_token_ids(self) -> List[int]: + return self.data.get_token_ids() + + def get_last_token_id(self) -> int: + return self.data.get_last_token_id() + + def get_output_token_ids(self) -> List[int]: + return self.data.output_token_ids + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + + def get_beam_search_score( + self, + length_penalty: float = 0.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + if seq_len is None: + seq_len = self.get_len() + # NOTE: HF implementation does not count the EOS token + # towards the length, we align with that here for testing. + if eos_token_id is not None and self.get_last_token_id() == eos_token_id: + seq_len -= 1 + return self.get_cumulative_logprob() / (seq_len**length_penalty) + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq + + def __repr__(self) -> str: + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={len(self.logical_token_blocks)})" + ) + + +class SequenceGroup: + """A group of sequences that are generated from the same prompt. + + Args: + request_id: The ID of the request. + seqs: The list of sequences. + sampling_params: The sampling parameters used to generate the outputs. + arrival_time: The arrival time of the request. + """ + + def __init__( + self, + request_id: str, + seqs: List[Sequence], + sampling_params: SamplingParams, + arrival_time: float, + ) -> None: + self.request_id = request_id + self.seqs_dict = {seq.seq_id: seq for seq in seqs} + self.sampling_params = sampling_params + self.arrival_time = arrival_time + self.prompt_logprobs: Optional[PromptLogprobs] = None + + @property + def prompt(self) -> str: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).prompt + + @property + def prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).data.prompt_token_ids + + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.sampling_params.use_beam_search: + # For beam search, maximally there will always be `best_of` beam + # candidates running in the future. + return self.sampling_params.best_of + else: + if self.sampling_params.best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences running. + return self.sampling_params.best_of + # At sampling stages, return the number of actual sequences + # that are not finished yet. + return self.num_unfinished_seqs() + + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> List[Sequence]: + if status is None: + return list(self.seqs_dict.values()) + else: + return [seq for seq in self.seqs_dict.values() if seq.status == status] + + def get_unfinished_seqs(self) -> List[Sequence]: + return [seq for seq in self.seqs_dict.values() if not seq.is_finished()] + + def get_finished_seqs(self) -> List[Sequence]: + return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + return len(self.get_seqs(status)) + + def num_unfinished_seqs(self) -> int: + return len(self.get_unfinished_seqs()) + + def num_finished_seqs(self) -> int: + return len(self.get_finished_seqs()) + + def find(self, seq_id: int) -> Sequence: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + return self.seqs_dict[seq_id] + + def add(self, seq: Sequence) -> None: + if seq.seq_id in self.seqs_dict: + raise ValueError(f"Sequence {seq.seq_id} already exists.") + self.seqs_dict[seq.seq_id] = seq + + def remove(self, seq_id: int) -> None: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + del self.seqs_dict[seq_id] + + def is_finished(self) -> bool: + return all(seq.is_finished() for seq in self.get_seqs()) + + def __repr__(self) -> str: + return ( + f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs_dict)})" + ) + + +class SequenceGroupMetadata: + """Metadata for a sequence group. Used to create `InputMetadata`. + + + Args: + request_id: The ID of the request. + is_prompt: Whether the request is at prompt stage. + seq_data: The sequence data. (Seq id -> sequence data) + sampling_params: The sampling parameters used to generate the outputs. + block_tables: The block tables. (Seq id -> list of physical block + numbers) + """ + + def __init__( + self, + request_id: str, + is_prompt: bool, + seq_data: Dict[int, SequenceData], + sampling_params: SamplingParams, + block_tables: Dict[int, List[int]], + ) -> None: + self.request_id = request_id + self.is_prompt = is_prompt + self.seq_data = seq_data + self.sampling_params = sampling_params + self.block_tables = block_tables + + +class SequenceOutput: + """The model output associated with a sequence. + + Args: + parent_seq_id: The ID of the parent sequence (for forking in beam + search). + output_token: The output token ID. + logprobs: The logprobs of the output token. + (Token id -> logP(x_i+1 | x_0, ..., x_i)) + """ + + def __init__( + self, + parent_seq_id: int, + output_token: int, + logprobs: Dict[int, float], + hidden_states: Optional[torch.Tensor] = None, + finished: bool = False, + ) -> None: + self.parent_seq_id = parent_seq_id + self.output_token = output_token + self.logprobs = logprobs + self.finished = finished + self.hidden_states = hidden_states + + def __repr__(self) -> str: + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})," + f"finished={self.finished})," + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceOutput): + raise NotImplementedError() + return ( + self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + and self.logprobs == other.logprobs + ) + + +class SequenceGroupOutput: + """The model output associated with a sequence group.""" + + def __init__( + self, + samples: List[SequenceOutput], + prompt_logprobs: Optional[PromptLogprobs], + ) -> None: + self.samples = samples + self.prompt_logprobs = prompt_logprobs + + def __repr__(self) -> str: + return ( + f"SequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceGroupOutput): + raise NotImplementedError() + return ( + self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs + ) + + +# For each sequence group, we generate a list of SequenceOutput object, +# each of which contains one possible candidate for the next token. +SamplerOutput = List[SequenceGroupOutput] diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py new file mode 100644 index 000000000..90aca7f32 --- /dev/null +++ b/ChatTTS/model/velocity/worker.py @@ -0,0 +1,250 @@ +"""A GPU worker class.""" + +import os +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list +from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine +from .model_runner import ModelRunner + + +class Worker: + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + post_model_path: str, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + self.post_model_path = post_model_path + + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + is_driver_worker, + post_model_path, + ) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + def init_model(self) -> None: + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + + # Initialize the distributed environment. + _init_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method + ) + + # Initialize the model. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int, + ) -> Tuple[int, int]: + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + cache_block_size = CacheEngine.get_cache_block_size( + block_size, self.model_config, self.parallel_config + ) + num_gpu_blocks = int( + (total_gpu_memory * gpu_memory_utilization - peak_memory) + // cache_block_size + ) + num_cpu_blocks = int(cpu_swap_space // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CacheEngine( + self.cache_config, self.model_config, self.parallel_config + ) + self.cache_events = self.cache_engine.events + self.gpu_cache = self.cache_engine.gpu_cache + self.model_runner.set_block_size(self.cache_engine.block_size) + + def warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def cache_swap( + self, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + # TODO(woosuk): Profile swapping overhead and optimize if needed. + if cache_events is not None: + for event in cache_events: + event.wait() + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + block_swapping_info = [ + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + ] + broadcast_object_list([num_seq_groups] + block_swapping_info, src=0) + else: + # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, + # blocks_to_copy (4 elements) + recv_data = [None] * 4 + broadcast_object_list(recv_data, src=0) + num_seq_groups = recv_data[0] + block_swapping_info = recv_data[1:] + + self.cache_swap(*block_swapping_info) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model( + seq_group_metadata_list, self.gpu_cache + ) + return output + + +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size})." + ) + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized" + ) + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel( + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}." + ) diff --git a/README.md b/README.md index e9d0d2af3..e31408cc0 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ A generative speech model for daily dialogue. [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) @@ -101,7 +101,12 @@ conda activate chattts pip install -r requirements.txt ``` -#### Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) +#### Optional: Install vLLM (Linux only) +```bash +pip install safetensors vllm==0.2.7 torchaudio +``` + +#### Unrecommended Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) > [!Note] > The installation process is very slow. @@ -113,7 +118,7 @@ pip install -r requirements.txt pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable ``` -#### Optional: Install FlashAttention-2 (mainly NVIDIA GPU) +#### Unrecommended Optional: Install FlashAttention-2 (mainly NVIDIA GPU) > [!Note] > See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). diff --git a/docs/cn/README.md b/docs/cn/README.md index 8dbfa1244..c5dd6f753 100644 --- a/docs/cn/README.md +++ b/docs/cn/README.md @@ -6,7 +6,7 @@ 一款适用于日常对话的生成式语音模型。 [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) diff --git a/docs/fr/README.md b/docs/fr/README.md index 3addc9b73..d2fbb00d7 100644 --- a/docs/fr/README.md +++ b/docs/fr/README.md @@ -6,7 +6,7 @@ Un modèle de parole génératif pour le dialogue quotidien. [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) -[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge)](https://pypi.org/project/ChatTTS) +[![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb)