Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cli] mv all options to setting function #219

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,23 @@


class Speaker:
def __init__(self, model_path: str, resample_rate: int = 16000):
def __init__(self, model_path: str):
self.session = ort.InferenceSession(model_path)
self.resample_rate = resample_rate
self.vad_model = vad.OnnxWrapper()
self.table = {}
self.resample_rate = 16000
self.apply_vad = True

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad

def extract_embedding(self, audio_path: str, apply_vad: bool = False):
def extract_embedding(self, audio_path: str):
pcm, sample_rate = librosa.load(audio_path, sr=self.resample_rate)
pcm = pcm * (1 << 15)
if apply_vad:
if self.apply_vad:
# TODO(Binbin Zhang): Refine the segments logic, here we just
# suppose there is only silence at the start/end of the speech
segments = vad.get_speech_timestamps(self.vad_model,
Expand Down Expand Up @@ -64,8 +71,8 @@ def extract_embedding(self, audio_path: str, apply_vad: bool = False):
return embedding

def compute_similarity(self, audio_path1: str, audio_path2) -> float:
e1 = self.extract_embedding(audio_path1, True)
e2 = self.extract_embedding(audio_path2, True)
e1 = self.extract_embedding(audio_path1)
e2 = self.extract_embedding(audio_path2)
if e1 is None or e2 is None:
return 0.0
else:
Expand Down Expand Up @@ -95,9 +102,9 @@ def recognize(self, audio_path: str):
return result


def load_model(language: str, resample_rate: int) -> Speaker:
def load_model(language: str) -> Speaker:
model_path = Hub.get_model(language)
return Speaker(model_path, resample_rate)
return Speaker(model_path)


def get_args():
Expand Down Expand Up @@ -137,9 +144,11 @@ def get_args():

def main():
args = get_args()
model = load_model(args.language, args.resample_rate)
model = load_model(args.language)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
if args.task == 'embedding':
embedding = model.extract_embedding(args.audio_file, args.vad)
embedding = model.extract_embedding(args.audio_file)
if embedding is not None:
np.savetxt(args.output_file, embedding)
print('Succeed, see {}'.format(args.output_file))
Expand Down
Loading