Skip to content

Commit

Permalink
[doc]add gpu support for CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
wsstriving committed Nov 13, 2023
1 parent c6a42a1 commit 2ab6d51
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
*.egg-info

# Visual Studio Code files
.vscode
Expand All @@ -24,7 +25,7 @@ venv
*.swo
*.swp
*.swm
*~
*~S

# IPython notebook checkpoints
.ipynb_checkpoints
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ pip install git+https://github.com/wenet-e2e/wespeaker.git
$ wespeaker --task embedding --audio_file audio.wav --output_file embedding.txt
$ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav
$ wespeaker --task diarization --audio_file audio.wav # TODO

# Add -g or --gpu to specify the gpu id to use, number < 0 means using CPU
$ wespeaker --task embedding --audio_file audio.wav --output_file embedding.txt -g 0
$ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav --g 0
```

**Python programming usage**:
Expand All @@ -33,6 +37,8 @@ $ wespeaker --task diarization --audio_file audio.wav # TODO
import wespeaker

model = wespeaker.load_model('chinese')
# set_gpu to enable the cuda inference, number < 0 means using CPU
model.set_gpu(0)
embedding = model.extract_embedding('audio.wav')
similarity = model.compute_similarity('audio1.wav', 'audio2.wav')
diar_result = model.diarize('audio.wav') # TODO
Expand Down
21 changes: 19 additions & 2 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Binbin Zhang ([email protected])
# Shuai Wang ([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,14 +41,23 @@ def __init__(self, model_dir: str):
self.vad_model = vad.OnnxWrapper()
self.table = {}
self.resample_rate = 16000
self.apply_vad = True
self.apply_vad = False
self.device = torch.device('cpu')

def set_resample_rate(self, resample_rate: int):
self.resample_rate = resample_rate

def set_vad(self, apply_vad: bool):
self.apply_vad = apply_vad

def set_gpu(self, device_id: int):
if device_id >= 0:
device = 'cuda:{}'.format(device_id)
else:
device = 'cpu'
self.device = torch.device(device)
self.model = self.model.to(self.device)

def extract_embedding(self, audio_path: str):
pcm, sample_rate = torchaudio.load(audio_path, normalize=False)
if self.apply_vad:
Expand All @@ -71,10 +81,11 @@ def extract_embedding(self, audio_path: str):
sample_frequency=16000)
feats = feats - torch.mean(feats, 0) # CMN
feats = feats.unsqueeze(0)
feats = feats.to(self.device)
self.model.eval()
with torch.no_grad():
_, outputs = self.model(feats)
embedding = outputs[0]
embedding = outputs[0].to(torch.device('cpu'))
return embedding

def compute_similarity(self, audio_path1: str, audio_path2: str) -> float:
Expand Down Expand Up @@ -150,6 +161,11 @@ def get_args():
parser.add_argument('--vad',
action='store_true',
help='whether to do VAD or not')
parser.add_argument('-g',
'--gpu',
type=int,
default=-1,
help='which gpu to use (number <0 means using cpu)')
parser.add_argument('--output_file',
help='output file to save speaker embedding')
args = parser.parse_args()
Expand All @@ -161,6 +177,7 @@ def main():
model = load_model(args.language)
model.set_resample_rate(args.resample_rate)
model.set_vad(args.vad)
model.set_gpu(args.gpu)
if args.task == 'embedding':
embedding = model.extract_embedding(args.audio_file)
if embedding is not None:
Expand Down

0 comments on commit 2ab6d51

Please sign in to comment.