From 1430e4330010888b42e6cdd95bb384bd13564801 Mon Sep 17 00:00:00 2001 From: carlos Date: Sat, 14 Feb 2026 11:27:17 +0000 Subject: [PATCH 1/5] [fix] Batch context is updated each time. It works with the initial prompt added. Ran pdb to make sure and check output. Long audio works. Existing Logic is correct without flag. --- whisperx/__main__.py | 1 + whisperx/asr.py | 140 +++++++++++++++++++++++++++++++++++------ whisperx/transcribe.py | 2 + 3 files changed, 123 insertions(+), 20 deletions(-) diff --git a/whisperx/__main__.py b/whisperx/__main__.py index dbb92fc4f..0e523910a 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -60,6 +60,7 @@ def cli(): parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--batch_context", action="store_true", help="use previous batch's transcription as context for the next batch (slower but more coherent across batches)") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") diff --git a/whisperx/asr.py b/whisperx/asr.py index 7540770f4..0f9a08a5b 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -40,28 +40,61 @@ def generate_segment_batched( tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output=None, + use_batch_context: bool = False, + previous_batch_context_tokens: List[List[int]] = None, ): batch_size = features.shape[0] - all_tokens = [] - prompt_reset_since = 0 - if options.initial_prompt is not None: - initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) - previous_tokens = all_tokens[prompt_reset_since:] - prompt = self.get_prompt( - tokenizer, - previous_tokens, - without_timestamps=options.without_timestamps, - prefix=options.prefix, - hotwords=options.hotwords - ) + if previous_batch_context_tokens is None: + previous_batch_context_tokens = [[] for _ in range(batch_size)] + + prompts = [] + batch_tokens = [] + for i in range(batch_size): + all_tokens = [] + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + + if use_batch_context: + # previous_batch_context_tokens is now List[List[int]] + # verify we have enough context history lists + if i < len(previous_batch_context_tokens): + ctx = previous_batch_context_tokens[i] + if ctx: + max_prompt_tokens = 223 + current_len = len(all_tokens) + available = max_prompt_tokens - current_len + if available > 0: + all_tokens.extend(ctx[-available:]) + batch_tokens.append(all_tokens) + + # Calculate max length in the batch + max_batch_tokens = max([len(t) for t in batch_tokens] + [0]) + + # Pad tokens to ensure consistent length across batch + # We use left-padding with EOT to preserve the immediate context before SOT + for i in range(batch_size): + current_tokens = batch_tokens[i] + if len(current_tokens) < max_batch_tokens: + padding_len = max_batch_tokens - len(current_tokens) + # Pad with EOT (End of Transcript) which is usually ignored or treated as break + current_tokens = [tokenizer.eot] * padding_len + current_tokens + + prompt = self.get_prompt( + tokenizer, + current_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix, + hotwords=options.hotwords + ) + prompts.append(prompt) encoder_output = self.encode(features) result = self.model.generate( encoder_output, - [prompt] * batch_size, + prompts, beam_size=options.beam_size, patience=options.patience, length_penalty=options.length_penalty, @@ -82,9 +115,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]: return tokenizer.tokenizer.decode_batch(res) text = decode_batch(tokens_batch) - return text + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -115,6 +148,7 @@ def __init__( framework="pt", language: Optional[str] = None, suppress_numerals: bool = False, + use_batch_context: bool = False, **kwargs, ): self.model = model @@ -122,6 +156,7 @@ def __init__( self.options = options self.preset_language = language self.suppress_numerals = suppress_numerals + self.use_batch_context = use_batch_context self._batch_size = kwargs.pop("batch_size", None) self._num_workers = 1 self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) @@ -142,6 +177,8 @@ def __init__( super(Pipeline, self).__init__() self.vad_model = vad self._vad_params = vad_params + self.previous_batch_context_tokens = [] + def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} @@ -160,7 +197,35 @@ def preprocess(self, audio): return {'inputs': features} def _forward(self, model_inputs): - outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) + current_batch_size = model_inputs['inputs'].shape[0] + # Ideally, batch[i] corresponds to stream[i]. + # This holds if batch_size == number of streams. + valid_contexts = self.previous_batch_context_tokens[:current_batch_size] + + outputs = self.model.generate_segment_batched( + model_inputs['inputs'], + self.tokenizer, + self.options, + use_batch_context=self.use_batch_context, + previous_batch_context_tokens=valid_contexts, + ) + if self.use_batch_context: + initial_prompt_length = 0 + if self.options.initial_prompt is not None: + initial_prompt = " " + self.options.initial_prompt.strip() + initial_prompt_length = len(self.tokenizer.encode(initial_prompt)) + + # Use 220 instead of 224 to be safe + max_context_window = max(0, 220 - initial_prompt_length) + + for i, text in enumerate(outputs): + if i < len(self.previous_batch_context_tokens): + # Filter out special tokens (timestamps, SOT, EOT, etc.) + # We only want the text content for context. + tokens = [t for t in self.tokenizer.encode(text) if t < self.tokenizer.eot] + self.previous_batch_context_tokens[i].extend(tokens) + self.previous_batch_context_tokens[i] = self.previous_batch_context_tokens[i][-max_context_window:] + return {'text': outputs} def postprocess(self, model_outputs): @@ -201,6 +266,14 @@ def transcribe( ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) + + batch_size = batch_size or self._batch_size + # Initialize context for each stream. + # We have 'batch_size' concurrent streams. + if batch_size is None or batch_size < 1: + batch_size = 1 + + self.previous_batch_context_tokens = [[] for _ in range(batch_size)] def data(audio, segments): for seg in segments: @@ -252,10 +325,33 @@ def data(audio, segments): new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = list(set(new_suppressed_tokens)) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) - + segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) + + if batch_size > 1: + num_streams = batch_size + # Distribute segments into streams + # Manual split + k, m = divmod(len(vad_segments), num_streams) + # lengths of each part: first m parts have k+1, rest have k + stream_segments = [] + start_idx = 0 + for i in range(num_streams): + part_len = k + 1 if i < m else k + stream_segments.append(vad_segments[start_idx : start_idx + part_len]) + start_idx += part_len + # Interleave + # We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...] + interleaved_segments = [] + max_len = max(len(s) for s in stream_segments) + for i in range(max_len): + for stream in stream_segments: + if i < len(stream): + interleaved_segments.append(stream[i]) + vad_segments = interleaved_segments + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): if print_progress: base_progress = ((idx + 1) / total_segments) * 100 @@ -273,6 +369,8 @@ def data(audio, segments): "end": round(vad_segments[idx]['end'], 3) } ) + # Sort segments by start time to restore original order + segments.sort(key=lambda x: x['start']) # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: @@ -289,8 +387,8 @@ def detect_language(self, audio: np.ndarray) -> str: logger.warning("Audio is shorter than 30s, language detection may be inaccurate") model_n_mels = self.model.feat_kwargs.get("feature_size") segment = log_mel_spectrogram(audio[: N_SAMPLES], - n_mels=model_n_mels if model_n_mels is not None else 80, - padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0] @@ -315,6 +413,7 @@ def load_model( local_files_only=False, threads=4, use_auth_token: Optional[Union[str, bool]] = None, + use_batch_context: bool = False, ) -> FasterWhisperPipeline: """Load a Whisper model for inference. Args: @@ -421,4 +520,5 @@ def load_model( language=language, suppress_numerals=suppress_numerals, vad_params=default_vad_options, + use_batch_context=use_batch_context, ) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 7c8be6794..92a634dd1 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -63,6 +63,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): diarize_model_name: str = args.pop("diarize_model") print_progress: bool = args.pop("print_progress") return_speaker_embeddings: bool = args.pop("speaker_embeddings") + batch_context: bool = args.pop("batch_context", False) if return_speaker_embeddings and not diarize: warnings.warn("--speaker_embeddings has no effect without --diarize") @@ -142,6 +143,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): local_files_only=model_cache_only, threads=faster_whisper_threads, use_auth_token=hf_token, + use_batch_context=batch_context, ) for audio_path in args.pop("audio"): From e33bb1ec25bd743899ff6de2a36d7f4abd1ba3a3 Mon Sep 17 00:00:00 2001 From: carlos Date: Sat, 14 Feb 2026 11:54:07 +0000 Subject: [PATCH 2/5] Although the existing commit worked, inital prompt was in the loop no need. --- whisperx/asr.py | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 0f9a08a5b..0931d88bf 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -46,49 +46,36 @@ def generate_segment_batched( batch_size = features.shape[0] if previous_batch_context_tokens is None: previous_batch_context_tokens = [[] for _ in range(batch_size)] - - prompts = [] + + initial_prompt_tokens = [] + if options.initial_prompt is not None: + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + batch_tokens = [] for i in range(batch_size): - all_tokens = [] - if options.initial_prompt is not None: - initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) - all_tokens.extend(initial_prompt_tokens) - + all_tokens = list(initial_prompt_tokens) if use_batch_context: - # previous_batch_context_tokens is now List[List[int]] - # verify we have enough context history lists if i < len(previous_batch_context_tokens): ctx = previous_batch_context_tokens[i] if ctx: - max_prompt_tokens = 223 - current_len = len(all_tokens) - available = max_prompt_tokens - current_len + # 223 is max prompt tokens + available = 223 - len(all_tokens) if available > 0: all_tokens.extend(ctx[-available:]) batch_tokens.append(all_tokens) - # Calculate max length in the batch max_batch_tokens = max([len(t) for t in batch_tokens] + [0]) - # Pad tokens to ensure consistent length across batch - # We use left-padding with EOT to preserve the immediate context before SOT - for i in range(batch_size): - current_tokens = batch_tokens[i] - if len(current_tokens) < max_batch_tokens: - padding_len = max_batch_tokens - len(current_tokens) - # Pad with EOT (End of Transcript) which is usually ignored or treated as break - current_tokens = [tokenizer.eot] * padding_len + current_tokens - - prompt = self.get_prompt( + prompts = [ + self.get_prompt( tokenizer, - current_tokens, + [tokenizer.eot] * (max_batch_tokens - len(t)) + t, without_timestamps=options.without_timestamps, prefix=options.prefix, hotwords=options.hotwords - ) - prompts.append(prompt) + ) for t in batch_tokens + ] encoder_output = self.encode(features) From de5fa652a1866f7d936967fce6c33a0aaca3de53 Mon Sep 17 00:00:00 2001 From: carlos Date: Sat, 14 Feb 2026 20:22:37 +0000 Subject: [PATCH 3/5] [modification] added and condition before streams, existing logic is not chnaged. --- whisperx/asr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisperx/asr.py b/whisperx/asr.py index 0931d88bf..768cddb61 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -317,7 +317,7 @@ def data(audio, segments): batch_size = batch_size or self._batch_size total_segments = len(vad_segments) - if batch_size > 1: + if batch_size > 1 and self.use_batch_context: num_streams = batch_size # Distribute segments into streams # Manual split From 1b6a3b78a6c7a23ac147c59939b89c11ce7358f8 Mon Sep 17 00:00:00 2001 From: carlos Date: Tue, 17 Feb 2026 13:30:03 +0000 Subject: [PATCH 4/5] [feat] First batch wrap around [New File] benchmark testing --- whisperx/asr.py | 18 +++++- whisperx/benchmark.py | 144 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 whisperx/benchmark.py diff --git a/whisperx/asr.py b/whisperx/asr.py index 768cddb61..342db0372 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -18,7 +18,6 @@ logger = get_logger(__name__) - def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] for i in range(tokenizer.eot): @@ -356,6 +355,23 @@ def data(audio, segments): "end": round(vad_segments[idx]['end'], 3) } ) + + if self.use_batch_context and batch_size > 1: + last_stream_index = (total_segments - 1) % batch_size + final_context = self.previous_batch_context_tokens[last_stream_index] + # Prepare context for the wrap-around re-run + # ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file). + # All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N). + new_rerun_context = [[] for _ in range(batch_size)] + new_rerun_context[0] = final_context + # Temporarily overwrite previous_batch_context_tokens for the re-run + self.previous_batch_context_tokens = new_rerun_context + first_batch_segments = vad_segments[:batch_size] + # Runs the model again just on 'first_batch_segments' + for i, out in enumerate(self.__call__(data(audio, first_batch_segments), batch_size=batch_size, num_workers=num_workers)): + text = out['text'] + # L398: Overwrite the existing text with the new wrap-around text + segments[i]['text'] = text # Sort segments by start time to restore original order segments.sort(key=lambda x: x['start']) diff --git a/whisperx/benchmark.py b/whisperx/benchmark.py new file mode 100644 index 000000000..7e4dbaf07 --- /dev/null +++ b/whisperx/benchmark.py @@ -0,0 +1,144 @@ +import argparse +import os +import time +import torch +import torchaudio +import jiwer +import whisperx +import numpy as np +from typing import Tuple + +def load_tedlium(root: str, download: bool = False, subset: str = "test"): + print(f"Loading TEDLIUM dataset ({subset}) from {root}...") + try: + dataset = torchaudio.datasets.TEDLIUM( + root=root, + release="release3", + subset=subset, + download=download + ) + return dataset + except Exception as e: + print(f"Error loading dataset: {e}") + return None + +def normalize_text(text: str) -> str: + """ + Simple normalization: lower case, remove punctuation. + """ + import string + text = text.lower() + text = text.translate(str.maketrans('', '', string.punctuation)) + return " ".join(text.split()) + +def benchmark(dataset, model_size="large-v2", device="cuda", compute_type="float16", batch_size=4, limit=None): + print(f"Loading WhisperX model: {model_size} on {device} ({compute_type})...") + + try: + model = whisperx.load_model(model_size, device, compute_type=compute_type) + except Exception as e: + print(f"Failed to load model: {e}") + return + + print("Model loaded.") + + total_wer = 0 + total_cer = 0 + total_latency = 0 + total_audio_duration = 0 + count = 0 + + print(f"\nBenchmarking on {limit if limit else len(dataset)} samples...") + + # Clear CUDA cache for accurate VRAM measurement + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + initial_vram = torch.cuda.memory_allocated() / 1024**3 + print(f"Initial VRAM usage: {initial_vram:.2f} GB") + + for i, item in enumerate(dataset): + if limit and i >= limit: + break + + waveform, sample_rate, transcript, talk_id, speaker_id, identifier = item + + # WhisperX expects audio as a numpy array, float32, mono, 16kHz + # TEDLIUM is likely 16kHz, but let's verify/resample if needed + # waveform is (channels, time) + + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) + waveform = resampler(waveform) + + audio_np = waveform.squeeze().numpy() + + duration = len(audio_np) / 16000 + total_audio_duration += duration + + # Measure Latency + start_time = time.time() + result = model.transcribe(audio_np, batch_size=batch_size) + end_time = time.time() + + latency = end_time - start_time + total_latency += latency + + # Combine segments for full transcript + hypothesis = " ".join([seg['text'] for seg in result['segments']]) + + # Normalize + ref_norm = normalize_text(transcript) + hyp_norm = normalize_text(hypothesis) + + if not ref_norm.strip(): + # Skip empty references to avoid division by zero in WER + continue + + # Measure WER/CER + wer = jiwer.wer(ref_norm, hyp_norm) + cer = jiwer.cer(ref_norm, hyp_norm) + + total_wer += wer + total_cer += cer + count += 1 + + print(f"Sample {i}: WER={wer:.2f}, CER={cer:.2f}, Latency={latency:.2f}s, Dur={duration:.2f}s, RTF={latency/duration:.2f}") + + if count == 0: + print("No samples processed.") + return + + avg_wer = total_wer / count + avg_cer = total_cer / count + avg_rtf = total_latency / total_audio_duration + + print("\n--- Benchmark Results ---") + print(f"Average WER: {avg_wer:.4f}") + print(f"Average CER: {avg_cer:.4f}") + print(f"Average RTF (Real Time Factor): {avg_rtf:.4f}") + print(f"Total Latency: {total_latency:.2f}s for {total_audio_duration:.2f}s audio") + + if torch.cuda.is_available(): + peak_vram = torch.cuda.max_memory_allocated() / 1024**3 + print(f"Peak VRAM Usage: {peak_vram:.2f} GB") + else: + print("VRAM Usage: N/A (CPU only)") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark WhisperX on TEDLIUM") + parser.add_argument("--root", type=str, default="./data", help="Root directory for dataset") + parser.add_argument("--download", action="store_true", help="Download dataset if not found") + parser.add_argument("--limit", type=int, default=None, help="Limit number of samples") + parser.add_argument("--model", type=str, default="large-v2", help="Whisper model size") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size") + + args = parser.parse_args() + + # Create data dir + os.makedirs(args.root, exist_ok=True) + + ds = load_tedlium(args.root, download=args.download) + if ds: + benchmark(ds, model_size=args.model, device=args.device, batch_size=args.batch_size, limit=args.limit) \ No newline at end of file From 0e073d418c7ab0b07fce401c9e0464436920d5ed Mon Sep 17 00:00:00 2001 From: carlito Date: Tue, 17 Feb 2026 13:33:06 +0000 Subject: [PATCH 5/5] Revert "Batch wrap" --- whisperx/__main__.py | 1 - whisperx/asr.py | 139 ++++++--------------------------------- whisperx/benchmark.py | 144 ----------------------------------------- whisperx/transcribe.py | 2 - 4 files changed, 18 insertions(+), 268 deletions(-) delete mode 100644 whisperx/benchmark.py diff --git a/whisperx/__main__.py b/whisperx/__main__.py index 0e523910a..dbb92fc4f 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -60,7 +60,6 @@ def cli(): parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms") parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") - parser.add_argument("--batch_context", action="store_true", help="use previous batch's transcription as context for the next batch (slower but more coherent across batches)") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") diff --git a/whisperx/asr.py b/whisperx/asr.py index 342db0372..7540770f4 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -18,6 +18,7 @@ logger = get_logger(__name__) + def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] for i in range(tokenizer.eot): @@ -39,48 +40,28 @@ def generate_segment_batched( tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output=None, - use_batch_context: bool = False, - previous_batch_context_tokens: List[List[int]] = None, ): batch_size = features.shape[0] - if previous_batch_context_tokens is None: - previous_batch_context_tokens = [[] for _ in range(batch_size)] - - initial_prompt_tokens = [] + all_tokens = [] + prompt_reset_since = 0 if options.initial_prompt is not None: initial_prompt = " " + options.initial_prompt.strip() initial_prompt_tokens = tokenizer.encode(initial_prompt) - - batch_tokens = [] - for i in range(batch_size): - all_tokens = list(initial_prompt_tokens) - if use_batch_context: - if i < len(previous_batch_context_tokens): - ctx = previous_batch_context_tokens[i] - if ctx: - # 223 is max prompt tokens - available = 223 - len(all_tokens) - if available > 0: - all_tokens.extend(ctx[-available:]) - batch_tokens.append(all_tokens) - - max_batch_tokens = max([len(t) for t in batch_tokens] + [0]) - - prompts = [ - self.get_prompt( - tokenizer, - [tokenizer.eot] * (max_batch_tokens - len(t)) + t, - without_timestamps=options.without_timestamps, - prefix=options.prefix, - hotwords=options.hotwords - ) for t in batch_tokens - ] + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix, + hotwords=options.hotwords + ) encoder_output = self.encode(features) result = self.model.generate( encoder_output, - prompts, + [prompt] * batch_size, beam_size=options.beam_size, patience=options.patience, length_penalty=options.length_penalty, @@ -101,9 +82,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]: return tokenizer.tokenizer.decode_batch(res) text = decode_batch(tokens_batch) + return text - def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. @@ -134,7 +115,6 @@ def __init__( framework="pt", language: Optional[str] = None, suppress_numerals: bool = False, - use_batch_context: bool = False, **kwargs, ): self.model = model @@ -142,7 +122,6 @@ def __init__( self.options = options self.preset_language = language self.suppress_numerals = suppress_numerals - self.use_batch_context = use_batch_context self._batch_size = kwargs.pop("batch_size", None) self._num_workers = 1 self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) @@ -163,8 +142,6 @@ def __init__( super(Pipeline, self).__init__() self.vad_model = vad self._vad_params = vad_params - self.previous_batch_context_tokens = [] - def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} @@ -183,35 +160,7 @@ def preprocess(self, audio): return {'inputs': features} def _forward(self, model_inputs): - current_batch_size = model_inputs['inputs'].shape[0] - # Ideally, batch[i] corresponds to stream[i]. - # This holds if batch_size == number of streams. - valid_contexts = self.previous_batch_context_tokens[:current_batch_size] - - outputs = self.model.generate_segment_batched( - model_inputs['inputs'], - self.tokenizer, - self.options, - use_batch_context=self.use_batch_context, - previous_batch_context_tokens=valid_contexts, - ) - if self.use_batch_context: - initial_prompt_length = 0 - if self.options.initial_prompt is not None: - initial_prompt = " " + self.options.initial_prompt.strip() - initial_prompt_length = len(self.tokenizer.encode(initial_prompt)) - - # Use 220 instead of 224 to be safe - max_context_window = max(0, 220 - initial_prompt_length) - - for i, text in enumerate(outputs): - if i < len(self.previous_batch_context_tokens): - # Filter out special tokens (timestamps, SOT, EOT, etc.) - # We only want the text content for context. - tokens = [t for t in self.tokenizer.encode(text) if t < self.tokenizer.eot] - self.previous_batch_context_tokens[i].extend(tokens) - self.previous_batch_context_tokens[i] = self.previous_batch_context_tokens[i][-max_context_window:] - + outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) return {'text': outputs} def postprocess(self, model_outputs): @@ -252,14 +201,6 @@ def transcribe( ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) - - batch_size = batch_size or self._batch_size - # Initialize context for each stream. - # We have 'batch_size' concurrent streams. - if batch_size is None or batch_size < 1: - batch_size = 1 - - self.previous_batch_context_tokens = [[] for _ in range(batch_size)] def data(audio, segments): for seg in segments: @@ -311,33 +252,10 @@ def data(audio, segments): new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens new_suppressed_tokens = list(set(new_suppressed_tokens)) self.options = replace(self.options, suppress_tokens=new_suppressed_tokens) - + segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) - - if batch_size > 1 and self.use_batch_context: - num_streams = batch_size - # Distribute segments into streams - # Manual split - k, m = divmod(len(vad_segments), num_streams) - # lengths of each part: first m parts have k+1, rest have k - stream_segments = [] - start_idx = 0 - for i in range(num_streams): - part_len = k + 1 if i < m else k - stream_segments.append(vad_segments[start_idx : start_idx + part_len]) - start_idx += part_len - # Interleave - # We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...] - interleaved_segments = [] - max_len = max(len(s) for s in stream_segments) - for i in range(max_len): - for stream in stream_segments: - if i < len(stream): - interleaved_segments.append(stream[i]) - vad_segments = interleaved_segments - for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): if print_progress: base_progress = ((idx + 1) / total_segments) * 100 @@ -356,25 +274,6 @@ def data(audio, segments): } ) - if self.use_batch_context and batch_size > 1: - last_stream_index = (total_segments - 1) % batch_size - final_context = self.previous_batch_context_tokens[last_stream_index] - # Prepare context for the wrap-around re-run - # ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file). - # All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N). - new_rerun_context = [[] for _ in range(batch_size)] - new_rerun_context[0] = final_context - # Temporarily overwrite previous_batch_context_tokens for the re-run - self.previous_batch_context_tokens = new_rerun_context - first_batch_segments = vad_segments[:batch_size] - # Runs the model again just on 'first_batch_segments' - for i, out in enumerate(self.__call__(data(audio, first_batch_segments), batch_size=batch_size, num_workers=num_workers)): - text = out['text'] - # L398: Overwrite the existing text with the new wrap-around text - segments[i]['text'] = text - # Sort segments by start time to restore original order - segments.sort(key=lambda x: x['start']) - # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None @@ -390,8 +289,8 @@ def detect_language(self, audio: np.ndarray) -> str: logger.warning("Audio is shorter than 30s, language detection may be inaccurate") model_n_mels = self.model.feat_kwargs.get("feature_size") segment = log_mel_spectrogram(audio[: N_SAMPLES], - n_mels=model_n_mels if model_n_mels is not None else 80, - padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0] @@ -416,7 +315,6 @@ def load_model( local_files_only=False, threads=4, use_auth_token: Optional[Union[str, bool]] = None, - use_batch_context: bool = False, ) -> FasterWhisperPipeline: """Load a Whisper model for inference. Args: @@ -523,5 +421,4 @@ def load_model( language=language, suppress_numerals=suppress_numerals, vad_params=default_vad_options, - use_batch_context=use_batch_context, ) diff --git a/whisperx/benchmark.py b/whisperx/benchmark.py deleted file mode 100644 index 7e4dbaf07..000000000 --- a/whisperx/benchmark.py +++ /dev/null @@ -1,144 +0,0 @@ -import argparse -import os -import time -import torch -import torchaudio -import jiwer -import whisperx -import numpy as np -from typing import Tuple - -def load_tedlium(root: str, download: bool = False, subset: str = "test"): - print(f"Loading TEDLIUM dataset ({subset}) from {root}...") - try: - dataset = torchaudio.datasets.TEDLIUM( - root=root, - release="release3", - subset=subset, - download=download - ) - return dataset - except Exception as e: - print(f"Error loading dataset: {e}") - return None - -def normalize_text(text: str) -> str: - """ - Simple normalization: lower case, remove punctuation. - """ - import string - text = text.lower() - text = text.translate(str.maketrans('', '', string.punctuation)) - return " ".join(text.split()) - -def benchmark(dataset, model_size="large-v2", device="cuda", compute_type="float16", batch_size=4, limit=None): - print(f"Loading WhisperX model: {model_size} on {device} ({compute_type})...") - - try: - model = whisperx.load_model(model_size, device, compute_type=compute_type) - except Exception as e: - print(f"Failed to load model: {e}") - return - - print("Model loaded.") - - total_wer = 0 - total_cer = 0 - total_latency = 0 - total_audio_duration = 0 - count = 0 - - print(f"\nBenchmarking on {limit if limit else len(dataset)} samples...") - - # Clear CUDA cache for accurate VRAM measurement - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - initial_vram = torch.cuda.memory_allocated() / 1024**3 - print(f"Initial VRAM usage: {initial_vram:.2f} GB") - - for i, item in enumerate(dataset): - if limit and i >= limit: - break - - waveform, sample_rate, transcript, talk_id, speaker_id, identifier = item - - # WhisperX expects audio as a numpy array, float32, mono, 16kHz - # TEDLIUM is likely 16kHz, but let's verify/resample if needed - # waveform is (channels, time) - - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) - waveform = resampler(waveform) - - audio_np = waveform.squeeze().numpy() - - duration = len(audio_np) / 16000 - total_audio_duration += duration - - # Measure Latency - start_time = time.time() - result = model.transcribe(audio_np, batch_size=batch_size) - end_time = time.time() - - latency = end_time - start_time - total_latency += latency - - # Combine segments for full transcript - hypothesis = " ".join([seg['text'] for seg in result['segments']]) - - # Normalize - ref_norm = normalize_text(transcript) - hyp_norm = normalize_text(hypothesis) - - if not ref_norm.strip(): - # Skip empty references to avoid division by zero in WER - continue - - # Measure WER/CER - wer = jiwer.wer(ref_norm, hyp_norm) - cer = jiwer.cer(ref_norm, hyp_norm) - - total_wer += wer - total_cer += cer - count += 1 - - print(f"Sample {i}: WER={wer:.2f}, CER={cer:.2f}, Latency={latency:.2f}s, Dur={duration:.2f}s, RTF={latency/duration:.2f}") - - if count == 0: - print("No samples processed.") - return - - avg_wer = total_wer / count - avg_cer = total_cer / count - avg_rtf = total_latency / total_audio_duration - - print("\n--- Benchmark Results ---") - print(f"Average WER: {avg_wer:.4f}") - print(f"Average CER: {avg_cer:.4f}") - print(f"Average RTF (Real Time Factor): {avg_rtf:.4f}") - print(f"Total Latency: {total_latency:.2f}s for {total_audio_duration:.2f}s audio") - - if torch.cuda.is_available(): - peak_vram = torch.cuda.max_memory_allocated() / 1024**3 - print(f"Peak VRAM Usage: {peak_vram:.2f} GB") - else: - print("VRAM Usage: N/A (CPU only)") - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark WhisperX on TEDLIUM") - parser.add_argument("--root", type=str, default="./data", help="Root directory for dataset") - parser.add_argument("--download", action="store_true", help="Download dataset if not found") - parser.add_argument("--limit", type=int, default=None, help="Limit number of samples") - parser.add_argument("--model", type=str, default="large-v2", help="Whisper model size") - parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device") - parser.add_argument("--batch_size", type=int, default=4, help="Batch size") - - args = parser.parse_args() - - # Create data dir - os.makedirs(args.root, exist_ok=True) - - ds = load_tedlium(args.root, download=args.download) - if ds: - benchmark(ds, model_size=args.model, device=args.device, batch_size=args.batch_size, limit=args.limit) \ No newline at end of file diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 92a634dd1..7c8be6794 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -63,7 +63,6 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): diarize_model_name: str = args.pop("diarize_model") print_progress: bool = args.pop("print_progress") return_speaker_embeddings: bool = args.pop("speaker_embeddings") - batch_context: bool = args.pop("batch_context", False) if return_speaker_embeddings and not diarize: warnings.warn("--speaker_embeddings has no effect without --diarize") @@ -143,7 +142,6 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): local_files_only=model_cache_only, threads=faster_whisper_threads, use_auth_token=hf_token, - use_batch_context=batch_context, ) for audio_path in args.pop("audio"):