Skip to content

Commit

Permalink
Clean up merge
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam committed Oct 6, 2023
1 parent 595536c commit 195e3bd
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 40 deletions.
6 changes: 6 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[settings]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
129 changes: 90 additions & 39 deletions generate_samples.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
#!/usr/bin/env python3
import argparse
import gc
import itertools as it
import json
import os
import gc
import logging
import unicodedata
import os
import wave
from pathlib import Path
from tqdm import tqdm
from types import SimpleNamespace
from typing import Union, List
import webrtcvad
from typing import List, Union

import numpy as np
import torch
import torchaudio
import webrtcvad
from piper_phonemize import phonemize_espeak

from piper_phonemize import phonemize_espeak, phoneme_ids_espeak
from piper_train.vits import commons

_DIR = Path(__file__).parent
_LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


# Main generation function
def generate_samples(
text: Union[List, str],
output_dir: str,
max_samples: int=None,
max_samples: int = None,
file_names: List[str] = [],
model: str = os.path.join(Path(__file__).parent, "models", "en-us-libritts-high.pt"),
model: str = os.path.join(
Path(__file__).parent, "models", "en_US-libritts_r-medium.pt"
),
batch_size: int = 1,
slerp_weights: List[float] = [0.5],
length_scales: List[float] = [0.75, 1, 1.25],
Expand All @@ -39,13 +39,13 @@ def generate_samples(
max_speakers: float = None,
verbose: bool = False,
auto_reduce_batch_size: bool = False,
**kwargs
) -> None:
**kwargs,
) -> None:
"""
Generate synthetic speech clips, saving the clips to the specified output directory.
Args:
text (List[str]): The text to convert into speech. Can be either a
text (List[str]): The text to convert into speech. Can be either a
a list of strings, or a path to a file with text on each line.
output_dir (str): The location to save the generated clips.
max_samples (int): The maximum number of samples to generate.
Expand Down Expand Up @@ -92,8 +92,6 @@ def generate_samples(
if max_speakers is not None:
num_speakers = min(num_speakers, max_speakers)

phonemizer = Phonemizer(voice)

max_len = None

sample_idx = 0
Expand All @@ -116,13 +114,19 @@ def generate_samples(
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492
beta=14.769656459379492,
)

speakers_iter = it.cycle(it.product(range(num_speakers), range(num_speakers)))
speakers_batch = list(it.islice(speakers_iter, 0, batch_size))
if isinstance(text, str) and os.path.exists(text):
texts = it.cycle([i.strip() for i in open(text, 'r').readlines() if len(i.strip()) > 0])
texts = it.cycle(
[
i.strip()
for i in open(text, "r", encoding="utf-8").readlines()
if len(i.strip()) > 0
]
)
elif isinstance(text, list):
texts = it.cycle(text)
else:
Expand All @@ -143,32 +147,56 @@ def generate_samples(
speaker_1 = torch.LongTensor([s[0] for s in speakers_batch])
speaker_2 = torch.LongTensor([s[1] for s in speakers_batch])

phoneme_ids = [get_phonemes(phonemizer, config, next(texts), verbose) for i in range(batch_size)]
phoneme_ids = [
get_phonemes(voice, config, next(texts), verbose)
for i in range(batch_size)
]

def right_pad_lists(lists):
max_length = max(len(l) for l in lists)
max_length = max(len(lst) for lst in lists)
padded_lists = []
for l in lists:
padded_l = l + [1] * (max_length - len(l)) # phoneme 1 (corresponding to '^' character seems to work best)
for lst in lists:
padded_l = lst + [1] * (
max_length - len(lst)
) # phoneme 1 (corresponding to '^' character seems to work best)
padded_lists.append(padded_l)
return padded_lists

phoneme_ids = right_pad_lists(phoneme_ids)

if auto_reduce_batch_size:
oom_error = True
counter = 1
while oom_error is True:
try:
audio = generate_audio(model, speaker_1[0:batch_size//counter], speaker_2[0:batch_size//counter], phoneme_ids[0:batch_size//counter],
slerp_weight, noise_scale, noise_scale_w, length_scale, max_len)
audio = generate_audio(
model,
speaker_1[0 : batch_size // counter],
speaker_2[0 : batch_size // counter],
phoneme_ids[0 : batch_size // counter],
slerp_weight,
noise_scale,
noise_scale_w,
length_scale,
max_len,
)
oom_error = False
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
gc.collect()
counter += 1 # reduce batch size to avoid OOM errors
else:
audio = generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise_scale, noise_scale_w, length_scale, max_len)
audio = generate_audio(
model,
speaker_1,
speaker_2,
phoneme_ids,
slerp_weight,
noise_scale,
noise_scale_w,
length_scale,
max_len,
)

# Resample audio
audio = resampler(audio.cpu()).numpy()
Expand Down Expand Up @@ -202,20 +230,32 @@ def right_pad_lists(lists):

_LOGGER.info("Done")

def remove_silence(x, frame_duration=.030, sample_rate=16000, min_start = 2000):

def remove_silence(x, frame_duration=0.030, sample_rate=16000, min_start=2000):
"""Uses webrtc voice activity detection to remove silence from the clips"""
vad = webrtcvad.Vad(0)
if x.dtype == np.float32 or x.dtype == np.float64:
x = (x*32767).astype(np.int16)
x = (x * 32767).astype(np.int16)
x_new = x[0:min_start].tolist()
step_size = int(sample_rate*frame_duration)
step_size = int(sample_rate * frame_duration)
for i in range(min_start, x.shape[0] - step_size, step_size):
vad_res = vad.is_speech(x[i:i+step_size].tobytes(), sample_rate)
vad_res = vad.is_speech(x[i : i + step_size].tobytes(), sample_rate)
if vad_res:
x_new.extend(x[i:i+step_size].tolist())
x_new.extend(x[i : i + step_size].tolist())
return np.array(x_new).astype(np.int16)

def generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise_scale, noise_scale_w, length_scale, max_len):

def generate_audio(
model,
speaker_1,
speaker_2,
phoneme_ids,
slerp_weight,
noise_scale,
noise_scale_w,
length_scale,
max_len,
):
x = torch.LongTensor(phoneme_ids)
x_lengths = torch.LongTensor([len(i) for i in phoneme_ids])

Expand Down Expand Up @@ -246,9 +286,7 @@ def generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise
m_p = torch.matmul(attn.squeeze(1), m_p_orig.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(
attn.squeeze(1), logs_p_orig.transpose(1, 2)
).transpose(
logs_p = torch.matmul(attn.squeeze(1), logs_p_orig.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']

Expand All @@ -259,9 +297,14 @@ def generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise
audio = o
return audio

def get_phonemes(phonemizer, config, text, verbose):
phonemes_str = phonemizer.phonemize(text)
phonemes = list(unicodedata.normalize("NFD", phonemes_str))

def get_phonemes(voice, config, text, verbose):
# Combine all sentences
phonemes = [
p
for sentence_phonemes in phonemize_espeak(text, voice)
for p in sentence_phonemes
]
if verbose is True:
_LOGGER.debug("Phonemes: %s", phonemes)

Expand All @@ -276,6 +319,7 @@ def get_phonemes(phonemizer, config, text, verbose):
phoneme_ids.extend(id_map["$"])
return phoneme_ids


def slerp(v1, v2, t, DOT_THR=0.9995, zdim=-1):
"""SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`.
Expand Down Expand Up @@ -336,13 +380,20 @@ def audio_float_to_int16(
parser = argparse.ArgumentParser()
parser.add_argument("text")
parser.add_argument("--max-samples", required=True, type=int)
parser.add_argument("--model", default=_DIR / "models" / "en-us-libritts-high.pt")
parser.add_argument(
"--model", default=_DIR / "models" / "en_US-libritts_r-medium.pt"
)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--slerp-weights", nargs="+", type=float, default=[0.5])
parser.add_argument(
"--length-scales", nargs="+", type=float, default=[1.0, 0.75, 1.25, 1.4]
)
parser.add_argument("--noise-scales", nargs="+", type=float, default=[0.667, .75, .85, 0.9, 1.0, 1.4])
parser.add_argument(
"--noise-scales",
nargs="+",
type=float,
default=[0.667, 0.75, 0.85, 0.9, 1.0, 1.4],
)
parser.add_argument("--noise-scale-ws", nargs="+", type=float, default=[0.8])
parser.add_argument("--output-dir", default="output")
parser.add_argument(
Expand Down
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[mypy]
ignore_missing_imports = true

[mypy-setuptools.*]
ignore_missing_imports = True
37 changes: 37 additions & 0 deletions pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[MESSAGES CONTROL]
disable=
format,
abstract-method,
cyclic-import,
duplicate-code,
global-statement,
import-outside-toplevel,
inconsistent-return-statements,
locally-disabled,
not-context-manager,
too-few-public-methods,
too-many-arguments,
too-many-branches,
too-many-instance-attributes,
too-many-lines,
too-many-locals,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
too-many-boolean-expressions,
unnecessary-pass,
unused-argument,
broad-except,
too-many-nested-blocks,
invalid-name,
unused-import,
fixme,
useless-super-delegation,
missing-module-docstring,
missing-class-docstring,
missing-function-docstring,
import-error,
consider-using-with

[FORMAT]
expected-line-ending-format=LF
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ audiomentations==0.33.0
piper-phonemize==1.1.0
numpy<2
torch
webrtcvad
torchaudio
webrtcvad
5 changes: 5 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
black==22.12.0
flake8==6.0.0
isort==5.11.3
mypy==0.991
pylint==2.15.9
13 changes: 13 additions & 0 deletions script/format
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python3
import subprocess
import venv
from pathlib import Path

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_VENV_DIR = _PROGRAM_DIR / ".venv"
_SCRIPT = _PROGRAM_DIR / "generate_samples.py"

context = venv.EnvBuilder().ensure_directories(_VENV_DIR)
subprocess.check_call([context.env_exe, "-m", "black", str(_SCRIPT)])
subprocess.check_call([context.env_exe, "-m", "isort", str(_SCRIPT)])
16 changes: 16 additions & 0 deletions script/lint
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python3
import subprocess
import venv
from pathlib import Path

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_VENV_DIR = _PROGRAM_DIR / ".venv"
_SCRIPT = _PROGRAM_DIR / "generate_samples.py"

context = venv.EnvBuilder().ensure_directories(_VENV_DIR)
subprocess.check_call([context.env_exe, "-m", "black", str(_SCRIPT), "--check"])
subprocess.check_call([context.env_exe, "-m", "isort", str(_SCRIPT), "--check"])
subprocess.check_call([context.env_exe, "-m", "flake8", str(_SCRIPT)])
subprocess.check_call([context.env_exe, "-m", "pylint", str(_SCRIPT)])
subprocess.check_call([context.env_exe, "-m", "mypy", str(_SCRIPT)])
12 changes: 12 additions & 0 deletions script/run
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python3
import sys
import subprocess
import venv
from pathlib import Path

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_VENV_DIR = _PROGRAM_DIR / ".venv"

context = venv.EnvBuilder().ensure_directories(_VENV_DIR)
subprocess.check_call([context.env_exe, "generate_samples.py"] + sys.argv[1:])
22 changes: 22 additions & 0 deletions script/setup
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python3
import subprocess
import venv
from pathlib import Path

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_VENV_DIR = _PROGRAM_DIR / ".venv"


# Create virtual environment
builder = venv.EnvBuilder(with_pip=True)
context = builder.ensure_directories(_VENV_DIR)
builder.create(_VENV_DIR)

# Upgrade dependencies
pip = [context.env_exe, "-m", "pip"]
subprocess.check_call(pip + ["install", "--upgrade", "pip"])
subprocess.check_call(pip + ["install", "--upgrade", "setuptools", "wheel"])

# Install requirements
subprocess.check_call(pip + ["install", "-r", str(_PROGRAM_DIR / "requirements.txt")])
Loading

0 comments on commit 195e3bd

Please sign in to comment.