diff --git a/README.md b/README.md index f323312..3558195 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ wget -O models/en-us-libritts-high.pt 'https://github.com/rhasspy/piper-sample-g ## Run -Generate a small set of samples: +Generate a small set of samples with the CLI: ``` sh python3 generate_samples.py 'okay, piper.' --max-samples 10 --output-dir okay_piper/ @@ -46,6 +46,16 @@ Setting `--max-speakers` to a value less than 904 (the number of speakers LibriT See `--help` for more options, including adjust the `--length-scales` (speaking speeds) and `--slerp-weights` (speaker blending) which are cycled per batch. +Alternatively, you can import the generate function into another Python script: + +```python +from generate_samples import generate_samples # make sure to add this to your Python path as needed + +generate_samples(text = ["okay, piper"], max_samples = 100, output_dir = output_dir, batch_size=10) +``` + +There are some additional arguments available when importing the function directly, see the docstring of `generate_sample` for more information. + ### Augmentation Once you have samples generating, you can augment them using [audiomentation](https://iver56.github.io/audiomentations/): diff --git a/generate_samples.py b/generate_samples.py index 1461bbe..8371422 100755 --- a/generate_samples.py +++ b/generate_samples.py @@ -2,55 +2,84 @@ import argparse import itertools as it import json +import os +import gc import logging import unicodedata import wave from pathlib import Path +from tqdm import tqdm +from types import SimpleNamespace +from typing import Union, List +import webrtcvad import numpy as np import torch +import torchaudio from piper_phonemize import phonemize_espeak, phoneme_ids_espeak from piper_train.vits import commons _DIR = Path(__file__).parent _LOGGER = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + +# Main generation function +def generate_samples( + text: Union[List, str], + output_dir: str, + max_samples: int=None, + file_names: List[str] = [], + model: str = os.path.join(Path(__file__).parent, "models", "en-us-libritts-high.pt"), + batch_size: int = 1, + slerp_weights: List[float] = [0.5], + length_scales: List[float] = [0.75, 1, 1.25], + noise_scales: List[float] = [0.667], + noise_scale_ws: List[float] = [0.8], + max_speakers: float = None, + verbose: bool = False, + auto_reduce_batch_size: bool = False, + **kwargs + ) -> None: + """ + Generate synthetic speech clips, saving the clips to the specified output directory. + + Args: + text (List[str]): The text to convert into speech. Can be either a + a list of strings, or a path to a file with text on each line. + output_dir (str): The location to save the generated clips. + max_samples (int): The maximum number of samples to generate. + file_names (List[str]): The names to use when saving the files. Must be the same length + as the `text` argument, if a list. + model (str): The path to the STT model to use for generation. + batch_size (int): The batch size to use when generated the clips + slerp_weights (List[float]): The weights to use when mixing speakers via SLERP. + length_scales (List[float]): Controls the average duration/speed of the generated speech. + noise_scales (List[float]): A parameter for overall variability of the generated speech. + noise_scale_ws (List[float]): A parameter for the stochastic duration of words/phonemes. + max_speakers (int): The maximum speaker number to use, if the model is multi-speaker. + verbose (bool): Enable or disable more detailed logging messages (default: False). + auto_reduce_batch_size (bool): Automatically and temporarily reduce the batch size + if CUDA OOM errors are detected, and try to resume generation. + + Returns: + None + """ + if max_samples is None: + max_samples = len(text) -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("text") - parser.add_argument("--max-samples", required=True, type=int) - parser.add_argument( - "--model", default=_DIR / "models" / "en_US-libritts_r-medium.pt" - ) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--slerp-weights", nargs="+", type=float, default=[0.5]) - parser.add_argument( - "--length-scales", nargs="+", type=float, default=[1.0, 0.75, 1.25] - ) - parser.add_argument("--noise-scales", nargs="+", type=float, default=[0.667]) - parser.add_argument("--noise-scale-ws", nargs="+", type=float, default=[0.8]) - parser.add_argument("--output-dir", default="output") - parser.add_argument( - "--max-speakers", - type=int, - help="Maximum number of speakers to use (default: all)", - ) - args = parser.parse_args() - logging.basicConfig(level=logging.DEBUG) - - _LOGGER.debug("Loading %s", args.model) - model_path = Path(args.model) + _LOGGER.debug("Loading %s", model) + model_path = Path(model) model = torch.load(model_path) model.eval() - _LOGGER.info("Successfully loaded %s", args.model) + _LOGGER.info("Successfully loaded the model") if torch.cuda.is_available(): model.cuda() _LOGGER.debug("CUDA available, using GPU") - output_dir = Path(args.output_dir) + output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) config_path = f"{model_path}.json" @@ -60,27 +89,10 @@ def main() -> None: voice = config["espeak"]["voice"] sample_rate = config["audio"]["sample_rate"] num_speakers = config["num_speakers"] - if args.max_speakers is not None: - num_speakers = min(num_speakers, args.max_speakers) + if max_speakers is not None: + num_speakers = min(num_speakers, max_speakers) - # Combine all sentences - phonemes = [ - p - for sentence_phonemes in phonemize_espeak(args.text, voice) - for p in sentence_phonemes - ] - _LOGGER.debug("Phonemes: %s", phonemes) - - id_map = config["phoneme_id_map"] - phoneme_ids = list(id_map["^"]) - for phoneme in phonemes: - p_ids = id_map.get(phoneme) - if p_ids is not None: - phoneme_ids.extend(p_ids) - phoneme_ids.extend(id_map["_"]) - - phoneme_ids.extend(id_map["$"]) - _LOGGER.debug("Phonemes ids: %s", phoneme_ids) + phonemizer = Phonemizer(voice) max_len = None @@ -88,15 +100,37 @@ def main() -> None: is_done = False settings_iter = it.cycle( it.product( - args.slerp_weights, - args.length_scales, - args.noise_scales, - args.noise_scale_ws, + slerp_weights, + length_scales, + noise_scales, + noise_scale_ws, ) ) + # Define resampler to get to 16khz (https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best) + sample_rate = 22050 + resample_rate = 16000 + resampler = torchaudio.transforms.Resample( + sample_rate, + resample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="kaiser_window", + beta=14.769656459379492 + ) + speakers_iter = it.cycle(it.product(range(num_speakers), range(num_speakers))) - speakers_batch = list(it.islice(speakers_iter, 0, args.batch_size)) + speakers_batch = list(it.islice(speakers_iter, 0, batch_size)) + if isinstance(text, str) and os.path.exists(text): + texts = it.cycle([i.strip() for i in open(text, 'r').readlines() if len(i.strip()) > 0]) + elif isinstance(text, list): + texts = it.cycle(text) + else: + texts = it.cycle([text]) + + if file_names: + file_names = it.cycle(file_names) + batch_idx = 0 while speakers_batch: if is_done: @@ -109,71 +143,138 @@ def main() -> None: speaker_1 = torch.LongTensor([s[0] for s in speakers_batch]) speaker_2 = torch.LongTensor([s[1] for s in speakers_batch]) - x = torch.LongTensor(phoneme_ids).repeat((batch_size, 1)) - x_lengths = torch.LongTensor([len(phoneme_ids)]).repeat(batch_size) - - if torch.cuda.is_available(): - speaker_1 = speaker_1.cuda() - speaker_2 = speaker_2.cuda() - x = x.cuda() - x_lengths = x_lengths.cuda() - - x, m_p_orig, logs_p_orig, x_mask = model.enc_p(x, x_lengths) - emb0 = model.emb_g(speaker_1) - emb1 = model.emb_g(speaker_2) - g = slerp(emb0, emb1, slerp_weight).unsqueeze(-1) # [b, h, 1] - - if model.use_sdp: - logw = model.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + phoneme_ids = [get_phonemes(phonemizer, config, next(texts), verbose) for i in range(batch_size)] + + def right_pad_lists(lists): + max_length = max(len(l) for l in lists) + padded_lists = [] + for l in lists: + padded_l = l + [1] * (max_length - len(l)) # phoneme 1 (corresponding to '^' character seems to work best) + padded_lists.append(padded_l) + return padded_lists + + phoneme_ids = right_pad_lists(phoneme_ids) + + if auto_reduce_batch_size: + oom_error = True + counter = 1 + while oom_error is True: + try: + audio = generate_audio(model, speaker_1[0:batch_size//counter], speaker_2[0:batch_size//counter], phoneme_ids[0:batch_size//counter], + slerp_weight, noise_scale, noise_scale_w, length_scale, max_len) + oom_error = False + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + gc.collect() + counter += 1 # reduce batch size to avoid OOM errors else: - logw = model.dp(x, x_mask, g=g) - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze( - commons.sequence_mask(y_lengths, y_lengths.max()), 1 - ).type_as(x_mask) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p_orig.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul( - attn.squeeze(1), logs_p_orig.transpose(1, 2) - ).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = model.flow(z_p, y_mask, g=g, reverse=True) - o = model.dec((z * y_mask)[:, :, :max_len], g=g) - - audio = o.cpu().numpy() + audio = generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise_scale, noise_scale_w, length_scale, max_len) + + # Resample audio + audio = resampler(audio.cpu()).numpy() audio_int16 = audio_float_to_int16(audio) - for audio_idx in range(batch_size): - wav_path = output_dir / f"{sample_idx}.wav" + for audio_idx in range(audio_int16.shape[0]): + # Use webrtcvad to trip silence from the clips + audio_data = remove_silence(audio_int16[audio_idx].flatten())[None,] + + if isinstance(file_names, it.cycle): + wav_path = output_dir / next(file_names) + else: + wav_path = output_dir / f"{sample_idx}.wav" with wave.open(str(wav_path), "wb") as wav_file: - wav_file.setframerate(sample_rate) + wav_file.setframerate(resample_rate) wav_file.setsampwidth(2) wav_file.setnchannels(1) - wav_file.writeframes(audio_int16[audio_idx]) - - print(wav_path) + wav_file.writeframes(audio_data) sample_idx += 1 - if sample_idx >= args.max_samples: + if sample_idx >= max_samples: is_done = True break + # print(f"Batch {batch_idx +1}/{max_samples//batch_size} complete", " "*200, end='\r') + # Next batch - _LOGGER.debug("Batch %s complete", batch_idx + 1) - speakers_batch = list(it.islice(speakers_iter, 0, args.batch_size)) + _LOGGER.debug(f"Batch {batch_idx +1}/{max_samples//batch_size} complete") + speakers_batch = list(it.islice(speakers_iter, 0, batch_size)) batch_idx += 1 _LOGGER.info("Done") +def remove_silence(x, frame_duration=.030, sample_rate=16000, min_start = 2000): + """Uses webrtc voice activity detection to remove silence from the clips""" + vad = webrtcvad.Vad(0) + if x.dtype == np.float32 or x.dtype == np.float64: + x = (x*32767).astype(np.int16) + x_new = x[0:min_start].tolist() + step_size = int(sample_rate*frame_duration) + for i in range(min_start, x.shape[0] - step_size, step_size): + vad_res = vad.is_speech(x[i:i+step_size].tobytes(), sample_rate) + if vad_res: + x_new.extend(x[i:i+step_size].tolist()) + return np.array(x_new).astype(np.int16) + +def generate_audio(model, speaker_1, speaker_2, phoneme_ids, slerp_weight, noise_scale, noise_scale_w, length_scale, max_len): + x = torch.LongTensor(phoneme_ids) + x_lengths = torch.LongTensor([len(i) for i in phoneme_ids]) + + if torch.cuda.is_available(): + speaker_1 = speaker_1.cuda() + speaker_2 = speaker_2.cuda() + x = x.cuda() + x_lengths = x_lengths.cuda() + + x, m_p_orig, logs_p_orig, x_mask = model.enc_p(x, x_lengths) + emb0 = model.emb_g(speaker_1) + emb1 = model.emb_g(speaker_2) + g = slerp(emb0, emb1, slerp_weight).unsqueeze(-1) # [b, h, 1] + + if model.use_sdp: + logw = model.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = model.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze( + commons.sequence_mask(y_lengths, y_lengths.max()), 1 + ).type_as(x_mask) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p_orig.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul( + attn.squeeze(1), logs_p_orig.transpose(1, 2) + ).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = model.flow(z_p, y_mask, g=g, reverse=True) + o = model.dec((z * y_mask)[:, :, :max_len], g=g) + + audio = o + return audio + +def get_phonemes(phonemizer, config, text, verbose): + phonemes_str = phonemizer.phonemize(text) + phonemes = list(unicodedata.normalize("NFD", phonemes_str)) + if verbose is True: + _LOGGER.debug("Phonemes: %s", phonemes) + + id_map = config["phoneme_id_map"] + phoneme_ids = list(id_map["^"]) + for phoneme in phonemes: + p_ids = id_map.get(phoneme) + if p_ids is not None: + phoneme_ids.extend(p_ids) + phoneme_ids.extend(id_map["_"]) + + phoneme_ids.extend(id_map["$"]) + return phoneme_ids def slerp(v1, v2, t, DOT_THR=0.9995, zdim=-1): """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. @@ -231,4 +332,25 @@ def audio_float_to_int16( if __name__ == "__main__": - main() + # Get command line arguments + parser = argparse.ArgumentParser() + parser.add_argument("text") + parser.add_argument("--max-samples", required=True, type=int) + parser.add_argument("--model", default=_DIR / "models" / "en-us-libritts-high.pt") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--slerp-weights", nargs="+", type=float, default=[0.5]) + parser.add_argument( + "--length-scales", nargs="+", type=float, default=[1.0, 0.75, 1.25, 1.4] + ) + parser.add_argument("--noise-scales", nargs="+", type=float, default=[0.667, .75, .85, 0.9, 1.0, 1.4]) + parser.add_argument("--noise-scale-ws", nargs="+", type=float, default=[0.8]) + parser.add_argument("--output-dir", default="output") + parser.add_argument( + "--max-speakers", + type=int, + help="Maximum number of speakers to use (default: all)", + ) + args = parser.parse_args().__dict__ + + # Generate speech + generate_samples(**args) diff --git a/requirements.txt b/requirements.txt index d5ff5ff..358d800 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ audiomentations==0.33.0 piper-phonemize==1.1.0 numpy<2 torch +webrtcvad \ No newline at end of file