diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 8ce542b4..b768c3a2 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -137,7 +137,7 @@ def __init__( options: Optional[NamedTuple] = None, tokenizer=None, device: Union[int, str, "torch.device"] = -1, - vad_device: Union[int, str, "torch.device"] = "cuda", + vad_device: Union[int, str, "torch.device"] = "auto", framework="pt", language: Optional[str] = None, **kwargs, @@ -168,7 +168,6 @@ def __init__( self.device = device if self.use_vad_model: - # Separate vad_device from pipeline self.device self.vad_device = self.get_device(vad_device) # load vad model and perform VAD preprocessing if needed @@ -185,10 +184,24 @@ def _sanitize_parameters(self, **kwargs): preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] return preprocess_kwargs, {}, {} - def get_device(self, device): + def get_device(self, device: Union[int, str, "torch.device"]): + """ + Converts the input device into a torch.device object. + + The input can be an integer, a string, or a `torch.device` object. + + The function handles a special case where the input device is "auto". + When "auto" is specified, the device will default to the + device of the model (self.model.device). If the model's device is also "auto", + it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu". + """ if isinstance(device, torch.device): return device elif isinstance(device, str): + if device == "auto" and self.model.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + elif device == "auto": + device = self.model.device return torch.device(device) elif device < 0: return torch.device("cpu") @@ -683,12 +696,12 @@ def __init__( local_files_only=local_files_only, cache_dir=download_root, ) - + self.device = device # set the random seed to make sure consistency across runs ctranslate2.set_random_seed(42) self.model = ctranslate2.models.Whisper( model_path, - device=device, + device=self.device, device_index=device_index, compute_type=compute_type, intra_threads=cpu_threads,