Skip to content

Commit

Permalink
Merge pull request #13 from mobiusml/fw_compliance
Browse files Browse the repository at this point in the history
making default vad_device same as asr model device
  • Loading branch information
Jiltseb authored May 24, 2024
2 parents d263cbd + 18bdaa8 commit b10b8cb
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
options: Optional[NamedTuple] = None,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
vad_device: Union[int, str, "torch.device"] = "cuda",
vad_device: Union[int, str, "torch.device"] = "auto",
framework="pt",
language: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -168,7 +168,6 @@ def __init__(
self.device = device

if self.use_vad_model:
# Separate vad_device from pipeline self.device
self.vad_device = self.get_device(vad_device)

# load vad model and perform VAD preprocessing if needed
Expand All @@ -185,10 +184,24 @@ def _sanitize_parameters(self, **kwargs):
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}

def get_device(self, device):
def get_device(self, device: Union[int, str, "torch.device"]):
"""
Converts the input device into a torch.device object.
The input can be an integer, a string, or a `torch.device` object.
The function handles a special case where the input device is "auto".
When "auto" is specified, the device will default to the
device of the model (self.model.device). If the model's device is also "auto",
it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu".
"""
if isinstance(device, torch.device):
return device
elif isinstance(device, str):
if device == "auto" and self.model.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "auto":
device = self.model.device
return torch.device(device)
elif device < 0:
return torch.device("cpu")
Expand Down Expand Up @@ -683,12 +696,12 @@ def __init__(
local_files_only=local_files_only,
cache_dir=download_root,
)

self.device = device
# set the random seed to make sure consistency across runs
ctranslate2.set_random_seed(42)
self.model = ctranslate2.models.Whisper(
model_path,
device=device,
device=self.device,
device_index=device_index,
compute_type=compute_type,
intra_threads=cpu_threads,
Expand Down

0 comments on commit b10b8cb

Please sign in to comment.