From df54fc321a28a63c8727206566a0a15d98d5d88d Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 01:02:27 +0200 Subject: [PATCH 01/11] adding silero and audio in-memory --- requirements.txt | 2 ++ src/vad/pyannote_vad.py | 24 ++++++++++++++++-------- src/vad/silero_vad.py | 30 ++++++++++++++++++++++++++++++ src/vad/vad_factory.py | 3 +++ 4 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 src/vad/silero_vad.py diff --git a/requirements.txt b/requirements.txt index 34aea8e..f4b0b25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ transformers==4.40.2 faster-whisper==1.0.2 torchvision~=0.18.0 torch~=2.3.0 +soundfile==0.12.1 +silero-vad==5.1.2 diff --git a/src/vad/pyannote_vad.py b/src/vad/pyannote_vad.py index 4f65c67..52d0942 100644 --- a/src/vad/pyannote_vad.py +++ b/src/vad/pyannote_vad.py @@ -1,10 +1,10 @@ import os -from os import remove - +import io +import soundfile as sf from pyannote.audio import Model from pyannote.audio.pipelines import VoiceActivityDetection -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .vad_interface import VADInterface @@ -51,11 +51,19 @@ def __init__(self, **kwargs): self.vad_pipeline.instantiate(pyannote_args) async def detect_activity(self, client): - audio_file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) - vad_results = self.vad_pipeline(audio_file_path) - remove(audio_file_path) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) + + # Create an in-memory audio file + audio_buffer = io.BytesIO() + # Save as WAV at 16kHz sample rate + sf.write(audio_buffer, audio_np, 16000, format='WAV') + + # Reset buffer position for reading + audio_buffer.seek(0) + + # Process with Pyannote directly from the in-memory buffer + vad_results = self.vad_pipeline(audio_buffer) + vad_segments = [] if len(vad_results) > 0: vad_segments = [ diff --git a/src/vad/silero_vad.py b/src/vad/silero_vad.py new file mode 100644 index 0000000..08f14c1 --- /dev/null +++ b/src/vad/silero_vad.py @@ -0,0 +1,30 @@ +from silero_vad import load_silero_vad, get_speech_timestamps + +from src.utils.audio_utils import convert_audio_bytes_to_numpy + +from .vad_interface import VADInterface + +class SileroVAD(VADInterface): + """ + Pyannote-based implementation of the VADInterface that works with in-memory audio. + """ + + def __init__(self, **kwargs): + """ + Initializes RVADFast's VAD pipeline. + """ + self.model = load_silero_vad() + + async def detect_activity(self, buffer): + # Convert bytearray to numpy array + audio_np = convert_audio_bytes_to_numpy(buffer) + + speech_timestamps = get_speech_timestamps(audio_np, self.model) + + # It returns ms + new_timestamps = [ + {'starts': timestamp['starts'] / 10000, 'ends': timestamp['ends'] / 10000} + for timestamp in speech_timestamps + ] + + return new_timestamps \ No newline at end of file diff --git a/src/vad/vad_factory.py b/src/vad/vad_factory.py index e20d3f9..b7a8b97 100644 --- a/src/vad/vad_factory.py +++ b/src/vad/vad_factory.py @@ -1,4 +1,5 @@ from .pyannote_vad import PyannoteVAD +from .silero_vad import SileroVAD class VADFactory: @@ -20,5 +21,7 @@ def create_vad_pipeline(type, **kwargs): """ if type == "pyannote": return PyannoteVAD(**kwargs) + elif type == "silero": + return SileroVAD(**kwargs) else: raise ValueError(f"Unknown VAD pipeline type: {type}") From 0f4bcd96afba284dc869e35b48aeb84fc5513add Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 01:02:55 +0200 Subject: [PATCH 02/11] fixing faster whisper language and adding audio in-memory --- src/asr/faster_whisper_asr.py | 128 +++------------------------------- src/asr/whisper_asr.py | 12 ++-- 2 files changed, 13 insertions(+), 127 deletions(-) diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py index 96c475c..2e2ced1 100644 --- a/src/asr/faster_whisper_asr.py +++ b/src/asr/faster_whisper_asr.py @@ -1,139 +1,29 @@ import os - +import torch from faster_whisper import WhisperModel -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .asr_interface import ASRInterface -language_codes = { - "afrikaans": "af", - "amharic": "am", - "arabic": "ar", - "assamese": "as", - "azerbaijani": "az", - "bashkir": "ba", - "belarusian": "be", - "bulgarian": "bg", - "bengali": "bn", - "tibetan": "bo", - "breton": "br", - "bosnian": "bs", - "catalan": "ca", - "czech": "cs", - "welsh": "cy", - "danish": "da", - "german": "de", - "greek": "el", - "english": "en", - "spanish": "es", - "estonian": "et", - "basque": "eu", - "persian": "fa", - "finnish": "fi", - "faroese": "fo", - "french": "fr", - "galician": "gl", - "gujarati": "gu", - "hausa": "ha", - "hawaiian": "haw", - "hebrew": "he", - "hindi": "hi", - "croatian": "hr", - "haitian": "ht", - "hungarian": "hu", - "armenian": "hy", - "indonesian": "id", - "icelandic": "is", - "italian": "it", - "japanese": "ja", - "javanese": "jw", - "georgian": "ka", - "kazakh": "kk", - "khmer": "km", - "kannada": "kn", - "korean": "ko", - "latin": "la", - "luxembourgish": "lb", - "lingala": "ln", - "lao": "lo", - "lithuanian": "lt", - "latvian": "lv", - "malagasy": "mg", - "maori": "mi", - "macedonian": "mk", - "malayalam": "ml", - "mongolian": "mn", - "marathi": "mr", - "malay": "ms", - "maltese": "mt", - "burmese": "my", - "nepali": "ne", - "dutch": "nl", - "norwegian nynorsk": "nn", - "norwegian": "no", - "occitan": "oc", - "punjabi": "pa", - "polish": "pl", - "pashto": "ps", - "portuguese": "pt", - "romanian": "ro", - "russian": "ru", - "sanskrit": "sa", - "sindhi": "sd", - "sinhalese": "si", - "slovak": "sk", - "slovenian": "sl", - "shona": "sn", - "somali": "so", - "albanian": "sq", - "serbian": "sr", - "sundanese": "su", - "swedish": "sv", - "swahili": "sw", - "tamil": "ta", - "telugu": "te", - "tajik": "tg", - "thai": "th", - "turkmen": "tk", - "tagalog": "tl", - "turkish": "tr", - "tatar": "tt", - "ukrainian": "uk", - "urdu": "ur", - "uzbek": "uz", - "vietnamese": "vi", - "yiddish": "yi", - "yoruba": "yo", - "chinese": "zh", - "cantonese": "yue", -} - - class FasterWhisperASR(ASRInterface): def __init__(self, **kwargs): - model_size = kwargs.get("model_size", "large-v3") - # Run on GPU with FP16 + model_size = kwargs.get("model_size", "tiny") + device = "cuda" if torch.cuda.is_available() else "cpu" + compute_type = "float16" if torch.cuda.is_available() else "float32" self.asr_pipeline = WhisperModel( - model_size, device="cuda", compute_type="float16" + model_size, device=device, compute_type=compute_type ) async def transcribe(self, client): - file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) - language = ( - None - if client.config["language"] is None - else language_codes.get(client.config["language"].lower()) - ) + language = client.config["language"].lower() segments, info = self.asr_pipeline.transcribe( - file_path, word_timestamps=True, language=language + audio_np, language=language ) segments = list(segments) # The transcription will actually run here. - os.remove(file_path) flattened_words = [ word for segment in segments for word in segment.words diff --git a/src/asr/whisper_asr.py b/src/asr/whisper_asr.py index b472e43..f4a01d4 100644 --- a/src/asr/whisper_asr.py +++ b/src/asr/whisper_asr.py @@ -3,7 +3,7 @@ import torch from transformers import pipeline -from src.audio_utils import save_audio_to_file +from src.utils.audio_utils import convert_audio_bytes_to_numpy from .asr_interface import ASRInterface @@ -19,19 +19,15 @@ def __init__(self, **kwargs): ) async def transcribe(self, client): - file_path = await save_audio_to_file( - client.scratch_buffer, client.get_file_name() - ) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) if client.config["language"] is not None: to_return = self.asr_pipeline( - file_path, + audio_np, generate_kwargs={"language": client.config["language"]}, )["text"] else: - to_return = self.asr_pipeline(file_path)["text"] - - os.remove(file_path) + to_return = self.asr_pipeline(audio_np)["text"] to_return = { "language": "UNSUPPORTED_BY_HUGGINGFACE_WHISPER", From 040ba26e893893cf7bb5c62b9abfbadd7929cf16 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 01:03:11 +0200 Subject: [PATCH 03/11] fixing logging and adding callbacks --- src/audio_utils.py | 28 --------------- .../buffering_strategies.py | 36 ++++++++----------- src/callbacks.py | 35 ++++++++++++++++++ src/client.py | 7 ++-- src/main.py | 11 +++--- src/server.py | 31 +++++++++++----- src/utils/audio_utils.py | 18 ++++++++++ src/utils/base_logger.py | 23 ++++++++++++ 8 files changed, 122 insertions(+), 67 deletions(-) delete mode 100644 src/audio_utils.py create mode 100644 src/callbacks.py create mode 100644 src/utils/audio_utils.py create mode 100644 src/utils/base_logger.py diff --git a/src/audio_utils.py b/src/audio_utils.py deleted file mode 100644 index f9aa203..0000000 --- a/src/audio_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import wave - - -async def save_audio_to_file( - audio_data, file_name, audio_dir="audio_files", audio_format="wav" -): - """ - Saves the audio data to a file. - - :param audio_data: The audio data to save. - :param file_name: The name of the file. - :param audio_dir: Directory where audio files will be saved. - :param audio_format: Format of the audio file. - :return: Path to the saved audio file. - """ - - os.makedirs(audio_dir, exist_ok=True) - - file_path = os.path.join(audio_dir, file_name) - - with wave.open(file_path, "wb") as wav_file: - wav_file.setnchannels(1) # Assuming mono audio - wav_file.setsampwidth(2) - wav_file.setframerate(16000) - wav_file.writeframes(audio_data) - - return file_path diff --git a/src/buffering_strategy/buffering_strategies.py b/src/buffering_strategy/buffering_strategies.py index bc1a0d4..1c75393 100644 --- a/src/buffering_strategy/buffering_strategies.py +++ b/src/buffering_strategy/buffering_strategies.py @@ -3,6 +3,8 @@ import os import time +from src.callbacks import AudioProcessingCallbacks + from .buffering_strategy_interface import BufferingStrategyInterface @@ -49,15 +51,9 @@ def __init__(self, client, **kwargs): self.chunk_offset_seconds = kwargs.get("chunk_offset_seconds") self.chunk_offset_seconds = float(self.chunk_offset_seconds) - self.error_if_not_realtime = os.environ.get("ERROR_IF_NOT_REALTIME") - if not self.error_if_not_realtime: - self.error_if_not_realtime = kwargs.get( - "error_if_not_realtime", False - ) - self.processing_flag = False - def process_audio(self, websocket, vad_pipeline, asr_pipeline): + def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): """ Process audio chunks by checking their length and scheduling asynchronous processing. @@ -66,7 +62,7 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline): length and, if so, it schedules asynchronous processing of the audio. Args: - websocket: The WebSocket connection for sending transcriptions. + callbacks: Callbacks for audio processing events. vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. """ @@ -77,35 +73,31 @@ def process_audio(self, websocket, vad_pipeline, asr_pipeline): ) if len(self.client.buffer) > chunk_length_in_bytes: if self.processing_flag: - exit( - "Error in realtime processing: tried processing a new " - "chunk while the previous one was still being processed" - ) + self.processing_flag.cancel() self.client.scratch_buffer += self.client.buffer self.client.buffer.clear() - self.processing_flag = True # Schedule the processing in a separate task - asyncio.create_task( - self.process_audio_async(websocket, vad_pipeline, asr_pipeline) + self.processing_flag = asyncio.create_task( + self.process_audio_async(callbacks, vad_pipeline, asr_pipeline) ) - async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): + async def process_audio_async(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): """ Asynchronously process audio for activity detection and transcription. This method performs heavy processing, including voice activity - detection and transcription of the audio data. It sends the - transcription results through the WebSocket connection. + detection and transcription of the audio data. If conditions are met, + triggers transcribes the audio, and triggers + processing and transcription callbacks Args: - websocket (Websocket): The WebSocket connection for sending - transcriptions. + callbacks: Callbacks for audio processing events. vad_pipeline: The voice activity detection pipeline. asr_pipeline: The automatic speech recognition pipeline. """ start = time.time() - vad_results = await vad_pipeline.detect_activity(self.client) + vad_results = await vad_pipeline.detect_activity(self.client.scratch_buffer) if len(vad_results) == 0: self.client.scratch_buffer.clear() @@ -123,7 +115,7 @@ async def process_audio_async(self, websocket, vad_pipeline, asr_pipeline): end = time.time() transcription["processing_time"] = end - start json_transcription = json.dumps(transcription) - await websocket.send(json_transcription) + await callbacks.trigger_transcription_complete(json_transcription) self.client.scratch_buffer.clear() self.client.increment_file_counter() diff --git a/src/callbacks.py b/src/callbacks.py new file mode 100644 index 0000000..abcf9fd --- /dev/null +++ b/src/callbacks.py @@ -0,0 +1,35 @@ +from typing import Callable, Any, Optional, Awaitable + +class AudioProcessingCallbacks: + """ + Callback interface for audio processing events. + + This class defines callback functions that can be registered to handle + different events that occur during audio processing, such as: + - Transcription completion + + Callbacks are defined as async functions to support asynchronous operations. + """ + + def __init__( + self, + on_transcription_complete: Optional[Callable[[str], Awaitable[None]]] = None, + ): + """ + Initialize the callback interface. + + Args: + on_transcription_complete: Called when transcription is complete with the + transcribed text. + """ + self.on_transcription_complete = on_transcription_complete + + async def trigger_transcription_complete(self, text: str): + """ + Trigger the transcription complete callback. + + Args: + text: The transcribed text. + """ + if self.on_transcription_complete: + await self.on_transcription_complete(text) \ No newline at end of file diff --git a/src/client.py b/src/client.py index 414a0fd..d674a20 100644 --- a/src/client.py +++ b/src/client.py @@ -3,6 +3,7 @@ from src.buffering_strategy.buffering_strategy_factory import ( BufferingStrategyFactory, ) +from src.callbacks import AudioProcessingCallbacks class Client: @@ -33,7 +34,7 @@ def __init__(self, client_id, sampling_rate, samples_width): "language": None, "processing_strategy": "silence_at_end_of_chunk", "processing_args": { - "chunk_length_seconds": 5, + "chunk_length_seconds": 3, "chunk_offset_seconds": 0.1, }, } @@ -72,7 +73,7 @@ def increment_file_counter(self): def get_file_name(self): return f"{self.client_id}_{self.file_counter}.wav" - def process_audio(self, websocket, vad_pipeline, asr_pipeline): + def process_audio(self, callbacks: AudioProcessingCallbacks, vad_pipeline, asr_pipeline): self.buffering_strategy.process_audio( - websocket, vad_pipeline, asr_pipeline + callbacks, vad_pipeline, asr_pipeline ) diff --git a/src/main.py b/src/main.py index fad9c2d..bf92c03 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import argparse import asyncio import json -import logging +from src.utils.base_logger import logger, setLogger from src.asr.asr_factory import ASRFactory from src.vad.vad_factory import VADFactory @@ -17,8 +17,8 @@ def parse_args(): parser.add_argument( "--vad-type", type=str, - default="pyannote", - help="Type of VAD pipeline to use (e.g., 'pyannote')", + default="silero", + help="Type of VAD pipeline to use (e.g., 'silero')", ) parser.add_argument( "--vad-args", @@ -73,14 +73,13 @@ def parse_args(): def main(): args = parse_args() - logging.basicConfig() - logging.getLogger().setLevel(args.log_level.upper()) + setLogger("debug") try: vad_args = json.loads(args.vad_args) asr_args = json.loads(args.asr_args) except json.JSONDecodeError as e: - print(f"Error parsing JSON arguments: {e}") + logger.error(f"Error parsing JSON arguments: {e}") return vad_pipeline = VADFactory.create_vad_pipeline(args.vad_type, **vad_args) diff --git a/src/server.py b/src/server.py index 809c7ad..eedb40a 100644 --- a/src/server.py +++ b/src/server.py @@ -1,11 +1,12 @@ import json -import logging import ssl import uuid import websockets from src.client import Client +from src.utils.base_logger import logger +from .callbacks import AudioProcessingCallbacks class Server: @@ -49,6 +50,20 @@ def __init__( self.connected_clients = {} async def handle_audio(self, client, websocket): + + async def on_transcription_complete(message): + # Process the transcribed message + try: + await websocket.send_text(message) + except Exception as e: + logger.error(f"Error processing message: {e}") + # This could be enhanced with proper error handling + + # Initialize callbacks + callbacks = AudioProcessingCallbacks( + on_transcription_complete=on_transcription_complete, + ) + while True: message = await websocket.recv() @@ -58,14 +73,14 @@ async def handle_audio(self, client, websocket): config = json.loads(message) if config.get("type") == "config": client.update_config(config["data"]) - logging.debug(f"Updated config: {client.config}") + logger.debug(f"Updated config: {client.config}") continue else: - print(f"Unexpected message type from {client.client_id}") + logger.error(f"Unexpected message type from {client.client_id}") # this is synchronous, any async operation is in BufferingStrategy client.process_audio( - websocket, self.vad_pipeline, self.asr_pipeline + callbacks, self.vad_pipeline, self.asr_pipeline ) async def handle_websocket(self, websocket): @@ -73,12 +88,12 @@ async def handle_websocket(self, websocket): client = Client(client_id, self.sampling_rate, self.samples_width) self.connected_clients[client_id] = client - print(f"Client {client_id} connected") + logger.info(f"Client {client_id} connected") try: await self.handle_audio(client, websocket) except websockets.ConnectionClosed as e: - print(f"Connection with {client_id} closed: {e}") + logger.error(f"Connection with {client_id} closed: {e}") finally: del self.connected_clients[client_id] @@ -94,7 +109,7 @@ def start(self): certfile=self.certfile, keyfile=self.keyfile ) - print( + logger.info( f"WebSocket server ready to accept secure connections on " f"{self.host}:{self.port}" ) @@ -106,7 +121,7 @@ def start(self): self.handle_websocket, self.host, self.port, ssl=ssl_context ) else: - print( + logger.error( f"WebSocket server ready to accept secure connections on " f"{self.host}:{self.port}" ) diff --git a/src/utils/audio_utils.py b/src/utils/audio_utils.py new file mode 100644 index 0000000..8e8e98f --- /dev/null +++ b/src/utils/audio_utils.py @@ -0,0 +1,18 @@ +from numpy import frombuffer, int16, float32 + +def convert_audio_bytes_to_numpy(audio_bytes): + """ + Convert raw audio bytes from scratch_buffer (bytearray) + directly to the numpy format required by Whisper and VAD + + :param audio_bytes: Raw audio bytes as bytearray + :return: Numpy array with the audio data in the format expected by Whisper + """ + # Convert bytearray directly to numpy array + # Assuming 16-bit PCM audio format + audio_as_np_int16 = frombuffer(audio_bytes, dtype=int16) + + # Convert to float32 and normalize to [-1, 1] range as expected by Whisper + audio_as_np_float32 = audio_as_np_int16.astype(float32) / 32768.0 + + return audio_as_np_float32 \ No newline at end of file diff --git a/src/utils/base_logger.py b/src/utils/base_logger.py new file mode 100644 index 0000000..c8bf6b0 --- /dev/null +++ b/src/utils/base_logger.py @@ -0,0 +1,23 @@ +import logging + +logger = logging + + +class BinaryLogFilter(logging.Filter): + def filter(self, record): + return not (record.getMessage().startswith('< BINARY') or '< BINARY' in record.getMessage()) + + +def setLogger(level: str): + global logger + + levels = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + } + + logger.basicConfig(format='%(asctime)s - %(message)s', level=levels[level]) + for handler in logger.getLogger().handlers: + handler.addFilter(BinaryLogFilter()) From d33b3ed99af35188dfac0b0ad4173db7a6848b81 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:25:52 +0200 Subject: [PATCH 04/11] Updating readme --- README.md | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 657a47c..2975fbc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ VoiceStreamAI is a Python 3 -based server and JavaScript client solution that enables near-realtime audio streaming and transcription using WebSocket. The -system employs Huggingface's Voice Activity Detection (VAD) and OpenAI's Whisper +system employs Silero Voice Activity Detection (VAD) and OpenAI's Whisper model ([faster-whisper](https://github.com/SYSTRAN/faster-whisper) being the default) for accurate speech recognition and processing. @@ -78,6 +78,8 @@ following packages: 5. `asyncio` 6. `sentence-transformers` 7. `faster-whisper` +8. `silero-vad` +9. `soundfile` Install these packages using pip: @@ -96,7 +98,7 @@ allowing you to specify components, host, and port settings according to your needs. - `--vad-type`: Specifies the type of Voice Activity Detection (VAD) pipeline to - use (default: `pyannote`) . + use (default: `silero`). The default Silero VAD doesn't require an authentication token. - `--vad-args`: A JSON string containing additional arguments for the VAD pipeline. (required for `pyannote`: `'{"auth_token": "VAD_AUTH_HERE"}'`) - `--asr-type`: Specifies the type of Automatic Speech Recognition (ASR) @@ -113,12 +115,19 @@ needs. For running the server with the standard configuration: +```bash +python3 -m src.main --help +``` + +Since the default VAD is Silero, which doesn't require an authentication token, +the above command is sufficient. If you want to use pyannote VAD: + 1. Obtain the key to the Voice-Activity-Detection model at [https://huggingface.co/pyannote/segmentation](https://huggingface.co/pyannote/segmentation) 2. Run the server using Python 3.x, please add the VAD key in the command line: ```bash -python3 -m src.main --vad-args '{"auth_token": "vad token here"}' +python3 -m src.main --vad-type 'pyannote' --vad-args '{"auth_token": "vad token here"}' ``` You can see all the command line options with the command: From 08aa607d542841ec2dbaf5c67d095ad7da9b6542 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:26:02 +0200 Subject: [PATCH 05/11] updating default values --- src/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index bf92c03..9161df5 100644 --- a/src/main.py +++ b/src/main.py @@ -23,7 +23,7 @@ def parse_args(): parser.add_argument( "--vad-args", type=str, - default='{"auth_token": "huggingface_token"}', + default=None, help="JSON string of additional arguments for VAD pipeline", ) parser.add_argument( @@ -35,7 +35,7 @@ def parse_args(): parser.add_argument( "--asr-args", type=str, - default='{"model_size": "large-v3"}', + default='{"model_size": "tiny"}', help="JSON string of additional arguments for ASR pipeline", ) parser.add_argument( From d70c9455da19e0dc749a5eb69e5d2102a67251db Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:26:20 +0200 Subject: [PATCH 06/11] fixing callback --- src/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server.py b/src/server.py index eedb40a..60006d8 100644 --- a/src/server.py +++ b/src/server.py @@ -54,7 +54,7 @@ async def handle_audio(self, client, websocket): async def on_transcription_complete(message): # Process the transcribed message try: - await websocket.send_text(message) + await websocket.send(message) except Exception as e: logger.error(f"Error processing message: {e}") # This could be enhanced with proper error handling From e189aab8f79e1df34d47efc188d2179acd658aa0 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:26:25 +0200 Subject: [PATCH 07/11] rollback languages --- src/asr/faster_whisper_asr.py | 111 +++++++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 2 deletions(-) diff --git a/src/asr/faster_whisper_asr.py b/src/asr/faster_whisper_asr.py index 2e2ced1..c16b944 100644 --- a/src/asr/faster_whisper_asr.py +++ b/src/asr/faster_whisper_asr.py @@ -6,6 +6,109 @@ from .asr_interface import ASRInterface +language_codes = { + "afrikaans": "af", + "amharic": "am", + "arabic": "ar", + "assamese": "as", + "azerbaijani": "az", + "bashkir": "ba", + "belarusian": "be", + "bulgarian": "bg", + "bengali": "bn", + "tibetan": "bo", + "breton": "br", + "bosnian": "bs", + "catalan": "ca", + "czech": "cs", + "welsh": "cy", + "danish": "da", + "german": "de", + "greek": "el", + "english": "en", + "spanish": "es", + "estonian": "et", + "basque": "eu", + "persian": "fa", + "finnish": "fi", + "faroese": "fo", + "french": "fr", + "galician": "gl", + "gujarati": "gu", + "hausa": "ha", + "hawaiian": "haw", + "hebrew": "he", + "hindi": "hi", + "croatian": "hr", + "haitian": "ht", + "hungarian": "hu", + "armenian": "hy", + "indonesian": "id", + "icelandic": "is", + "italian": "it", + "japanese": "ja", + "javanese": "jw", + "georgian": "ka", + "kazakh": "kk", + "khmer": "km", + "kannada": "kn", + "korean": "ko", + "latin": "la", + "luxembourgish": "lb", + "lingala": "ln", + "lao": "lo", + "lithuanian": "lt", + "latvian": "lv", + "malagasy": "mg", + "maori": "mi", + "macedonian": "mk", + "malayalam": "ml", + "mongolian": "mn", + "marathi": "mr", + "malay": "ms", + "maltese": "mt", + "burmese": "my", + "nepali": "ne", + "dutch": "nl", + "norwegian nynorsk": "nn", + "norwegian": "no", + "occitan": "oc", + "punjabi": "pa", + "polish": "pl", + "pashto": "ps", + "portuguese": "pt", + "romanian": "ro", + "russian": "ru", + "sanskrit": "sa", + "sindhi": "sd", + "sinhalese": "si", + "slovak": "sk", + "slovenian": "sl", + "shona": "sn", + "somali": "so", + "albanian": "sq", + "serbian": "sr", + "sundanese": "su", + "swedish": "sv", + "swahili": "sw", + "tamil": "ta", + "telugu": "te", + "tajik": "tg", + "thai": "th", + "turkmen": "tk", + "tagalog": "tl", + "turkish": "tr", + "tatar": "tt", + "ukrainian": "uk", + "urdu": "ur", + "uzbek": "uz", + "vietnamese": "vi", + "yiddish": "yi", + "yoruba": "yo", + "chinese": "zh", + "cantonese": "yue", +} + class FasterWhisperASR(ASRInterface): def __init__(self, **kwargs): model_size = kwargs.get("model_size", "tiny") @@ -18,9 +121,13 @@ def __init__(self, **kwargs): async def transcribe(self, client): audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) - language = client.config["language"].lower() + language = ( + None + if client.config["language"] is None + else language_codes.get(client.config["language"].lower()) + ) segments, info = self.asr_pipeline.transcribe( - audio_np, language=language + audio_np, word_timestamps=True, language=language ) segments = list(segments) # The transcription will actually run here. From 21773e5caad09c8cc523d683a6c8d1261c823214 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:27:03 +0200 Subject: [PATCH 08/11] removing decimal --- src/vad/silero_vad.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/vad/silero_vad.py b/src/vad/silero_vad.py index 08f14c1..9c78ca0 100644 --- a/src/vad/silero_vad.py +++ b/src/vad/silero_vad.py @@ -1,6 +1,9 @@ +from math import floor + from silero_vad import load_silero_vad, get_speech_timestamps from src.utils.audio_utils import convert_audio_bytes_to_numpy +from src.utils.base_logger import logger from .vad_interface import VADInterface @@ -23,7 +26,7 @@ async def detect_activity(self, buffer): # It returns ms new_timestamps = [ - {'starts': timestamp['starts'] / 10000, 'ends': timestamp['ends'] / 10000} + {'start': floor(timestamp['start'] / 10000), 'end': floor(timestamp['end'] / 10000)} for timestamp in speech_timestamps ] From 2cc38df2339803f1ab29ea130869fdbf481a9d96 Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:27:20 +0200 Subject: [PATCH 09/11] fixing pyannote test --- test/vad/test_pyannote_vad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/vad/test_pyannote_vad.py b/test/vad/test_pyannote_vad.py index a3b95d6..af711f3 100644 --- a/test/vad/test_pyannote_vad.py +++ b/test/vad/test_pyannote_vad.py @@ -41,7 +41,7 @@ def test_detect_activity(self): self.client.scratch_buffer = bytearray(audio_segment.raw_data) vad_results = asyncio.run( - self.vad.detect_activity(self.client) + self.vad.detect_activity(self.client.scratch_buffer) ) # Adjust VAD-detected times by adding the start time of the From 524819fc81fe189399be63bd4cf3739221fb875f Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 18 Apr 2025 11:27:29 +0200 Subject: [PATCH 10/11] adding silero vad test --- test/vad/test_silero_vad.py | 88 +++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 test/vad/test_silero_vad.py diff --git a/test/vad/test_silero_vad.py b/test/vad/test_silero_vad.py new file mode 100644 index 0000000..c87205f --- /dev/null +++ b/test/vad/test_silero_vad.py @@ -0,0 +1,88 @@ +# tests/vad/test_silero_vad.py + +import asyncio +import json +import os +import unittest + +from pydub import AudioSegment + +from src.client import Client +from src.vad.silero_vad import SileroVAD + + +class TestSileroVAD(unittest.TestCase): + def setUp(self): + self.vad = SileroVAD() + self.annotations_path = os.path.join( + os.path.dirname(__file__), "../audio_files/annotations.json" + ) + self.client = Client("test_client", 16000, 2) # Example client + + def load_annotations(self): + with open(self.annotations_path, "r") as file: + return json.load(file) + + def test_detect_activity(self): + annotations = self.load_annotations() + + for audio_file, data in annotations.items(): + audio_file_path = os.path.join( + os.path.dirname(__file__), f"../audio_files/{audio_file}" + ) + + for annotated_segment in data["segments"]: + print(annotated_segment['transcription']) + # Load the specific audio segment for VAD + audio_segment = self.get_audio_segment( + audio_file_path, + annotated_segment["start"], + annotated_segment["end"], + ) + self.client.scratch_buffer = bytearray(audio_segment.raw_data) + + vad_results = asyncio.run( + self.vad.detect_activity(self.client.scratch_buffer) + ) + + # Adjust VAD-detected times by adding the start time of the + # annotated segment + adjusted_vad_results = [ + { + "start": segment["start"] + annotated_segment["start"], + "end": segment["end"] + annotated_segment["start"], + } + for segment in vad_results + ] + + detected_segments = [ + segment + for segment in adjusted_vad_results + if segment["start"] <= annotated_segment["start"] + 1 + and segment["end"] <= annotated_segment["end"] + 4.2 + ] + + # Print formatted information about the test + print( + f"\nTesting segment from '{audio_file}': Annotated Start: " + f"{annotated_segment['start']}, Annotated End: " + f"{annotated_segment['end']}" + ) + print(f"VAD segments: {adjusted_vad_results}") + print(f"Overlapping, Detected segments: {detected_segments}") + + # Assert that at least one detected segment meets the condition + self.assertTrue( + len(detected_segments) > 0, + "No detected segment matches the annotated segment", + ) + + def get_audio_segment(self, file_path, start, end): + with open(file_path, "rb") as file: + audio = AudioSegment.from_file(file, format="wav") + # pydub works in milliseconds + return audio[start * 1000 : end * 1000] # noqa: E203 + + +if __name__ == "__main__": + unittest.main() From 2b6d3b6bf8f98c2221915cace7b11644e792e9ab Mon Sep 17 00:00:00 2001 From: Matias Miranda Date: Fri, 25 Apr 2025 10:21:57 +0200 Subject: [PATCH 11/11] fix silero sampling rate --- src/buffering_strategy/buffering_strategies.py | 2 +- src/vad/silero_vad.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/buffering_strategy/buffering_strategies.py b/src/buffering_strategy/buffering_strategies.py index 1c75393..b4d214d 100644 --- a/src/buffering_strategy/buffering_strategies.py +++ b/src/buffering_strategy/buffering_strategies.py @@ -97,7 +97,7 @@ async def process_audio_async(self, callbacks: AudioProcessingCallbacks, vad_pip asr_pipeline: The automatic speech recognition pipeline. """ start = time.time() - vad_results = await vad_pipeline.detect_activity(self.client.scratch_buffer) + vad_results = await vad_pipeline.detect_activity(self.client) if len(vad_results) == 0: self.client.scratch_buffer.clear() diff --git a/src/vad/silero_vad.py b/src/vad/silero_vad.py index 9c78ca0..edd4840 100644 --- a/src/vad/silero_vad.py +++ b/src/vad/silero_vad.py @@ -18,15 +18,15 @@ def __init__(self, **kwargs): """ self.model = load_silero_vad() - async def detect_activity(self, buffer): + async def detect_activity(self, client): # Convert bytearray to numpy array - audio_np = convert_audio_bytes_to_numpy(buffer) + audio_np = convert_audio_bytes_to_numpy(client.scratch_buffer) speech_timestamps = get_speech_timestamps(audio_np, self.model) # It returns ms new_timestamps = [ - {'start': floor(timestamp['start'] / 10000), 'end': floor(timestamp['end'] / 10000)} + {'start': floor(timestamp['start'] / client.sampling_rate), 'end': floor(timestamp['end'] / client.sampling_rate)} for timestamp in speech_timestamps ]