From a19dd0dcd0c8176a4dc8b16d7ad070dfa0f7b2a6 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 8 Nov 2023 16:28:18 +0800 Subject: [PATCH] [cli] mv all options to setting function (#219) * [cli] mv all options to setting function * apply vad by default --- wespeaker/cli/speaker.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 308fef50..c3a98d61 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -26,16 +26,23 @@ class Speaker: - def __init__(self, model_path: str, resample_rate: int = 16000): + def __init__(self, model_path: str): self.session = ort.InferenceSession(model_path) - self.resample_rate = resample_rate self.vad_model = vad.OnnxWrapper() self.table = {} + self.resample_rate = 16000 + self.apply_vad = True + + def set_resample_rate(self, resample_rate: int): + self.resample_rate = resample_rate + + def set_vad(self, apply_vad: bool): + self.apply_vad = apply_vad - def extract_embedding(self, audio_path: str, apply_vad: bool = False): + def extract_embedding(self, audio_path: str): pcm, sample_rate = librosa.load(audio_path, sr=self.resample_rate) pcm = pcm * (1 << 15) - if apply_vad: + if self.apply_vad: # TODO(Binbin Zhang): Refine the segments logic, here we just # suppose there is only silence at the start/end of the speech segments = vad.get_speech_timestamps(self.vad_model, @@ -64,8 +71,8 @@ def extract_embedding(self, audio_path: str, apply_vad: bool = False): return embedding def compute_similarity(self, audio_path1: str, audio_path2) -> float: - e1 = self.extract_embedding(audio_path1, True) - e2 = self.extract_embedding(audio_path2, True) + e1 = self.extract_embedding(audio_path1) + e2 = self.extract_embedding(audio_path2) if e1 is None or e2 is None: return 0.0 else: @@ -95,9 +102,9 @@ def recognize(self, audio_path: str): return result -def load_model(language: str, resample_rate: int) -> Speaker: +def load_model(language: str) -> Speaker: model_path = Hub.get_model(language) - return Speaker(model_path, resample_rate) + return Speaker(model_path) def get_args(): @@ -137,9 +144,11 @@ def get_args(): def main(): args = get_args() - model = load_model(args.language, args.resample_rate) + model = load_model(args.language) + model.set_resample_rate(args.resample_rate) + model.set_vad(args.vad) if args.task == 'embedding': - embedding = model.extract_embedding(args.audio_file, args.vad) + embedding = model.extract_embedding(args.audio_file) if embedding is not None: np.savetxt(args.output_file, embedding) print('Succeed, see {}'.format(args.output_file))