Skip to content

Commit

Permalink
feat(vLLM): add missing params in refine_text
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 25, 2024
1 parent 4991dfd commit 4a1962b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
14 changes: 9 additions & 5 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _infer_code(
)

if gpt.is_vllm:
from .model.velocity.sampling_params import SamplingParams
from .model.velocity import SamplingParams

sample_params = SamplingParams(
temperature=temperature,
Expand Down Expand Up @@ -617,10 +617,13 @@ def _refine_text(
)

if gpt.is_vllm:
from .model.velocity.sampling_params import SamplingParams
from .model.velocity import SamplingParams

sample_params = SamplingParams(
repetition_penalty=params.repetition_penalty,
temperature=params.temperature,
top_p=params.top_P,
top_k=params.top_K,
max_new_token=params.max_new_token,
max_tokens=8192,
min_new_token=params.min_new_token,
Expand All @@ -629,16 +632,17 @@ def _refine_text(
infer_text=True,
start_idx=input_ids.shape[-2],
)
input_ids = [i.tolist() for i in input_ids]
input_ids_list = [i.tolist() for i in input_ids]
del input_ids

result = gpt.llm.generate(None, sample_params, input_ids)
result = gpt.llm.generate(None, sample_params, input_ids_list)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states)

del text_mask, input_ids
del text_mask, input_ids_list, result
del_all(logits_warpers)
del_all(logits_processors)

Expand Down
3 changes: 1 addition & 2 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def from_pretrained(self, file_path: str):
if self.is_vllm and platform.system().lower() == "linux":
from safetensors.torch import save_file

from .velocity.llm import LLM
from .velocity.post_model import PostModel
from .velocity import LLM, PostModel

vllm_folder = Path(os.getcwd()) / "asset" / "vllm"
if not os.path.exists(vllm_folder):
Expand Down
3 changes: 3 additions & 0 deletions ChatTTS/model/velocity/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .llm import LLM
from .post_model import PostModel
from .sampling_params import SamplingParams

0 comments on commit 4a1962b

Please sign in to comment.