diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py index 96c475c..cdcc058 100644 --- a/src/asr/faster_whisper_asr.py +++ b/src/asr/faster_whisper_asr.py @@ -112,10 +112,15 @@ class FasterWhisperASR(ASRInterface): def __init__(self, **kwargs): - model_size = kwargs.get("model_size", "large-v3") - # Run on GPU with FP16 + model_size_or_path = kwargs.get("model_size_or_path", kwargs.get("model_size", "large-v3")) + device = kwargs.get("device", "cuda") + compute_type = kwargs.get("compute_type", "float16") + + # Run on GPU with FP16 by default, or use provided parameters self.asr_pipeline = WhisperModel( - model_size, device="cuda", compute_type="float16" + model_size_or_path, + device=device, + compute_type=compute_type ) async def transcribe(self, client): diff --git a/src/vad/vad_factory.py b/src/vad/vad_factory.py index e20d3f9..cd81a6b 100644 --- a/src/vad/vad_factory.py +++ b/src/vad/vad_factory.py @@ -1,6 +1,3 @@ -from .pyannote_vad import PyannoteVAD - - class VADFactory: """ Factory for creating instances of VAD systems. @@ -10,15 +7,8 @@ class VADFactory: def create_vad_pipeline(type, **kwargs): """ Creates a VAD pipeline based on the specified type. - - Args: - type (str): The type of VAD pipeline to create (e.g., 'pyannote'). - kwargs: Additional arguments for the VAD pipeline creation. - - Returns: - VADInterface: An instance of a class that implements VADInterface. """ - if type == "pyannote": - return PyannoteVAD(**kwargs) + if type == "none" or type is None: + return None else: - raise ValueError(f"Unknown VAD pipeline type: {type}") + raise ValueError(f"VAD type '{type}' not available. Use 'none' to disable VAD.") \ No newline at end of file