Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.8.1"
version = "3.8.2"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.10, <3.14"
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

228 changes: 45 additions & 183 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Forced Alignment with Whisper
C. Max Bain
"""
import math

from dataclasses import dataclass
from typing import Iterable, Optional, Union, List

Expand Down Expand Up @@ -86,8 +84,8 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
else:
logger.error(f"No default alignment model for language: {language_code}. "
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
f"then pass the model name via --align_model [MODEL_NAME]")
f"Please find a wav2vec2.0 model finetuned on this language at https://huggingface.co/models, "
f"then pass the model name via --align_model [MODEL_NAME]")
raise ValueError(f"No default align-model for language: {language_code}")

if model_name in torchaudio.pipelines.__all__:
Expand Down Expand Up @@ -178,19 +176,11 @@ def align(
elif char_ in model_dictionary.keys():
clean_char.append(char_)
clean_cdx.append(cdx)
else:
# add placeholder
clean_char.append('*')
clean_cdx.append(cdx)

clean_wdx = []
for wdx, wrd in enumerate(per_word):
if any([c in model_dictionary.keys() for c in wrd.lower()]):
clean_wdx.append(wdx)
else:
# index for placeholder
clean_wdx.append(wdx)


# Use language-specific Punkt model if available otherwise we fallback to English.
punkt_lang = PUNKT_LANGUAGES.get(model_lang, 'english')
Expand Down Expand Up @@ -244,7 +234,7 @@ def align(
continue

text_clean = "".join(segment_data[sdx]["clean_char"])
tokens = [model_dictionary.get(c, -1) for c in text_clean]
tokens = [model_dictionary[c] for c in text_clean]

f1 = int(t1 * SAMPLE_RATE)
f2 = int(t2 * SAMPLE_RATE)
Expand Down Expand Up @@ -277,8 +267,7 @@ def align(
blank_id = code

trellis = get_trellis(emission, tokens, blank_id)
# path = backtrack(trellis, emission, tokens, blank_id)
path = backtrack_beam(trellis, emission, tokens, blank_id, beam_width=2)
path = backtrack(trellis, emission, tokens, blank_id)

if path is None:
logger.warning(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original')
Expand Down Expand Up @@ -405,55 +394,25 @@ def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)

trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1:, 0] = float("inf")
# Trellis has extra dimensions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")

for t in range(num_frame - 1):
for t in range(num_frame):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
# trellis[t, :-1] + emission[t, tokens[1:]],
trellis[t, :-1] + get_wildcard_emission(emission[t], tokens[1:], blank_id),
trellis[t, :-1] + emission[t, tokens],
)
return trellis


def get_wildcard_emission(frame_emission, tokens, blank_id):
"""Processing token emission scores containing wildcards (vectorized version)

Args:
frame_emission: Emission probability vector for the current frame
tokens: List of token indices
blank_id: ID of the blank token

Returns:
tensor: Maximum probability score for each token position
"""
assert 0 <= blank_id < len(frame_emission)

# Convert tokens to a tensor if they are not already
tokens = torch.tensor(tokens) if not isinstance(tokens, torch.Tensor) else tokens

# Create a mask to identify wildcard positions
wildcard_mask = (tokens == -1)

# Get scores for non-wildcard positions
regular_scores = frame_emission[tokens.clamp(min=0).long()] # clamp to avoid -1 index

# Create a mask and compute the maximum value without modifying frame_emission
max_valid_score = frame_emission.clone() # Create a copy
max_valid_score[blank_id] = float('-inf') # Modify the copy to exclude the blank token
max_valid_score = max_valid_score.max()

# Use where operation to combine results
result = torch.where(wildcard_mask, max_valid_score, regular_scores)

return result


@dataclass
class Point:
token_index: int
Expand All @@ -462,138 +421,41 @@ class Point:


def backtrack(trellis, emission, tokens, blank_id=0):
t, j = trellis.size(0) - 1, trellis.size(1) - 1

path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0

# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()

path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
# p_change = emission[t - 1, tokens[j]]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]

# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change

# Update position
t -= 1
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else blank_id].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))

# 3. Update the token
if changed > stayed:
j -= 1

# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))

# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1

return path[::-1]



@dataclass
class Path:
points: List[Point]
score: float


@dataclass
class BeamState:
"""State in beam search."""
token_index: int # Current token position
time_index: int # Current time step
score: float # Cumulative score
path: List[Point] # Path history


def backtrack_beam(trellis, emission, tokens, blank_id=0, beam_width=5):
"""Standard CTC beam search backtracking implementation.

Args:
trellis (torch.Tensor): The trellis (or lattice) of shape (T, N), where T is the number of time steps
and N is the number of tokens (including the blank token).
emission (torch.Tensor): The emission probabilities of shape (T, N).
tokens (List[int]): List of token indices (excluding the blank token).
blank_id (int, optional): The ID of the blank token. Defaults to 0.
beam_width (int, optional): The number of top paths to keep during beam search. Defaults to 5.

Returns:
List[Point]: the best path
"""
T, J = trellis.size(0) - 1, trellis.size(1) - 1

init_state = BeamState(
token_index=J,
time_index=T,
score=trellis[T, J],
path=[Point(J, T, emission[T, blank_id].exp().item())]
)

beams = [init_state]

while beams and beams[0].token_index > 0:
next_beams = []

for beam in beams:
t, j = beam.time_index, beam.token_index

if t <= 0:
continue

p_stay = emission[t - 1, blank_id]
p_change = get_wildcard_emission(emission[t - 1], [tokens[j]], blank_id)[0]

stay_score = trellis[t - 1, j]
change_score = trellis[t - 1, j - 1] if j > 0 else float('-inf')

# Stay
if not math.isinf(stay_score):
new_path = beam.path.copy()
new_path.append(Point(j, t - 1, p_stay.exp().item()))
next_beams.append(BeamState(
token_index=j,
time_index=t - 1,
score=stay_score,
path=new_path
))

# Change
if j > 0 and not math.isinf(change_score):
new_path = beam.path.copy()
new_path.append(Point(j - 1, t - 1, p_change.exp().item()))
next_beams.append(BeamState(
token_index=j - 1,
time_index=t - 1,
score=change_score,
path=new_path
))

# sort by score
beams = sorted(next_beams, key=lambda x: x.score, reverse=True)[:beam_width]

if not beams:
break

if not beams:
if j == 0:
break
else:
# failed
return None

best_beam = beams[0]
t = best_beam.time_index
j = best_beam.token_index
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
best_beam.path.append(Point(j, t - 1, prob))
t -= 1

return best_beam.path[::-1]
return path[::-1]


# Merge the labels
Expand Down Expand Up @@ -643,4 +505,4 @@ def merge_words(segments, separator="|"):
i2 = i1
else:
i2 += 1
return words
return words