diff --git a/.github/workflows/dictionary/google-cloud.txt b/.github/workflows/dictionary/google-cloud.txt index a10355408..55a3f33ba 100644 --- a/.github/workflows/dictionary/google-cloud.txt +++ b/.github/workflows/dictionary/google-cloud.txt @@ -95,11 +95,13 @@ servicemanagement servicemesh servicenetworking serviceusage +ssml stackdriver statefulset subnetwork subnetworks superblocks +texttospeech timepicker uefi ultragpu diff --git a/modules/python/src/custom_nodes/google_genmedia/__init__.py b/modules/python/src/custom_nodes/google_genmedia/__init__.py index 336febe3e..7b5fe83b0 100644 --- a/modules/python/src/custom_nodes/google_genmedia/__init__.py +++ b/modules/python/src/custom_nodes/google_genmedia/__init__.py @@ -87,7 +87,10 @@ def setup_custom_package_logger(): setup_custom_package_logger() - +from .chirp3hd_node import NODE_CLASS_MAPPINGS as CHIRP3_HD_NODE_CLASS_MAPPINGS +from .chirp3hd_node import ( + NODE_DISPLAY_NAME_MAPPINGS as CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS, +) from .gemini_flash_image_node import ( NODE_CLASS_MAPPINGS as GEMINI_FLASH_25_IMAGE_NODE_CLASS_MAPPINGS, ) @@ -123,6 +126,7 @@ def setup_custom_package_logger(): # Combine all node class mappings NODE_CLASS_MAPPINGS = { + **CHIRP3_HD_NODE_CLASS_MAPPINGS, **IMAGEN3_NODE_CLASS_MAPPINGS, **IMAGEN4_NODE_CLASS_MAPPINGS, **LYRIA2_NODE_CLASS_MAPPINGS, @@ -136,6 +140,7 @@ def setup_custom_package_logger(): # Combine all node display name mappings NODE_DISPLAY_NAME_MAPPINGS = { + **CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS, **IMAGEN3_NODE_DISPLAY_NAME_MAPPINGS, **IMAGEN4_NODE_DISPLAY_NAME_MAPPINGS, **LYRIA2_NODE_DISPLAY_NAME_MAPPINGS, diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py new file mode 100644 index 000000000..281984a44 --- /dev/null +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Optional + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import texttospeech + +from . import utils +from .base import VertexAIClient +from .constants import CHIRP3_USER_AGENT +from .custom_exceptions import ConfigurationError + +logger = logging.getLogger(__name__) + + +def classify_string(text: str) -> dict: + """ + Classifies a string as 'ssml', 'markup', or 'text'. + + Args: + text: The string to classify. + + Returns: + The dictionary with classification. + """ + if re.search(r"<.*?>", text): + return {"ssml": text} + elif re.search(r"\[.*?\]", text): + return {"markup": text} + else: + return {"text": text} + + +class Chirp3API(VertexAIClient): + """ + Handles all communication with the Google Cloud Text-to-Speech API for Chirp3 models. + """ + + def __init__( + self, + project_id: Optional[str] = None, + region: Optional[str] = None, + user_agent: Optional[str] = CHIRP3_USER_AGENT, + ): + """ + Initializes the TextToSpeechClient. + + Args: + project_id: The GCP project ID. Overrides metadata lookup. + region: The GCP region. Overrides metadata lookup. + user_agent: The user agent to use for the client. + + Raises: + ConfigurationError: If client initialization fails. + """ + super().__init__(gcp_project_id=project_id, gcp_region=region) + + try: + client_info = ClientInfo(user_agent=user_agent) if user_agent else None + self.client = texttospeech.TextToSpeechClient(client_info=client_info) + print(f"Initialized Google TTS Client. with User-Agent: {user_agent}") + except Exception as e: + raise ConfigurationError( + "Failed to initialize Google TTS Client. " + "Please ensure you are authenticated (e.g., `gcloud auth application-default login`). " + f"Error: {e}" + ) + + def synthesize( + self, + language_code: str, + sample_rate: int, + speed: float, + text: str, + voice_name: str, + volume_gain_db: float, + ) -> tuple[bytes, int]: + """ + Synthesizes speech and returns the raw audio binary content and sample rate. + + Args: + text: The text to synthesize. + language_code: The language code for the voice. + voice_name: The name of the voice to use. + sample_rate: The desired sample rate of the audio. + speed: The speaking rate of the synthesized speech. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing the raw audio content (bytes) and the sample rate (int). + + + Raises: + APIInputError: If input parameters are invalid. + APIExecutionError: If music generation fails due to API or unexpected issues. + """ + + synthesis_input_params = classify_string(text) + voice_params = {"language_code": language_code, "name": voice_name} + + logger.info(f" - Synthesis Input: {synthesis_input_params}") + logger.info(f" - Voice Params: {voice_params}") + + return utils.generate_speech_from_text( + client=self.client, + sample_rate=sample_rate, + speed=speed, + synthesis_input_params=synthesis_input_params, + voice_params=voice_params, + volume_gain_db=volume_gain_db, + ) diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py new file mode 100644 index 000000000..f8ce37f20 --- /dev/null +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py @@ -0,0 +1,171 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +from .chirp3hd_api import Chirp3API +from .constants import CHIRP3_HD_MODEL +from .custom_exceptions import APIExecutionError, ConfigurationError +from .utils import get_tts_voices_and_languages, load_audio_from_bytes + + +class Chirp3Node: + """ + A ComfyUI node for Google's Text-to-Speech API, specifically for the Chirp3 HD model. + It synthesizes text into speech and returns an audio tensor. + """ + + def __init__(self) -> None: + """ + Initializes the Chirp3Node. + """ + pass + + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: + """ + Defines the input types and widgets for the ComfyUI node. + + Returns: + A dictionary specifying the required and optional input parameters. + """ + dynamic_voices, dynamic_langs, voice_map = get_tts_voices_and_languages( + model_to_include=CHIRP3_HD_MODEL + ) + cls.voice_id_map = voice_map + return { + "required": { + "text": ( + "STRING", + { + "label_on": True, + "multiline": True, + "default": "Hello! I am a generative voice, designed by Google.", + "placeholder": "Enter the text to synthesize...", + }, + ), + "language_code": ( + dynamic_langs, + { + "default": ( + dynamic_langs[0] if len(dynamic_langs) > 16 else "en-US" + ) + }, + ), + "voice_name": ( + dynamic_voices, + {"default": (dynamic_voices[0] if dynamic_voices else "Charon")}, + ), + "sample_rate": ( + "INT", + {"default": 24000, "min": 8000, "max": 48000, "step": 10}, + ), + "speed": ( + "FLOAT", + {"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.05}, + ), + "volume_gain_db": ( + "FLOAT", + {"default": 0.0, "min": -96.0, "max": 16.0, "step": 0.1}, + ), + }, + "optional": { + "gcp_project_id": ( + "STRING", + {"default": "", "placeholder": "your-gcp-project-id"}, + ), + "gcp_region": ( + "STRING", + {"default": "", "placeholder": "us-central1"}, + ), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "execute_synthesis" + CATEGORY = "Google AI/Chirp3" + + def execute_synthesis( + self, + gcp_project_id: Optional[str], + gcp_region: Optional[str], + language_code: str, + sample_rate: int, + speed: float, + text: str, + voice_name: str, + volume_gain_db: float, + ) -> Tuple[Dict[str, Any],]: + """ + Executes the text-to-speech synthesis process using the Chirp3 HD model. + + Args: + gcp_project_id: The GCP project ID. + gcp_region: The GCP region. + language_code: The language and region of the voice. + sample_rate: The desired sample rate for the audio. + speed: The speaking rate of the audio. + text: The text to be synthesized into speech. + voice_name: The name of the voice to use. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing a dictionary with the audio waveform and sample rate. + + Raises: + ConfigurationError: If the input text is empty. + RuntimeError: For errors during API communication or audio processing. + """ + short_voice_name = self.voice_id_map.get(voice_name) + if not short_voice_name: + raise ConfigurationError( + f"Voice ID lookup failed for selected voice: {voice_name}. " + "The voice list may be corrupt or outdated." + ) + if not text or not text.strip(): + raise ConfigurationError("Text input cannot be empty.") + + # Reconstruct the full voice name as per user's specified logic + full_voice_name = f"{language_code}-{CHIRP3_HD_MODEL}-{short_voice_name}" + + try: + api_client = Chirp3API(project_id=gcp_project_id, region=gcp_region) + + audio_data_binary, _ = api_client.synthesize( + language_code=language_code, + sample_rate=sample_rate, + speed=speed, + text=text, + voice_name=full_voice_name, + volume_gain_db=volume_gain_db, + ) + + if audio_data_binary is None or len(audio_data_binary) == 0: + raise APIExecutionError("API call returned no audio data.") + + waveform, r_sample_rate = load_audio_from_bytes(audio_data_binary) + output_audio = { + "waveform": waveform.unsqueeze(0), + "sample_rate": r_sample_rate, + } + return (output_audio,) + + except (ConfigurationError, APIExecutionError) as e: + raise RuntimeError(str(e)) from e + except Exception as e: + raise RuntimeError(f"An unexpected error occurred: {e}") from e + + +NODE_CLASS_MAPPINGS = {"Chirp3Node": Chirp3Node} +NODE_DISPLAY_NAME_MAPPINGS = {"Chirp3Node": "Chirp3 HD"} diff --git a/modules/python/src/custom_nodes/google_genmedia/constants.py b/modules/python/src/custom_nodes/google_genmedia/constants.py index 6ebf9f498..f8ace5bd6 100644 --- a/modules/python/src/custom_nodes/google_genmedia/constants.py +++ b/modules/python/src/custom_nodes/google_genmedia/constants.py @@ -19,6 +19,8 @@ from google.genai import types AUDIO_MIME_TYPES = ["audio/mp3", "audio/wav", "audio/mpeg"] +CHIRP3_HD_MODEL = "Chirp3-HD" +CHIRP3_USER_AGENT = "cloud-solutions/comfyui-chirp3-custom-node-v1" GEMINI_USER_AGENT = "cloud-solutions/comfyui-gemini-custom-node-v1" GEMINI_25_FLASH_IMAGE_ASPECT_RATIO = [ "1:1", diff --git a/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py b/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py index d19cf5d5e..6a9f81197 100644 --- a/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py +++ b/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This is a preview version of veo2 custom node # Copyright 2025 Google LLC # (license header) - +import base64 from typing import Optional from google.api_core.gapic_v1.client_info import ClientInfo @@ -24,7 +23,7 @@ from . import utils from .base import VertexAIClient from .constants import LYRIA2_MODEL, LYRIA2_USER_AGENT -from .custom_exceptions import APIExecutionError, APIInputError, ConfigurationError +from .custom_exceptions import APIExecutionError, ConfigurationError from .logger import get_node_logger from .retry import api_error_retry @@ -89,19 +88,46 @@ def generate_music_from_text( APIInputError: If input parameters are invalid. APIExecutionError: If music generation fails due to API or unexpected issues. """ + waveforms = [] + sample_rate = None instance = {"prompt": str(prompt)} if negative_prompt: instance["negative_prompt"] = str(negative_prompt) if seed > 0: instance["seed"] = seed instance["sample_count"] = 1 - logger.info("Seed is greater than 0, setting sample_count to 1.") + logger.info( + "Lyria Node: Seed is greater than 0, setting sample_count to 1." + ) else: instance["sample_count"] = sample_count - logger.info(f"Instance: {instance}") + logger.info(f"Lyria Node: Instance: {instance}") response = self.client.predict( endpoint=self.model_endpoint, instances=[instance] ) - logger.info(f"Response received from model: {response.model_display_name}") + logger.info( + f"Lyria Node: Response received from model: {response.model_display_name}" + ) + for n, prediction in enumerate(response.predictions): + prediction_dict = dict(prediction) + audio_bytes = base64.b64decode(prediction_dict["bytesBase64Encoded"]) + + try: + waveform, current_sample_rate = utils.load_audio_from_bytes(audio_bytes) + + if sample_rate is None: + sample_rate = current_sample_rate + elif sample_rate != current_sample_rate: + logger.warning( + f"Mismatch in sample rates at index {n}. Expected {sample_rate}, got {current_sample_rate}. " + "Proceeding, but this might cause playback issues." + ) + + waveforms.append(waveform) + + except Exception as e: + raise APIExecutionError( + f"An unexpected error occurred while processing audio sample {n}: {e}" + ) from e - return utils.process_audio_response(response) + return utils.process_audio_response(waveforms, sample_rate) diff --git a/modules/python/src/custom_nodes/google_genmedia/requirements.txt b/modules/python/src/custom_nodes/google_genmedia/requirements.txt index 79c4929f8..a864d321a 100644 --- a/modules/python/src/custom_nodes/google_genmedia/requirements.txt +++ b/modules/python/src/custom_nodes/google_genmedia/requirements.txt @@ -19,5 +19,6 @@ google-cloud-aiplatform==1.111.0 google-cloud-storage==2.19.0 google-genai==1.46.0 google-generativeai==0.8.5 +google-cloud-texttospeech==2.31.0 moviepy==2.2.1 opencv-python-headless==4.11.0.86 diff --git a/modules/python/src/custom_nodes/google_genmedia/utils.py b/modules/python/src/custom_nodes/google_genmedia/utils.py index 57f74ec45..ad8786fc3 100644 --- a/modules/python/src/custom_nodes/google_genmedia/utils.py +++ b/modules/python/src/custom_nodes/google_genmedia/utils.py @@ -21,7 +21,7 @@ import time import wave from io import BytesIO -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import folder_paths import numpy as np @@ -29,14 +29,14 @@ from google import genai from google.api_core import exceptions as api_core_exceptions from google.api_core.client_info import ClientInfo -from google.cloud import storage +from google.cloud import storage, texttospeech from google.genai import errors as genai_errors from google.genai import types from google.genai.types import GenerateVideosConfig, Image from grpc import StatusCode from PIL import Image as PIL_Image -from .constants import STORAGE_USER_AGENT +from .constants import CHIRP3_HD_MODEL, STORAGE_USER_AGENT from .custom_exceptions import APIExecutionError, APIInputError from .logger import get_node_logger from .retry import api_error_retry @@ -195,6 +195,51 @@ def generate_image_from_text( return generated_pil_images +@api_error_retry +def generate_speech_from_text( + client: texttospeech.TextToSpeechClient, + sample_rate: int, + speed: float, + synthesis_input_params: dict, + voice_params: dict, + volume_gain_db: float, +) -> tuple[bytes, int]: + """ + Synthesizes speech and returns the raw audio binary content and sample rate. + + Args: + client: The TextToSpeechClient instance. + language_code: The language code for the voice. + sample_rate: The desired sample rate of the audio. + speed: The speaking rate of the synthesized speech. + synthesize: Synthesis input. + voice_params: Voice selection params. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing the raw audio content (bytes) and the sample rate (int). + + Raises: + APIInputError: If the specified `file_path` does not exist. + APIExecutionError: If an error occurs during the file reading or conversion process. + """ + synthesis_input = texttospeech.SynthesisInput(**synthesis_input_params) + voice = texttospeech.VoiceSelectionParams(**voice_params) + audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.LINEAR16, + sample_rate_hertz=sample_rate, + speaking_rate=speed, + volume_gain_db=volume_gain_db, + ) + print(f" - Audio Config: {audio_config}") + + response = client.synthesize_speech( + input=synthesis_input, voice=voice, audio_config=audio_config + ) + print(f" - Received audio content of length: {len(response.audio_content)} bytes") + return response.audio_content, sample_rate + + @api_error_retry def generate_video_from_gcsuri_image( client: genai.Client, @@ -670,6 +715,153 @@ def generate_video_from_text( return process_video_response(operation) +def get_tts_voices_and_languages( + model_to_include: Optional[str] = None, +) -> Tuple[List[str], List[str], Dict[str, str]]: + """ + Fetches and filters Google TTS voices, caching the master list in memory + to ensure the API is only called once per session. + + Args: + model_to_include: If provided, only voices containing this string will be returned. + If None, voices containing CHIRP3_HD_MODEL will be excluded. + + Returns: + A tuple of (voice_display_names, language_codes, voice_id_map). + """ + if not hasattr(get_tts_voices_and_languages, "all_voices"): + try: + logger.info("Fetching all TTS voices from Google API for caching...") + client = texttospeech.TextToSpeechClient() + response = client.list_voices() + get_tts_voices_and_languages.all_voices = response.voices + logger.info(f"Successfully cached {len(response.voices)} voices from API.") + except Exception as e: + logger.warning(f"Failed to fetch TTS lists from API (likely no auth): {e}") + get_tts_voices_and_languages.all_voices = [] # Cache empty list on failure + + all_voices = get_tts_voices_and_languages.all_voices + + if not all_voices: + return ["(No voices found)"], ["en-US"], {} + + # Explicitly filter the voices first + if model_to_include: + # If a model is specified, include only voices for that model. + filtered_voices = [ + voice for voice in all_voices if model_to_include in voice.name + ] + filter_log = f"including '{CHIRP3_HD_MODEL}'" + else: + # Otherwise, exclude the Chirp3 HD voices by default. + filtered_voices = [ + voice for voice in all_voices if CHIRP3_HD_MODEL not in voice.name + ] + filter_log = f"excluding '{CHIRP3_HD_MODEL}'" + + # Handle case where filter results in an empty list + if not filtered_voices: + logger.warning(f"TTS voice filter ({filter_log}) resulted in an empty list.") + return ["(No voices found)"], ["en-US"], {} + + # Process the filtered list to create the final outputs + voice_map = {} + lang_set = set() + gender_map = { + texttospeech.SsmlVoiceGender.MALE: "Male", + texttospeech.SsmlVoiceGender.FEMALE: "Female", + texttospeech.SsmlVoiceGender.NEUTRAL: "Neutral", + } + + for voice in filtered_voices: + gender_raw = getattr(voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL) + gender = gender_map.get(gender_raw, "Neutral") + + if model_to_include: + # For included models (like Chirp), use a short name. + short_name = voice.name.split("-")[-1] + display_name = f"{short_name} ({gender})" + voice_map[display_name] = short_name + else: + # For other voices, use the full name. + display_name = f"{voice.name} ({gender})" + voice_map[display_name] = voice.name + + lang_set.update(voice.language_codes) + + voice_list = sorted(list(voice_map.keys())) + lang_list = sorted(list(lang_set)) + + logger.info( + f"Filtered to {len(voice_list)} voices and {len(lang_list)} languages ({filter_log})." + ) + return voice_list, lang_list, voice_map + + +def load_audio_from_bytes(audio_bytes: bytes) -> Tuple[torch.Tensor, int]: + """ + Loads audio from WAV-formatted bytes into a normalized torch tensor. + Supports 8-bit (unsigned) and 16-bit (signed) WAV data. + + Args: + audio_bytes: The raw audio data in WAV format. + + Returns: + A tuple containing: + - waveform (torch.Tensor): Audio data with shape (Channels, Time). + - sample_rate (int): The sample rate of the audio. + + Raises: + APIExecutionError: If the audio data cannot be read or has an unsupported format. + """ + buffer = io.BytesIO(audio_bytes) + + try: + with wave.open(buffer, "rb") as wf: + sample_rate = wf.getframerate() + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + n_frames = wf.getnframes() + + frames = wf.readframes(n_frames) + + if sampwidth == 2: + dtype = np.int16 + elif sampwidth == 1: + dtype = np.uint8 # 8-bit WAV is typically unsigned + else: + raise APIExecutionError( + f"Unsupported sample width for WAV: {sampwidth} bytes. Only 8-bit and 16-bit are supported by this implementation." + ) + + waveform_np = np.frombuffer(frames, dtype=dtype) + + if dtype == np.int16: + waveform_np = waveform_np.astype(np.float32) / 32768.0 + elif dtype == np.uint8: + waveform_np = (waveform_np.astype(np.float32) - 128.0) / 128.0 + + waveform_tensor = torch.from_numpy(waveform_np) + + if n_channels > 1: + waveform_tensor = waveform_tensor.reshape(-1, n_channels).transpose( + 0, 1 + ) + else: + waveform_tensor = waveform_tensor.reshape(1, -1) + + return waveform_tensor, sample_rate + + except wave.Error as e: + raise APIExecutionError( + "Failed to read audio data as WAV. The data might be in a different format." + ) from e + except Exception as e: + raise APIExecutionError( + f"An unexpected error occurred while loading audio bytes: {e}" + ) from e + + def media_file_to_genai_part(file_path: str, mime_type: str) -> types.Part: """Reads a media file (image, audio, or video) and converts it to a genai.types.Part. @@ -731,7 +923,7 @@ def prep_for_media_conversion(file_path: str, mime_type: str) -> Optional[types. return None # Return None if file not found -def process_audio_response(response: Any) -> dict: +def process_audio_response(waveforms: Any, sample_rate: int) -> dict: """ Processes the audio generation response, loads the audio into a tensor, and returns it in the format expected by ComfyUI's AUDIO output. @@ -739,7 +931,8 @@ def process_audio_response(response: Any) -> dict: It assumes the audio from the API is in WAV format. Args: - response: The completed response object from the Lyria API. + waveforms: The completed response object from the Lyria API. + sample_rate: The sample rate of the audio. Returns: A dictionary containing the audio waveform as a torch.Tensor @@ -748,68 +941,6 @@ def process_audio_response(response: Any) -> dict: Raises: APIExecutionError: If no audio data is found or if loading fails. """ - if not response.predictions: - raise APIExecutionError("No predictions found in the API response.") - - waveforms = [] - sample_rate = None - - logger.info(f"Found {len(response.predictions)} audio clips to process.") - for n, prediction in enumerate(response.predictions): - prediction_dict = dict(prediction) - audio_bytes = base64.b64decode(prediction_dict["bytesBase64Encoded"]) - - buffer = io.BytesIO(audio_bytes) - - try: - with wave.open(buffer, "rb") as wf: - if sample_rate is None: - sample_rate = wf.getframerate() - elif sample_rate != wf.getframerate(): - logger.warning( - f"Mismatch in sample rates. Expected {sample_rate}, got {wf.getframerate()}. Using the first sample rate." - ) - - n_channels = wf.getnchannels() - sampwidth = wf.getsampwidth() - n_frames = wf.getnframes() - - frames = wf.readframes(n_frames) - - if sampwidth == 2: - dtype = np.int16 - elif sampwidth == 1: - dtype = np.uint8 # 8-bit is usually unsigned - else: - raise APIExecutionError( - f"Unsupported sample width for WAV: {sampwidth} bytes. Only 8-bit and 16-bit are supported by this implementation." - ) - - waveform_np = np.frombuffer(frames, dtype=dtype) - - # Normalize - if dtype == np.int16: - waveform_np = waveform_np.astype(np.float32) / 32768.0 - elif dtype == np.uint8: - waveform_np = (waveform_np.astype(np.float32) - 128.0) / 128.0 - - # Reshape to (T, C) and then convert to tensor - waveform_tensor = torch.from_numpy(waveform_np) - waveform_tensor = waveform_tensor.reshape(-1, n_channels) - - # Transpose to (C, T) - waveform_tensor = waveform_tensor.transpose(0, 1) - - waveforms.append(waveform_tensor) - - except wave.Error as e: - raise APIExecutionError( - "Failed to read audio data as WAV file. The API might have returned a different format, which requires ffmpeg." - ) from e - except Exception as e: - raise APIExecutionError( - f"An unexpected error occurred while processing audio sample {n}: {e}" - ) from e if not waveforms: raise APIExecutionError( diff --git a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf index f82e0126a..9eb0e8961 100644 --- a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf +++ b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf @@ -25,6 +25,7 @@ resource "google_project_service" "cluster" { "aiplatform.googleapis.com", "cloudbuild.googleapis.com", "iap.googleapis.com", + "texttospeech.googleapis.com", ]) disable_dependent_services = false diff --git a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json new file mode 100644 index 000000000..f339b21d6 --- /dev/null +++ b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json @@ -0,0 +1,102 @@ +{ + "id": "ea4eda08-0e1b-4856-8417-e7ef5d3bafbb", + "revision": 0, + "last_node_id": 2, + "last_link_id": 1, + "nodes": [ + { + "id": 2, + "type": "PreviewAudio", + "pos": [ + 616.00390625, + 191.57421875 + ], + "size": [ + 270, + 88 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + { + "name": "audio", + "type": "AUDIO", + "link": 1 + } + ], + "outputs": [], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.66", + "Node name for S&R": "PreviewAudio" + }, + "widgets_values": [] + }, + { + "id": 1, + "type": "Chirp3Node", + "pos": [ + 97.59765625, + 177.49609375 + ], + "size": [ + 400, + 256 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "audio", + "type": "AUDIO", + "links": [ + 1 + ] + } + ], + "properties": { + "Node name for S&R": "Chirp3Node" + }, + "widgets_values": [ + "Hello! I am a generative voice, designed by Google.", + "en-US", + "Achernar (Female)", + 24000, + 1, + 0, + "", + "" + ] + } + ], + "links": [ + [ + 1, + 1, + 0, + 2, + 0, + "AUDIO" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1, + "offset": [ + 0, + 0 + ] + }, + "frontendVersion": "1.28.7", + "VHS_latentpreview": false, + "VHS_latentpreviewrate": 0, + "VHS_MetadataImage": true, + "VHS_KeepIntermediate": true + }, + "version": 0.4 +} diff --git a/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json b/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json new file mode 100644 index 000000000..87bdff047 --- /dev/null +++ b/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json @@ -0,0 +1,31 @@ +{ + "1": { + "inputs": { + "text": "Hello! I am a generative voice, designed by Google.", + "language_code": "ar-XA", + "voice_name": "Achernar (Female)", + "sample_rate": 24000, + "speed": 1, + "volume_gain_db": 0, + "gcp_project_id": "", + "gcp_region": "" + }, + "class_type": "Chirp3Node", + "_meta": { + "title": "Chirp3 HD" + } + }, + "2": { + "inputs": { + "audioUI": "", + "audio": [ + "1", + 0 + ] + }, + "class_type": "PreviewAudio", + "_meta": { + "title": "PreviewAudio" + } + } +}