diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 5824e8f96..bcbb804d1 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -549,7 +549,9 @@ def _infer_code( return [ GPT.GenerationOutputs( - ids=token_ids, hiddens=hidden_states, attentions=[], + ids=token_ids, + hiddens=hidden_states, + attentions=[], ), ] @@ -640,8 +642,11 @@ def _refine_text( del_all(logits_warpers) del_all(logits_processors) - return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states, attentions=[], - ) + return GPT.GenerationOutputs( + ids=token_ids, + hiddens=hidden_states, + attentions=[], + ) emb = gpt(input_ids, text_mask) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 073f59982..a13df5990 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -782,14 +782,16 @@ def _make_tensor_with_pad( padded_x = [] for x_i in x: pad_i = pad - if isinstance(x[0][0],tuple): + if isinstance(x[0][0], tuple): pad_i = (0,) * len(x[0][0]) padded_x.append(_pad_to_max(x_i, max_len, pad_i)) - - return torch.tensor(padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu") + + return torch.tensor( + padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu", + ) def _get_graph_batch_size(batch_size: int) -> int: