diff --git a/README.md b/README.md index 0a1f49e7..4da271ad 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ language_info = model.detect_language_multi_segment("audio.mp3") ### Batched faster-whisper -The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-4 Clause license. This product includes software developed by Max Bain. We modify this implementation and also added kaldi-based feature extraction. It improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference. +The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-4 Clause license and integrates its VAD model to this library. This product includes software developed by Max Bain. We modify this implementation and also added kaldi-based feature extraction. It improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference. The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper. @@ -263,6 +263,7 @@ See more model and transcription options in the [`WhisperModel`](https://github. Here is a non exhaustive list of open-source projects using faster-whisper. Feel free to add your project to the list! +* [faster-whisper-server](https://github.com/fedirz/faster-whisper-server) is an OpenAI compatible server using `faster-whisper`. It's easily deployable with Docker, works with OpenAI SDKs/CLI, supports streaming, and live transcription. * [WhisperX](https://github.com/m-bain/whisperX) is an award-winning Python library that offers speaker diarization and accurate word-level timestamps using wav2vec2 alignment * [whisper-ctranslate2](https://github.com/Softcatala/whisper-ctranslate2) is a command line client based on faster-whisper and compatible with the original client from openai/whisper. * [whisper-diarize](https://github.com/MahmoudAshraf97/whisper-diarization) is a speaker diarization tool that is based on faster-whisper and NVIDIA NeMo. diff --git a/faster_whisper/assets/pyannote_vad_model.bin b/faster_whisper/assets/pyannote_vad_model.bin new file mode 100644 index 00000000..75c92f09 Binary files /dev/null and b/faster_whisper/assets/pyannote_vad_model.bin differ diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index a597fd83..959a4fd8 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -15,6 +15,27 @@ import av import numpy as np +# Audio Hyperparameters + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token +TIME_PRECISION = 1 / TOKENS_PER_SECOND + def decode_audio( input_file: Union[str, BinaryIO], diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 75bf8372..cf5d95fa 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -163,7 +163,7 @@ def __call__(self, waveform, enable_ta=False, padding=True, chunk_length=None): waveform = np.pad(waveform, [(0, self.n_samples)]) if enable_ta: - audio = torch.from_numpy(waveform).unsqueeze(0) + audio = torch.from_numpy(waveform).unsqueeze(0).float() fbank = ta_kaldi.fbank( audio, sample_frequency=self.sampling_rate, @@ -177,7 +177,7 @@ def __call__(self, waveform, enable_ta=False, padding=True, chunk_length=None): # Audioset values as default mean and std for audio mean_val = -4.2677393 std_val = 4.5689974 - scaled_features = (log_spec - (mean_val)) / (std_val * 2) + scaled_features = (log_spec - mean_val) / (std_val * 2) log_spec = scaled_features else: diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index b768c3a2..2731d116 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -1,10 +1,8 @@ -import hashlib import itertools import json import logging import os import random -import urllib import zlib from collections import Counter, defaultdict @@ -18,14 +16,19 @@ import torch from pyannote.audio import Model -from tqdm import tqdm from transformers import Pipeline from transformers.pipelines.pt_utils import PipelineIterator -from faster_whisper.audio import decode_audio, pad_or_trim +from faster_whisper.audio import TIME_PRECISION, decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer -from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger +from faster_whisper.utils import ( + download_model, + format_timestamp, + get_assets_path, + get_end, + get_logger, +) from faster_whisper.vad import ( SpeechTimestampsMap, VadOptions, @@ -59,16 +62,21 @@ class Segment(NamedTuple): class BatchedSegment(NamedTuple): """ - A single segment in batched transcription (up to multiple sentences) of a speech. + A single segment in batched transcription of a speech. start (float): Start time in seconds. end (float): End time in seconds. text (str): transcription of the segment. + avg_logprob (float): average log probability of the segment. + no_speech_prob (float): no speech probability of the segment. """ start: float end: float text: str + words: Optional[List[Word]] + no_speech_prob: float + avg_logprob: float # Added additional parameters for multilingual videos and fixes below @@ -80,7 +88,7 @@ class TranscriptionOptions(NamedTuple): repetition_penalty: float no_repeat_ngram_size: int log_prob_threshold: Optional[float] - log_prob_low_threshold: Optional[float] # New parameter + log_prob_low_threshold: Optional[float] no_speech_threshold: Optional[float] compression_ratio_threshold: Optional[float] condition_on_previous_text: bool @@ -95,8 +103,8 @@ class TranscriptionOptions(NamedTuple): word_timestamps: bool prepend_punctuations: str append_punctuations: str - multilingual: bool # New parameter - output_language: str # New parameter + multilingual: bool + output_language: Optional[str] max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] @@ -137,6 +145,7 @@ def __init__( options: Optional[NamedTuple] = None, tokenizer=None, device: Union[int, str, "torch.device"] = -1, + chunk_size: int = 30, vad_device: Union[int, str, "torch.device"] = "auto", framework="pt", language: Optional[str] = None, @@ -151,10 +160,8 @@ def __init__( self.use_vad_model = use_vad_model self.vad_onset = 0.500 self.vad_offset = 0.363 - self.vad_model_url = ( - "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation" - "/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin" - ) + self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin") + ( self._preprocess_params, self._forward_params, @@ -174,7 +181,7 @@ def __init__( self.vad_model = self.load_vad_model( vad_onset=self.vad_onset, vad_offset=self.vad_offset ) - self.chunk_size = 30 # VAD merging size + self.chunk_size = chunk_size # VAD merging size super(Pipeline, self).__init__() @@ -208,20 +215,35 @@ def get_device(self, device: Union[int, str, "torch.device"]): else: return torch.device(f"cuda:{device}") - def preprocess(self, audio, enable_ta_fe=True): - audio = audio["inputs"] + def preprocess(self, inputs, enable_ta_fe=True): + audio = inputs["inputs"] features = torch.tensor( self.model.feature_extractor(audio, enable_ta=enable_ta_fe, padding=True)[ :, : self.model.feature_extractor.nb_max_frames ] ) - return {"inputs": features} + inputs["features"] = features + return inputs def _forward(self, model_inputs, **forward_params): - outputs = self.model.generate_segment_batched( - model_inputs["inputs"], self.tokenizer, forward_params + ( + encoder_output, + sot_seqs, + text_tokens, + output, + ) = self.model.generate_segment_batched( + model_inputs["features"], self.tokenizer, forward_params ) - return {"text": outputs} + + if forward_params["word_timestamps"]: + word_timings = self.align_words( + encoder_output, text_tokens, sot_seqs, model_inputs["seg_metadata"] + ) + + for _response, _word_timings in zip(output, word_timings): + _response["word_timestamps"] = _word_timings + + return {"output": output} def __call__( self, inputs, options, enable_ta_fe, num_workers=None, batch_size=None, **kwargs @@ -242,6 +264,7 @@ def __call__( forward_params, postprocess_params, ) = self._sanitize_parameters(**kwargs) + # Fuse __init__ params and __call__ params without modifying the __init__ ones. preprocess_params = { **self._preprocess_params, @@ -284,7 +307,11 @@ def get_iterator( postprocess_params=None, ): def stack(items): - return {"inputs": torch.stack([x["inputs"] for x in items])} + return { + "inputs": [x["inputs"] for x in items], + "seg_metadata": [x["seg_metadata"] for x in items], + "features": torch.stack([x["features"] for x in items]), + } if "TOKENIZERS_PARALLELISM" not in os.environ: os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -328,46 +355,19 @@ def get_language_and_tokenizer(self, audio, task=None, language=None): def audio_split(self, audio, segments, sampling_rate): "Returns splitted audio chunks as iterator" + for seg in segments: f1 = int(seg["start"] * sampling_rate) f2 = int(seg["end"] * sampling_rate) - yield {"inputs": audio[f1:f2]} - - # The code below is adapted from whisper-x - def load_vad_model(self, vad_onset=0.500, vad_offset=0.363, use_auth_token=None): - model_dir = torch.hub._get_torch_home() - os.makedirs(model_dir, exist_ok=True) - model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") - if os.path.exists(model_fp) and not os.path.isfile(model_fp): - raise RuntimeError(f"{model_fp} exists and is not a regular file") - - if not os.path.isfile(model_fp): - with urllib.request.urlopen(self.vad_model_url) as source, open( - model_fp, "wb" - ) as output: - with tqdm( - total=int(source.info().get("Content-Length")), - ncols=80, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - model_bytes = open(model_fp, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != self.vad_model_url.split("/")[-2]: - raise RuntimeError( - "Model SHA256 checksum does not not match. Please retry loading the model." - ) + seg_metadata = { + "start_time": seg["start"], + "end_time": seg["end"], + "stitched_seg": seg["segments"], + } + yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata} - # or use silero VAD - vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token) + def load_vad_model(self, vad_onset=0.500, vad_offset=0.363): + vad_model = Model.from_pretrained(self.vad_model_path) hyperparameters = { "onset": vad_onset, "offset": vad_offset, @@ -381,6 +381,141 @@ def load_vad_model(self, vad_onset=0.500, vad_offset=0.363, use_auth_token=None) vad_pipeline.instantiate(hyperparameters) return vad_pipeline + def align_words(self, features, text_tokens, sot_seqs, seg_metadata): + # Split text into word tokens using the tokenizer + word_tokens = [] + for tokens in text_tokens: + word_tokens.append(self.tokenizer.split_to_word_tokens(tokens)) + + # Group indices by start sequence + start_seq_wise_req = {} + for _idx, _sot_seq in enumerate(sot_seqs): + if _sot_seq not in start_seq_wise_req: + start_seq_wise_req[_sot_seq] = [] + start_seq_wise_req[_sot_seq].append(_idx) + + # Initialize token alignments for each segment metadata + token_alignments = [[] for _ in seg_metadata] + duration_list = [ + int( + (seg_meta["end_time"] - seg_meta["start_time"]) + / self.model.feature_extractor.time_per_frame + ) + for seg_meta in seg_metadata + ] + + # Perform alignment for each group of indices with the same start sequence + start_seq = list(start_seq_wise_req.items())[0] + + res = self.model.model.align( + features, + start_sequence=list(start_seq[0]), + text_tokens=text_tokens, + num_frames=duration_list, + median_filter_width=7, + ) + for start_seq, req_idx in start_seq_wise_req.items(): + for _res, _req_idx in zip(res, req_idx): + token_alignments[_req_idx] = _res + + # Process each segment's metadata to align word timings + word_timings = [] + for _idx, _seg_metadata in enumerate(seg_metadata): + _word_timings = self.model.assign_word_timings( + token_alignments[_idx].alignments, + token_alignments[_idx].text_token_probs, + word_tokens[_idx][0], + word_tokens[_idx][1], + ) + + stitched_seg = _seg_metadata["stitched_seg"] + current_seg_idx = 0 + current_offset = stitched_seg[0][0] + + for w in _word_timings: + w["start"] += current_offset + w["end"] += current_offset + + if ( + current_seg_idx < len(stitched_seg) + and (w["start"]) <= stitched_seg[current_seg_idx][1] + and (w["end"]) >= stitched_seg[current_seg_idx][1] + ): + w["end"] = stitched_seg[current_seg_idx][1] # replace by seg end + + while ( + current_seg_idx < len(stitched_seg) + and (w["start"]) >= stitched_seg[current_seg_idx][1] + ): + current_seg_idx += 1 + + word_timings.append(_word_timings) + + return word_timings + + def combine_words(self, metadata, response): + combined_segments = [] + + for meta, res in zip(metadata, response): + word_timestamps = res["word_timestamps"] + segment_texts = [] + segment_index = 0 + current_segment = meta["segments"][segment_index] + current_text = [] + current_word_timestamps = [] + current_start = current_segment[0] + + for idx, word_info in enumerate(word_timestamps): + word_start, word_end, word_text = ( + word_info["start"], + word_info["end"], + word_info["word"], + ) + + # Move to the next segment if the word is outside the current segment + while ( + word_start >= current_segment[1] + and segment_index < len(meta["segments"]) - 1 + ): + # Save the completed segment + if current_text: + segment_texts.append( + { + "start": current_start, + "end": current_segment[1], + "text": "".join(current_text), + "word_timestamps": current_word_timestamps, + "avg_logprob": res["avg_logprob"], + "no_speech_prob": res["no_speech_prob"], + } + ) + segment_index += 1 + current_segment = meta["segments"][segment_index] + current_start = current_segment[0] + current_text = [] + current_word_timestamps = [] + + # Add word to the current segment text + if word_start >= current_segment[0] and word_end <= current_segment[1]: + current_text.append(word_text) + current_word_timestamps.append(word_info) + + # Save the final segment + if current_text: + segment_texts.append( + { + "start": current_start, + "end": current_segment[1], + "text": "".join(current_text), + "word_timestamps": current_word_timestamps, + "avg_logprob": res["avg_logprob"], + "no_speech_prob": res["no_speech_prob"], + } + ) + + combined_segments.extend(segment_texts) + return combined_segments + def transcribe( self, audio: Union[str, np.ndarray], @@ -406,7 +541,7 @@ def transcribe( ], compression_ratio_threshold: Optional[float] = 2.4, log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = -2.0, + log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, initial_prompt: Optional[Union[str, Iterable[int]]] = None, prefix: Optional[str] = None, @@ -418,6 +553,7 @@ def transcribe( max_new_tokens: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hotwords: Optional[str] = None, + word_timestamps: bool = False, ) -> Tuple[Iterable[BatchedSegment], TranscriptionInfo]: """transcribe audio in chunks in batched fashion and return with language info. @@ -447,7 +583,7 @@ def transcribe( log_prob_threshold: If the average log probability over sampled tokens is below this value, treat as failed. log_prob_low_threshold: This parameter alone is sufficient to skip an output text, - wheras log_prob_threshold also looks for appropriate no_speech_threshold value. + whereas log_prob_threshold also looks for appropriate no_speech_threshold value. This value should be less than log_prob_threshold. no_speech_threshold: If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `log_prob_threshold`, @@ -470,13 +606,13 @@ def transcribe( process. The last end timestamp defaults to the end of the file. hotwords: Hotwords/hint phrases to the model. Has no effect if prefix is not None. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + Set as False. Static params: (Fixed for batched version) without_timestamps: Only sample text tokens, set as True. max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. - word_timestamps: Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - Set as False. multilingual: If True, perform transcription on multilingual videos. Set as False. output_language: Valid only if multilingual is set to True. Specifies the string representing the output language. One of @@ -516,6 +652,7 @@ def transcribe( if isinstance(audio, str): audio = decode_audio(audio) + duration = audio.shape[0] / sampling_rate # if no segment split is provided, use vad_model and generate segments if not vad_segments: @@ -567,14 +704,14 @@ def transcribe( max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hotwords=hotwords, + word_timestamps=word_timestamps, hallucination_silence_threshold=None, condition_on_previous_text=False, prompt_reset_on_temperature=0.5, multilingual=False, - word_timestamps=False, output_language=None, without_timestamps=True, - max_initial_timestamp=0.0, + max_initial_timestamp=0.0, ) for idx, out in enumerate( @@ -586,31 +723,46 @@ def transcribe( options=batched_options, ) ): - # inputs, *args, num_workers=None, batch_size=None, **kwargs if log_progress: percent_complete = ((idx + 1) / total_segments) * 100 self.model.logger.info(f"Progress: {percent_complete:.2f}%...") - text = out["text"] - if batch_size in [0, 1, None]: - text = text[0] - - segments = BatchedSegment( - text=text, - start=round(vad_segments[idx]["start"], 3), - end=round(vad_segments[idx]["end"], 3), - ) + response = out["output"] info = TranscriptionInfo( language=language, language_probability=language_probability, - duration=0.0, - duration_after_vad=0.0, + duration=duration, + duration_after_vad=None, transcription_options=batched_options, vad_options=None, all_language_probs=None, ) - yield segments, info + + if not batched_options.word_timestamps: + segments = BatchedSegment( + text=response["text"], + start=round(vad_segments[idx]["start"], 3), + end=round(vad_segments[idx]["end"], 3), + words=None, + avg_logprob=response["avg_logprob"], + no_speech_prob=response["no_speech_prob"], + ) + yield segments, info + + else: + response = self.combine_words([vad_segments[idx]], [response]) + segments = [] + for res in response: + segments = BatchedSegment( + text=res["text"], + start=round(res["start"], 3), + end=round(res["end"], 3), + words=res["word_timestamps"], + avg_logprob=res["avg_logprob"], + no_speech_prob=res["no_speech_prob"], + ) + yield segments, info # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: @@ -776,7 +928,7 @@ def transcribe( ], compression_ratio_threshold: Optional[float] = 2.4, log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = -2.0, + log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, prompt_reset_on_temperature: float = 0.5, @@ -1204,10 +1356,6 @@ def generate_segments( # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False - # Skip if the logprob is very low (below the threshold value), - # despite no_speech_prob being low (ex: Too ambiguous outputs) - if avg_logprob < options.log_prob_low_threshold: - should_skip = True if should_skip: self.logger.debug( "No speech threshold is met (%f > %f)", @@ -1215,9 +1363,20 @@ def generate_segments( options.no_speech_threshold, ) - # fast-forward to the next segment boundary - seek += segment_size - continue + # Skip if the logprob is very low (below the threshold value), + # despite no_speech_prob being low (ex: Too ambiguous outputs) + if options.log_prob_low_threshold: + if avg_logprob < options.log_prob_low_threshold: + should_skip = True + self.logger.debug( + "log prob low threshold is met (%f > %f)", + avg_logprob, + options.log_prob_low_threshold, + ) + + # fast-forward to the next segment boundary + seek += segment_size + continue tokens = result.sequences_ids[0] @@ -1791,12 +1950,38 @@ def encode_batch(self, features: torch.Tensor) -> ctranslate2.StorageView: features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) + def assign_word_timings(self, alignments, text_token_probs, words, word_tokens): + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + if len(word_tokens) <= 1: + return [] + + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + if len(word_boundaries) <= 1: + return [] + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] * TIME_PRECISION + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probs = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return [ + dict( + word=word, start=round(start, 2), end=round(end, 2), prob=round(prob, 2) + ) + for word, start, end, prob in zip(words, start_times, end_times, word_probs) + ] + def generate_segment_batched( self, features: torch.Tensor, tokenizer: Tokenizer, options: dict, - encoder_output=None, ): batch_size = features.shape[0] all_tokens = [] @@ -1825,6 +2010,8 @@ def generate_segment_batched( max_length=self.max_length, suppress_blank=options["suppress_blank"], suppress_tokens=options["suppress_tokens"], + return_scores=True, + return_no_speech_prob=True, ) tokens_batch = [x.sequences_ids[0] for x in result] @@ -1836,7 +2023,22 @@ def decode_batch(tokens: List[List[int]]) -> str: return tokenizer.tokenizer.decode_batch(res) text = decode_batch(tokens_batch) - return text + output = [] + for idx, res in enumerate(result): + output.append({"text": text[idx].strip()}) + + # return scores + seq_len = len(res.sequences_ids[0]) + cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) + output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) + + # return no speech prob + output[-1]["no_speech_prob"] = res.no_speech_prob + + text_tokens = [x.sequences_ids[0] + [tokenizer.eot] for x in result] + sot_seqs = [tuple(_[-4:]) for _ in [prompt] * batch_size] + + return encoder_output, sot_seqs, text_tokens, output def detect_language_multi_segment( self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None @@ -2073,7 +2275,7 @@ def key_func(language): "word_timestamps": False, "prepend_punctuations": "\"'“¿([{-", "append_punctuations": "\"'.。,,!!??::”)]}、", - "log_prob_low_threshold": -2.0, + "log_prob_low_threshold": None, "multilingual": False, "output_language": "en", "hotwords": None, diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 9ad4220b..7f5ed551 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -7,7 +7,6 @@ from typing import List, NamedTuple, Optional, Union import numpy as np -import pandas as pd import torch from pyannote.audio.core.io import AudioFile @@ -511,36 +510,15 @@ def __call__(self, scores: SlidingWindowFeature) -> Annotation: return active -def merge_vad( - vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0 -): - active = Annotation() - for k, vad_t in enumerate(vad_arr): - region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) - active[region, k] = 1 - - if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: - active = active.support(collar=min_duration_off) - - # remove tracks shorter than min_duration_on - if min_duration_on > 0: - for segment, track in list(active.itertracks()): - if segment.duration < min_duration_on: - del active[segment, track] - - active = active.for_json() - active_segs = pd.DataFrame([x["segment"] for x in active["content"]]) - return active_segs - - def merge_chunks( segments, chunk_size, onset: float = 0.5, offset: Optional[float] = None, + edge_padding: float = 0.1, ): """ - Merge operation described in paper + Merge operation described in whisper-x paper """ curr_end = 0 merged_segments = [] @@ -554,18 +532,29 @@ def merge_chunks( for speech_turn in segments.get_timeline(): segments_list.append( SegmentX( - max(0.0, speech_turn.start - 0.1), speech_turn.end + 0.1, "UNKNOWN" + max(0.0, speech_turn.start - edge_padding), + speech_turn.end + edge_padding, + "UNKNOWN", ) - ) # 100ms padding to account for edge errors + ) # 100ms edge padding to account for edge errors if len(segments_list) == 0: print("No active speech found in audio") return [] - # assert segments_list, "segments_list is empty." + # Make sur the starting point is the start of the segment. curr_start = segments_list[0].start - for seg in segments_list: + for idx, seg in enumerate(segments_list): + # if any segment start timing is less than previous segment end timing, + # reset the edge padding. Similarly for end timing. + if idx > 0: + if seg.start < segments_list[idx - 1].end: + seg.start = seg.start + edge_padding + if idx < len(segments_list) - 1: + if seg.end > segments_list[idx + 1].start: + seg.end = seg.end - edge_padding + if seg.end - curr_start > chunk_size and curr_end - curr_start > 0: merged_segments.append( { diff --git a/requirements.txt b/requirements.txt index ec741c89..e3b849c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ tokenizers>=0.13,<1 onnxruntime>=1.14,<2 transformers pyannote-audio>=3.1.1 -pandas torch>=2.1.1 torchaudio>=2.1.2 jsons>=1.6.3 \ No newline at end of file diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 467d0dc1..11c56da8 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -50,9 +50,20 @@ def test_batched_transcribe(physcisworks_path): segments.append( {"start": segment.start, "end": segment.end, "text": segment.text} ) - assert len(segments) == 8 # number of near 30 sec segments + # number of near 30 sec segments + assert len(segments) == 8 - segment = segments[0] + result = batched_model.transcribe( + physcisworks_path, batch_size=16, word_timestamps=True + ) + segments = [] + for segment, info in result: + assert segment.words is not None + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + # more number of segments owing to vad based alignment instead of 30 sec segments + assert len(segments) > 8 def test_prefix_with_timestamps(jfk_path):