From d8a078eed46b139f90d7f33b472119e9be9ff969 Mon Sep 17 00:00:00 2001 From: Claude-Assistant Date: Sat, 14 Feb 2026 09:58:45 +0100 Subject: [PATCH] feat: expose avg_logprob per segment from ctranslate2 beam search Pass through the average log probability (transcription confidence score) from ctranslate2 to the final segment output. The field is NotRequired so existing code constructing segments without it remains valid. Co-Authored-By: Claude Opus 4.6 --- whisperx/alignment.py | 13 +++++++++++-- whisperx/asr.py | 16 +++++++++++++--- whisperx/schema.py | 7 +++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index ce92d7a4f..9e5b63a2d 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -216,6 +216,7 @@ def align( t1 = segment["start"] t2 = segment["end"] text = segment["text"] + avg_logprob = segment.get("avg_logprob") aligned_seg: SingleAlignedSegment = { "start": t1, @@ -225,6 +226,9 @@ def align( "chars": None, } + if avg_logprob is not None: + aligned_seg["avg_logprob"] = avg_logprob + if return_char_alignments: aligned_seg["chars"] = [] @@ -353,12 +357,15 @@ def align( sentence_words.append(word_segment) - aligned_subsegments.append({ + subsegment = { "text": sentence_text, "start": sentence_start, "end": sentence_end, "words": sentence_words, - }) + } + if avg_logprob is not None: + subsegment["avg_logprob"] = avg_logprob + aligned_subsegments.append(subsegment) if return_char_alignments: curr_chars = curr_chars[["char", "start", "end", "score"]] @@ -376,6 +383,8 @@ def align( agg_dict["text"] = "".join if return_char_alignments: agg_dict["chars"] = "sum" + if avg_logprob is not None: + agg_dict["avg_logprob"] = "first" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments = aligned_subsegments.to_dict('records') aligned_segments += aligned_subsegments diff --git a/whisperx/asr.py b/whisperx/asr.py index 7540770f4..ea29e56f5 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -70,10 +70,17 @@ def generate_segment_batched( suppress_tokens=options.suppress_tokens, no_repeat_ngram_size=options.no_repeat_ngram_size, repetition_penalty=options.repetition_penalty, + return_scores=True, ) tokens_batch = [x.sequences_ids[0] for x in result] + avg_logprobs = [] + for res in result: + seq_len = len(res.sequences_ids[0]) + cum_logprob = res.scores[0] * (seq_len ** options.length_penalty) + avg_logprobs.append(cum_logprob / (seq_len + 1)) + def decode_batch(tokens: List[List[int]]) -> List[str]: res = [] for tk in tokens: @@ -83,7 +90,7 @@ def decode_batch(tokens: List[List[int]]) -> List[str]: text = decode_batch(tokens_batch) - return text + return {'text': text, 'avg_logprob': avg_logprobs} def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved @@ -161,7 +168,7 @@ def preprocess(self, audio): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) - return {'text': outputs} + return outputs def postprocess(self, model_outputs): return model_outputs @@ -262,15 +269,18 @@ def data(audio, segments): percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") text = out['text'] + avg_logprob = out['avg_logprob'] if batch_size in [0, 1, None]: text = text[0] + avg_logprob = avg_logprob[0] if verbose: print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}") segments.append( { "text": text, "start": round(vad_segments[idx]['start'], 3), - "end": round(vad_segments[idx]['end'], 3) + "end": round(vad_segments[idx]['end'], 3), + "avg_logprob": avg_logprob, } ) diff --git a/whisperx/schema.py b/whisperx/schema.py index 70b10a7b0..83d9147fd 100644 --- a/whisperx/schema.py +++ b/whisperx/schema.py @@ -1,5 +1,10 @@ from typing import TypedDict, Optional, List, Tuple +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + class SingleWordSegment(TypedDict): """ @@ -28,6 +33,7 @@ class SingleSegment(TypedDict): start: float end: float text: str + avg_logprob: NotRequired[float] class SegmentData(TypedDict): @@ -49,6 +55,7 @@ class SingleAlignedSegment(TypedDict): start: float end: float text: str + avg_logprob: NotRequired[float] words: List[SingleWordSegment] chars: Optional[List[SingleCharSegment]]