From f2609a64b6f8a2f325f991494c24a408160d0c75 Mon Sep 17 00:00:00 2001 From: Claude-Assistant Date: Tue, 10 Mar 2026 09:01:22 +0100 Subject: [PATCH 1/3] fix: revert #986 wildcard alignment that broke word-level timestamps PR #986 ("support timestamps for numbers") introduced three changes that together broke CTC forced alignment: 1. Unknown chars (numbers, punctuation) were replaced with '*' wildcards mapped to token -1. get_wildcard_emission() scored these using torch.max() over all non-blank emissions, so wildcards greedily matched any speech-like signal in the segment window. 2. get_trellis() was rewritten with a different shape (num_frame, num_tokens) and incompatible initialization, discarding the original SoS-offset design from the PyTorch forced alignment tutorial. 3. backtrack() was replaced with backtrack_beam(), which always starts backtracking from the last frame of the segment window. The original backtrack() used torch.argmax() on the last token column to determine the starting frame. With padded segment boundaries (silence before/after speech), the new implementation spread all tokens across the full window, placing the first word at the start of the silence instead of the speech. This commit restores the original PyTorch tutorial implementation: - Unknown chars are skipped; words with only unknown chars become unalignable and get no timestamps (handled by interpolate_nans) - get_trellis: restored (num_frame+1, num_tokens+1) shape with SoS offset - backtrack: restored torch.argmax-based starting frame - Removed backtrack_beam, get_wildcard_emission, BeamState, Path Verified: v3.3.0 (pre-#986) produced correct timestamps with padded segment boundaries; this fix reproduces that behavior. Fixes #1220 Co-Authored-By: Claude Sonnet 4.6 --- whisperx/alignment.py | 228 +++++++++--------------------------------- 1 file changed, 45 insertions(+), 183 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 9e5b63a2..74f9d350 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -2,8 +2,6 @@ Forced Alignment with Whisper C. Max Bain """ -import math - from dataclasses import dataclass from typing import Iterable, Optional, Union, List @@ -86,8 +84,8 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str] model_name = DEFAULT_ALIGN_MODELS_HF[language_code] else: logger.error(f"No default alignment model for language: {language_code}. " - f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " - f"then pass the model name via --align_model [MODEL_NAME]") + f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, " + f"then pass the model name via --align_model [MODEL_NAME]") raise ValueError(f"No default align-model for language: {language_code}") if model_name in torchaudio.pipelines.__all__: @@ -178,19 +176,11 @@ def align( elif char_ in model_dictionary.keys(): clean_char.append(char_) clean_cdx.append(cdx) - else: - # add placeholder - clean_char.append('*') - clean_cdx.append(cdx) clean_wdx = [] for wdx, wrd in enumerate(per_word): if any([c in model_dictionary.keys() for c in wrd.lower()]): clean_wdx.append(wdx) - else: - # index for placeholder - clean_wdx.append(wdx) - # Use language-specific Punkt model if available otherwise we fallback to English. punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english') @@ -244,7 +234,7 @@ def align( continue text_clean = "".join(segment_data[sdx]["clean_char"]) - tokens = [model_dictionary.get(c, -1) for c in text_clean] + tokens = [model_dictionary[c] for c in text_clean] f1 = int(t1 * SAMPLE_RATE) f2 = int(t2 * SAMPLE_RATE) @@ -277,8 +267,7 @@ def align( blank_id = code trellis = get_trellis(emission, tokens, blank_id) - # path = backtrack(trellis, emission, tokens, blank_id) - path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2) + path = backtrack(trellis, emission, tokens, blank_id) if path is None: logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original') @@ -405,55 +394,25 @@ def get_trellis(emission, tokens, blank_id=0): num_frame = emission.size(0) num_tokens = len(tokens) - trellis = torch.zeros((num_frame, num_tokens)) - trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0) - trellis[0, 1:] = -float("inf") - trellis[-num_tokens + 1:, 0] = float("inf") + # Trellis has extra dimensions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.empty((num_frame + 1, num_tokens + 1)) + trellis[0, 0] = 0 + trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) + trellis[0, -num_tokens:] = -float("inf") + trellis[-num_tokens:, 0] = float("inf") - for t in range(num_frame - 1): + for t in range(num_frame): trellis[t + 1, 1:] = torch.maximum( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token - # trellis[t, :-1] + emission[t, tokens[1:]], - trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id), + trellis[t, :-1] + emission[t, tokens], ) return trellis -def get_wildcard_emission(frame_emission, tokens, blank_id): - """Processing token emission scores containing wildcards (vectorized version) - - Args: - frame_emission: Emission probability vector for the current frame - tokens: List of token indices - blank_id: ID of the blank token - - Returns: - tensor: Maximum probability score for each token position - """ - assert 0 <= blank_id < len(frame_emission) - - # Convert tokens to a tensor if they are not already - tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens - - # Create a mask to identify wildcard positions - wildcard_mask = (tokens == -1) - - # Get scores for non-wildcard positions - regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index - - # Create a mask and compute the maximum value without modifying frame_emission - max_valid_score = frame_emission.clone() # Create a copy - max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token - max_valid_score = max_valid_score.max() - - # Use where operation to combine results - result = torch.where(wildcard_mask, max_valid_score, regular_scores) - - return result - - @dataclass class Point: token_index: int @@ -462,138 +421,41 @@ class Point: def backtrack(trellis, emission, tokens, blank_id=0): - t, j = trellis.size(0) - 1, trellis.size(1) - 1 - - path = [Point(j, t, emission[t, blank_id].exp().item())] - while j > 0: - # Should not happen but just in case - assert t > 0 - + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change - # Frame-wise score of stay vs change - p_stay = emission[t - 1, blank_id] - # p_change = emission[t - 1, tokens[j]] - p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - - # Context-aware score for stay vs change - stayed = trellis[t - 1, j] + p_stay - changed = trellis[t - 1, j - 1] + p_change - - # Update position - t -= 1 + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token if changed > stayed: j -= 1 - - # Store the path with frame-wise probability. - prob = (p_change if changed > stayed else p_stay).exp().item() - path.append(Point(j, t, prob)) - - # Now j == 0, which means, it reached the SoS. - # Fill up the rest for the sake of visualization - while t > 0: - prob = emission[t - 1, blank_id].exp().item() - path.append(Point(j, t - 1, prob)) - t -= 1 - - return path[::-1] - - - -@dataclass -class Path: - points: List[Point] - score: float - - -@dataclass -class BeamState: - """State in beam search.""" - token_index: int # Current token position - time_index: int # Current time step - score: float # Cumulative score - path: List[Point] # Path history - - -def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5): - """Standard CTC beam search backtracking implementation. - - Args: - trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps - and N is the number of tokens (including the blank token). - emission (torch.Tensor): The emission probabilities of shape (T, N). - tokens (List[int]): List of token indices (excluding the blank token). - blank_id (int, optional): The ID of the blank token. Defaults to 0. - beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5. - - Returns: - List[Point]: the best path - """ - T, J = trellis.size(0) - 1, trellis.size(1) - 1 - - init_state = BeamState( - token_index=J, - time_index=T, - score=trellis[T, J], - path=[Point(J, T, emission[T, blank_id].exp().item())] - ) - - beams = [init_state] - - while beams and beams[0].token_index > 0: - next_beams = [] - - for beam in beams: - t, j = beam.time_index, beam.token_index - - if t <= 0: - continue - - p_stay = emission[t - 1, blank_id] - p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0] - - stay_score = trellis[t - 1, j] - change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf') - - # Stay - if not math.isinf(stay_score): - new_path = beam.path.copy() - new_path.append(Point(j, t - 1, p_stay.exp().item())) - next_beams.append(BeamState( - token_index=j, - time_index=t - 1, - score=stay_score, - path=new_path - )) - - # Change - if j > 0 and not math.isinf(change_score): - new_path = beam.path.copy() - new_path.append(Point(j - 1, t - 1, p_change.exp().item())) - next_beams.append(BeamState( - token_index=j - 1, - time_index=t - 1, - score=change_score, - path=new_path - )) - - # sort by score - beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width] - - if not beams: - break - - if not beams: + if j == 0: + break + else: + # failed return None - best_beam = beams[0] - t = best_beam.time_index - j = best_beam.token_index - while t > 0: - prob = emission[t - 1, blank_id].exp().item() - best_beam.path.append(Point(j, t - 1, prob)) - t -= 1 - - return best_beam.path[::-1] + return path[::-1] # Merge the labels @@ -643,4 +505,4 @@ def merge_words(segments, separator="|"): i2 = i1 else: i2 += 1 - return words + return words \ No newline at end of file From 636f2988f82bcb3ad2450c4f85865a8d435a0b7c Mon Sep 17 00:00:00 2001 From: Claude-Assistant Date: Tue, 10 Mar 2026 09:04:45 +0100 Subject: [PATCH 2/3] fix: use blank_id parameter instead of hardcoded 0 in trellis and backtrack The original code accepted blank_id as a parameter but used hardcoded 0 in two places, breaking alignment for HuggingFace models where the blank token is [pad] (not index 0). Co-Authored-By: Claude Sonnet 4.6 --- whisperx/alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 74f9d350..f3095ba8 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -399,7 +399,7 @@ def get_trellis(emission, tokens, blank_id=0): # The extra dim for time axis is for simplification of the code. trellis = torch.empty((num_frame + 1, num_tokens + 1)) trellis[0, 0] = 0 - trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) + trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0) trellis[0, -num_tokens:] = -float("inf") trellis[-num_tokens:, 0] = float("inf") @@ -442,7 +442,7 @@ def backtrack(trellis, emission, tokens, blank_id=0): changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] # 2. Store the path with frame-wise probability. - prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + prob = emission[t - 1, tokens[j - 1] if changed > stayed else blank_id].exp().item() # Return token index and time index in non-trellis coordinate. path.append(Point(j - 1, t - 1, prob)) From 6d3edb1c0b33d9b1bef3c7ab889b4c56162b4ff6 Mon Sep 17 00:00:00 2001 From: Barabazs <31799121+Barabazs@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:30:43 +0100 Subject: [PATCH 3/3] chore: bump version --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 15d392da..94626d60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ urls = { repository = "https://github.com/m-bain/whisperx" } authors = [{ name = "Max Bain" }] name = "whisperx" -version = "3.8.1" +version = "3.8.2" description = "Time-Accurate Automatic Speech Recognition using Whisper." readme = "README.md" requires-python = ">=3.10, <3.14" diff --git a/uv.lock b/uv.lock index 421c5044..22d61222 100644 --- a/uv.lock +++ b/uv.lock @@ -3026,7 +3026,7 @@ wheels = [ [[package]] name = "whisperx" -version = "3.8.1" +version = "3.8.2" source = { editable = "." } dependencies = [ { name = "ctranslate2" },