diff --git a/whisperx/__main__.py b/whisperx/__main__.py index dbb92fc4..cd281125 100644 --- a/whisperx/__main__.py +++ b/whisperx/__main__.py @@ -16,10 +16,10 @@ def cli(): parser.add_argument("--model", default="small", help="name of the Whisper model to use") parser.add_argument("--model_cache_only", type=str2bool, default=False, help="If True, will not attempt to download models, instead using cached models from --model_dir") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device type to use for PyTorch inference (e.g. cpu, cuda)") parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference") parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference") - parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation") + parser.add_argument("--compute_type", default="default", type=str, choices=["default", "float16", "float32", "int8"], help="compute type for computation; 'default' uses float16 on GPU, float32 on CPU") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced") diff --git a/whisperx/asr.py b/whisperx/asr.py index ea29e56f..a5c80711 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -313,7 +313,7 @@ def load_model( whisper_arch: str, device: str, device_index=0, - compute_type="float16", + compute_type="default", asr_options: Optional[dict] = None, language: Optional[str] = None, vad_model: Optional[Vad]= None, @@ -331,6 +331,7 @@ def load_model( whisper_arch - The name of the Whisper model to load. device - The device to load the model on. compute_type - The compute type to use for the model. + Use "default" to automatically select based on device (float16 for GPU, float32 for CPU). vad_model - The vad model to manually assign. vad_method - The vad method to use. vad_model has a higher priority if it is not None. options - A dictionary of options to use for the model. @@ -343,6 +344,10 @@ def load_model( A Whisper pipeline. """ + if compute_type == "default": + compute_type = "float16" if device == "cuda" else "float32" + logger.info(f"Compute type not specified, defaulting to {compute_type} for device {device}") + if whisper_arch.endswith(".en"): language = "en"