From 499c2946e098c71a39f5600df211580f8bdfbd67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 20 Jul 2024 15:20:38 +0900 Subject: [PATCH] fix(gpt): stream mode dim mismatch (fix #606) --- ChatTTS/model/gpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 73dfbcf96..178c19446 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -416,6 +416,7 @@ def generate( ) inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids) del inputs_ids + inputs_ids = inputs_ids_buf.narrow(1, 0, progress) pbar: Optional[tqdm] = None @@ -430,8 +431,6 @@ def generate( for i in range(max_new_token): - inputs_ids = inputs_ids_buf.narrow(1, 0, progress) - model_input = self._prepare_generation_inputs( inputs_ids, past_key_values, @@ -606,6 +605,7 @@ def generate( del idx_next progress += 1 + inputs_ids = inputs_ids_buf.narrow(1, 0, progress) not_finished = finish.logical_not().to(end_idx.device) end_idx.add_(not_finished.int())