diff --git a/whisperx/alignment.py b/whisperx/alignment.py index ce92d7a4..9e5b63a2 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -216,6 +216,7 @@ def align( t1 = segment["start"] t2 = segment["end"] text = segment["text"] + avg_logprob = segment.get("avg_logprob") aligned_seg: SingleAlignedSegment = { "start": t1, @@ -225,6 +226,9 @@ def align( "chars": None, } + if avg_logprob is not None: + aligned_seg["avg_logprob"] = avg_logprob + if return_char_alignments: aligned_seg["chars"] = [] @@ -353,12 +357,15 @@ def align( sentence_words.append(word_segment) - aligned_subsegments.append({ + subsegment = { "text": sentence_text, "start": sentence_start, "end": sentence_end, "words": sentence_words, - }) + } + if avg_logprob is not None: + subsegment["avg_logprob"] = avg_logprob + aligned_subsegments.append(subsegment) if return_char_alignments: curr_chars = curr_chars[["char", "start", "end", "score"]] @@ -376,6 +383,8 @@ def align( agg_dict["text"] = "".join if return_char_alignments: agg_dict["chars"] = "sum" + if avg_logprob is not None: + agg_dict["avg_logprob"] = "first" aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) aligned_subsegments = aligned_subsegments.to_dict('records') aligned_segments += aligned_subsegments diff --git a/whisperx/asr.py b/whisperx/asr.py index 7540770f..ea29e56f 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -70,10 +70,17 @@ def generate_segment_batched( suppress_tokens=options.suppress_tokens, no_repeat_ngram_size=options.no_repeat_ngram_size, repetition_penalty=options.repetition_penalty, + return_scores=True, ) tokens_batch = [x.sequences_ids[0] for x in result] + avg_logprobs = [] + for res in result: + seq_len = len(res.sequences_ids[0]) + cum_logprob = res.scores[0] * (seq_len ** options.length_penalty) + avg_logprobs.append(cum_logprob / (seq_len + 1)) + def decode_batch(tokens: List[List[int]]) -> List[str]: res = [] for tk in tokens: @@ -83,7 +90,7 @@ def decode_batch(tokens: List[List[int]]) -> List[str]: text = decode_batch(tokens_batch) - return text + return {'text': text, 'avg_logprob': avg_logprobs} def encode(self, features: np.ndarray) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved @@ -161,7 +168,7 @@ def preprocess(self, audio): def _forward(self, model_inputs): outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) - return {'text': outputs} + return outputs def postprocess(self, model_outputs): return model_outputs @@ -262,15 +269,18 @@ def data(audio, segments): percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") text = out['text'] + avg_logprob = out['avg_logprob'] if batch_size in [0, 1, None]: text = text[0] + avg_logprob = avg_logprob[0] if verbose: print(f"Transcript: [{round(vad_segments[idx]['start'], 3)} --> {round(vad_segments[idx]['end'], 3)}] {text}") segments.append( { "text": text, "start": round(vad_segments[idx]['start'], 3), - "end": round(vad_segments[idx]['end'], 3) + "end": round(vad_segments[idx]['end'], 3), + "avg_logprob": avg_logprob, } ) diff --git a/whisperx/schema.py b/whisperx/schema.py index 70b10a7b..83d9147f 100644 --- a/whisperx/schema.py +++ b/whisperx/schema.py @@ -1,5 +1,10 @@ from typing import TypedDict, Optional, List, Tuple +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + class SingleWordSegment(TypedDict): """ @@ -28,6 +33,7 @@ class SingleSegment(TypedDict): start: float end: float text: str + avg_logprob: NotRequired[float] class SegmentData(TypedDict): @@ -49,6 +55,7 @@ class SingleAlignedSegment(TypedDict): start: float end: float text: str + avg_logprob: NotRequired[float] words: List[SingleWordSegment] chars: Optional[List[SingleCharSegment]]