From 2ab6d5124e012b2b4b1379eb60b3b243b7f03541 Mon Sep 17 00:00:00 2001 From: Shuai Wang Date: Mon, 13 Nov 2023 16:08:47 +0800 Subject: [PATCH] [doc]add gpu support for CLI --- .gitignore | 3 ++- README.md | 6 ++++++ wespeaker/cli/speaker.py | 21 +++++++++++++++++++-- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index a2785a68..a35a4506 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +*.egg-info # Visual Studio Code files .vscode @@ -24,7 +25,7 @@ venv *.swo *.swp *.swm -*~ +*~S # IPython notebook checkpoints .ipynb_checkpoints diff --git a/README.md b/README.md index 7c64dcdb..2ce91123 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,10 @@ pip install git+https://github.com/wenet-e2e/wespeaker.git $ wespeaker --task embedding --audio_file audio.wav --output_file embedding.txt $ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav $ wespeaker --task diarization --audio_file audio.wav # TODO + +# Add -g or --gpu to specify the gpu id to use, number < 0 means using CPU +$ wespeaker --task embedding --audio_file audio.wav --output_file embedding.txt -g 0 +$ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav --g 0 ``` **Python programming usage**: @@ -33,6 +37,8 @@ $ wespeaker --task diarization --audio_file audio.wav # TODO import wespeaker model = wespeaker.load_model('chinese') +# set_gpu to enable the cuda inference, number < 0 means using CPU +model.set_gpu(0) embedding = model.extract_embedding('audio.wav') similarity = model.compute_similarity('audio1.wav', 'audio2.wav') diar_result = model.diarize('audio.wav') # TODO diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 08d81d5e..a9caa921 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) +# Shuai Wang (wsstriving@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +41,8 @@ def __init__(self, model_dir: str): self.vad_model = vad.OnnxWrapper() self.table = {} self.resample_rate = 16000 - self.apply_vad = True + self.apply_vad = False + self.device = torch.device('cpu') def set_resample_rate(self, resample_rate: int): self.resample_rate = resample_rate @@ -48,6 +50,14 @@ def set_resample_rate(self, resample_rate: int): def set_vad(self, apply_vad: bool): self.apply_vad = apply_vad + def set_gpu(self, device_id: int): + if device_id >= 0: + device = 'cuda:{}'.format(device_id) + else: + device = 'cpu' + self.device = torch.device(device) + self.model = self.model.to(self.device) + def extract_embedding(self, audio_path: str): pcm, sample_rate = torchaudio.load(audio_path, normalize=False) if self.apply_vad: @@ -71,10 +81,11 @@ def extract_embedding(self, audio_path: str): sample_frequency=16000) feats = feats - torch.mean(feats, 0) # CMN feats = feats.unsqueeze(0) + feats = feats.to(self.device) self.model.eval() with torch.no_grad(): _, outputs = self.model(feats) - embedding = outputs[0] + embedding = outputs[0].to(torch.device('cpu')) return embedding def compute_similarity(self, audio_path1: str, audio_path2: str) -> float: @@ -150,6 +161,11 @@ def get_args(): parser.add_argument('--vad', action='store_true', help='whether to do VAD or not') + parser.add_argument('-g', + '--gpu', + type=int, + default=-1, + help='which gpu to use (number <0 means using cpu)') parser.add_argument('--output_file', help='output file to save speaker embedding') args = parser.parse_args() @@ -161,6 +177,7 @@ def main(): model = load_model(args.language) model.set_resample_rate(args.resample_rate) model.set_vad(args.vad) + model.set_gpu(args.gpu) if args.task == 'embedding': embedding = model.extract_embedding(args.audio_file) if embedding is not None: