Skip to content

Commit

Permalink
chore(format): run black on dev (#556)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 9, 2024
1 parent 6e18575 commit 00b2dfd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
6 changes: 3 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _infer(
for i in range(wavs.shape[0]):
if b[i] > len(wavs[i]):
b[i] = len(wavs[i])
new_wavs[i, :b[i]-a[i]] = wavs[i, a[i]:b[i]]
new_wavs[i, : b[i] - a[i]] = wavs[i, a[i] : b[i]]
length = b
yield new_wavs
else:
Expand All @@ -410,8 +410,8 @@ def _infer(
for i in range(wavs.shape[0]):
a = length[i]
b = len(wavs[i])
wavs[i, :b-a] = wavs[i, a:]
wavs[i, b-a:] = 0
wavs[i, : b - a] = wavs[i, a:]
wavs[i, b - a :] = 0
yield wavs

@torch.inference_mode()
Expand Down
12 changes: 8 additions & 4 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):

logger.info("Start inference.")
wavs = chat.infer(
texts, stream,
texts,
stream,
params_infer_code=ChatTTS.Chat.InferCodeParams(
spk_emb=spk,
),
Expand All @@ -57,7 +58,7 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):
for index, wav in enumerate(wavs):
if stream:
for i, w in enumerate(wav):
save_mp3_file(w, (i+1)*1000+index)
save_mp3_file(w, (i + 1) * 1000 + index)
wavs_list.append(wav)
else:
save_mp3_file(wav, index)
Expand All @@ -82,10 +83,13 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):
parser.add_argument(
"--stream",
help="Use stream mode",
action='store_true',
action="store_true",
)
parser.add_argument(
"texts", help="Original text", default=["YOUR TEXT HERE"], nargs=argparse.REMAINDER,
"texts",
help="Original text",
default=["YOUR TEXT HERE"],
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
main(args.texts, args.spk, args.stream)
Expand Down

0 comments on commit 00b2dfd

Please sign in to comment.