Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ You may also need to install ffmpeg, rust etc. Follow openAI instructions here h

### Speaker Diarization

To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)

> **Note**<br>
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the [speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) model.

<h2 align="left" id="example">Usage 💬 (command line)</h2>

Expand Down Expand Up @@ -197,7 +194,7 @@ print(result["segments"]) # after alignment
# import gc; import torch; gc.collect(); torch.cuda.empty_cache(); del model_a

# 3. Assign speaker labels
diarize_model = DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
diarize_model = DiarizationPipeline(token=YOUR_HF_TOKEN, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
Expand Down Expand Up @@ -291,8 +288,8 @@ And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/py

Valuable VAD & Diarization Models from:

- [pyannote audio][https://github.com/pyannote/pyannote-audio]
- [silero vad][https://github.com/snakers4/silero-vad]
- [pyannote-audio](https://github.com/pyannote/pyannote-audio) — Speaker diarization powered by the [speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) model, licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) by [pyannoteAI](https://www.pyannote.ai)
- [silero-vad](https://github.com/snakers4/silero-vad)

Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.7.7"
version = "3.8.0"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.10, <3.14"
Expand All @@ -13,8 +13,9 @@ dependencies = [
"faster-whisper>=1.1.1",
"nltk>=3.9.1",
"numpy>=2.1.0",
"omegaconf>=2.3.0",
"pandas>=2.2.3",
"pyannote-audio>=3.3.2,<4.0.0",
"pyannote-audio>=4.0.0",
"huggingface-hub<1.0.0",
"torch~=2.8.0",
"torchaudio~=2.8.0",
Expand All @@ -36,6 +37,12 @@ include-package-data = true
where = ["."]
include = ["whisperx*"]

# torchcodec (transitive dep of pyannote-audio >=4) has no wheels for Linux aarch64
[tool.uv]
override-dependencies = [
"torchcodec>=0.6.0; (sys_platform == 'linux' and platform_machine == 'x86_64') or sys_platform == 'darwin' or sys_platform == 'win32'",
]

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
Expand Down
575 changes: 252 additions & 323 deletions uv.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion whisperx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def cli():
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-3.1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--diarize_model", default="pyannote/speaker-diarization-community-1", type=str, help="Name of the speaker diarization model to use")
parser.add_argument("--speaker_embeddings", action="store_true", help="Include speaker embeddings in JSON output (only works with --diarize)")

parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
Expand Down
2 changes: 1 addition & 1 deletion whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def load_model(
device_vad = f'cuda:{device_index}'
else:
device_vad = device
vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options)
vad_model = Pyannote(torch.device(device_vad), token=None, **default_vad_options)
else:
raise ValueError(f"Invalid vad_method: {vad_method}")

Expand Down
31 changes: 12 additions & 19 deletions whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ class DiarizationPipeline:
def __init__(
self,
model_name=None,
use_auth_token=None,
token=None,
device: Optional[Union[str, torch.device]] = "cpu",
):
if isinstance(device, str):
device = torch.device(device)
model_config = model_name or "pyannote/speaker-diarization-3.1"
model_config = model_name or "pyannote/speaker-diarization-community-1"
logger.info(f"Loading diarization model: {model_config}")
self.model = Pipeline.from_pretrained(model_config, use_auth_token=use_auth_token).to(device)
self.model = Pipeline.from_pretrained(model_config, token=token).to(device)

def __call__(
self,
Expand Down Expand Up @@ -132,22 +132,15 @@ def __call__(
'sample_rate': SAMPLE_RATE
}

if return_embeddings:
diarization, embeddings = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=True,
)
else:
diarization = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
embeddings = None
output = self.model(
audio_data,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

diarization = output.speaker_diarization
embeddings = output.speaker_embeddings if return_embeddings else None

diarize_df = pd.DataFrame(diarization.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
Expand Down
2 changes: 1 addition & 1 deletion whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
logger.info("Performing diarization...")
logger.info(f"Using model: {diarize_model_name}")
results = []
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
diarize_model = DiarizationPipeline(model_name=diarize_model_name, token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_result = diarize_model(
input_audio_path,
Expand Down
12 changes: 6 additions & 6 deletions whisperx/vads/pyannote.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger = get_logger(__name__)


def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, token=None, model_fp=None):
model_dir = torch.hub._get_torch_home()

main_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand All @@ -40,7 +40,7 @@ def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=Non

model_bytes = open(model_fp, "rb").read()

vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
vad_model = Model.from_pretrained(model_fp, token=token)
hyperparameters = {"onset": vad_onset,
"offset": vad_offset,
"min_duration_on": 0.1,
Expand Down Expand Up @@ -192,11 +192,11 @@ def __init__(
self,
segmentation: PipelineModel = "pyannote/segmentation",
fscore: bool = False,
use_auth_token: Union[Text, None] = None,
token: Union[Text, None] = None,
**inference_kwargs,
):

super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
super().__init__(segmentation=segmentation, fscore=fscore, token=token, **inference_kwargs)

def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
"""Apply voice activity detection
Expand Down Expand Up @@ -234,10 +234,10 @@ def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:

class Pyannote(Vad):

def __init__(self, device, use_auth_token=None, model_fp=None, **kwargs):
def __init__(self, device, token=None, model_fp=None, **kwargs):
logger.info("Performing voice activity detection using Pyannote...")
super().__init__(kwargs['vad_onset'])
self.vad_pipeline = load_vad_model(device, use_auth_token=use_auth_token, model_fp=model_fp)
self.vad_pipeline = load_vad_model(device, token=token, model_fp=model_fp)

def __call__(self, audio: AudioFile, **kwargs):
return self.vad_pipeline(audio)
Expand Down