Skip to content

Commit

Permalink
[cli] mv all options to setting function (#219)
Browse files Browse the repository at this point in the history
* [cli] mv all options to setting function

* apply vad by default
  • Loading branch information
robin1001 authored Nov 8, 2023
1 parent e87bd0c commit a19dd0d
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,23 @@


class Speaker:
def __init__(self, model_path: str, resample_rate: int = 16000):
def __init__(self, model_path: str):
self.session = ort.InferenceSession(model_path)
self.resample_rate = resample_rate
self.vad_model = vad.OnnxWrapper()
self.table = {}
self.resample_rate = 16000
self.apply_vad = True

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad

def extract_embedding(self, audio_path: str, apply_vad: bool = False):
def extract_embedding(self, audio_path: str):
pcm, sample_rate = librosa.load(audio_path, sr=self.resample_rate)
pcm = pcm * (1 << 15)
if apply_vad:
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
segments = vad.get_speech_timestamps(self.vad_model,
Expand Down Expand Up @@ -64,8 +71,8 @@ def extract_embedding(self, audio_path: str, apply_vad: bool = False):
return embedding

def compute_similarity(self, audio_path1: str, audio_path2) -> float:
e1 = self.extract_embedding(audio_path1, True)
e2 = self.extract_embedding(audio_path2, True)
e1 = self.extract_embedding(audio_path1)
e2 = self.extract_embedding(audio_path2)
if e1 is None or e2 is None:
return 0.0
else:
Expand Down Expand Up @@ -95,9 +102,9 @@ def recognize(self, audio_path: str):
return result


def load_model(language: str, resample_rate: int) -> Speaker:
def load_model(language: str) -> Speaker:
model_path = Hub.get_model(language)
return Speaker(model_path, resample_rate)
return Speaker(model_path)


def get_args():
Expand Down Expand Up @@ -137,9 +144,11 @@ def get_args():

def main():
args = get_args()
model = load_model(args.language, args.resample_rate)
model = load_model(args.language)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
if args.task == 'embedding':
embedding = model.extract_embedding(args.audio_file, args.vad)
embedding = model.extract_embedding(args.audio_file)
if embedding is not None:
np.savetxt(args.output_file, embedding)
print('Succeed, see {}'.format(args.output_file))
Expand Down

0 comments on commit a19dd0d

Please sign in to comment.