Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions whisperx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def cli():
parser.add_argument("--model", default="small", help="name of the Whisper model to use")
parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device type to use for PyTorch inference (e.g. cpu, cuda)")
parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
parser.add_argument("--compute_type", default="default", type=str, choices=["default", "float16", "float32", "int8"], help="compute type for computation; 'default' uses float16 on GPU, float32 on CPU")

parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
Expand Down
7 changes: 6 additions & 1 deletion whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def load_model(
whisper_arch: str,
device: str,
device_index=0,
compute_type="float16",
compute_type="default",
asr_options: Optional[dict] = None,
language: Optional[str] = None,
vad_model: Optional[Vad]= None,
Expand All @@ -331,6 +331,7 @@ def load_model(
whisper_arch - The name of the Whisper model to load.
device - The device to load the model on.
compute_type - The compute type to use for the model.
Use "default" to automatically select based on device (float16 for GPU, float32 for CPU).
vad_model - The vad model to manually assign.
vad_method - The vad method to use. vad_model has a higher priority if it is not None.
options - A dictionary of options to use for the model.
Expand All @@ -343,6 +344,10 @@ def load_model(
A Whisper pipeline.
"""

if compute_type == "default":
compute_type = "float16" if device == "cuda" else "float32"
logger.info(f"Compute type not specified, defaulting to {compute_type} for device {device}")

if whisper_arch.endswith(".en"):
language = "en"

Expand Down