From 36a189c5c3aa48ac336371bfd152fd1b7b361d39 Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 07:20:07 +0000 Subject: [PATCH 1/8] initial chirp3hd node commit --- .../custom_nodes/google_genmedia/__init__.py | 9 +- .../google_genmedia/chirp3hd_api.py | 125 +++++++++ .../google_genmedia/chirp3hd_node.py | 177 ++++++++++++ .../custom_nodes/google_genmedia/constants.py | 2 + .../google_genmedia/lyria2_api.py | 40 ++- .../google_genmedia/requirements.txt | 1 + .../src/custom_nodes/google_genmedia/utils.py | 265 +++++++++++++----- 7 files changed, 544 insertions(+), 75 deletions(-) create mode 100644 modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py create mode 100644 modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py diff --git a/modules/python/src/custom_nodes/google_genmedia/__init__.py b/modules/python/src/custom_nodes/google_genmedia/__init__.py index 336febe3e..56d4c2c05 100644 --- a/modules/python/src/custom_nodes/google_genmedia/__init__.py +++ b/modules/python/src/custom_nodes/google_genmedia/__init__.py @@ -87,7 +87,12 @@ def setup_custom_package_logger(): setup_custom_package_logger() - +from .chirp3hd_node import ( + NODE_CLASS_MAPPINGS as CHIRP3_HD_NODE_CLASS_MAPPINGS, +) +from .chirp3hd_node import ( + NODE_DISPLAY_NAME_MAPPINGS as CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS, +) from .gemini_flash_image_node import ( NODE_CLASS_MAPPINGS as GEMINI_FLASH_25_IMAGE_NODE_CLASS_MAPPINGS, ) @@ -123,6 +128,7 @@ def setup_custom_package_logger(): # Combine all node class mappings NODE_CLASS_MAPPINGS = { + **CHIRP3_HD_NODE_CLASS_MAPPINGS, **IMAGEN3_NODE_CLASS_MAPPINGS, **IMAGEN4_NODE_CLASS_MAPPINGS, **LYRIA2_NODE_CLASS_MAPPINGS, @@ -136,6 +142,7 @@ def setup_custom_package_logger(): # Combine all node display name mappings NODE_DISPLAY_NAME_MAPPINGS = { + **CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS, **IMAGEN3_NODE_DISPLAY_NAME_MAPPINGS, **IMAGEN4_NODE_DISPLAY_NAME_MAPPINGS, **LYRIA2_NODE_DISPLAY_NAME_MAPPINGS, diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py new file mode 100644 index 000000000..281984a44 --- /dev/null +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Optional + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import texttospeech + +from . import utils +from .base import VertexAIClient +from .constants import CHIRP3_USER_AGENT +from .custom_exceptions import ConfigurationError + +logger = logging.getLogger(__name__) + + +def classify_string(text: str) -> dict: + """ + Classifies a string as 'ssml', 'markup', or 'text'. + + Args: + text: The string to classify. + + Returns: + The dictionary with classification. + """ + if re.search(r"<.*?>", text): + return {"ssml": text} + elif re.search(r"\[.*?\]", text): + return {"markup": text} + else: + return {"text": text} + + +class Chirp3API(VertexAIClient): + """ + Handles all communication with the Google Cloud Text-to-Speech API for Chirp3 models. + """ + + def __init__( + self, + project_id: Optional[str] = None, + region: Optional[str] = None, + user_agent: Optional[str] = CHIRP3_USER_AGENT, + ): + """ + Initializes the TextToSpeechClient. + + Args: + project_id: The GCP project ID. Overrides metadata lookup. + region: The GCP region. Overrides metadata lookup. + user_agent: The user agent to use for the client. + + Raises: + ConfigurationError: If client initialization fails. + """ + super().__init__(gcp_project_id=project_id, gcp_region=region) + + try: + client_info = ClientInfo(user_agent=user_agent) if user_agent else None + self.client = texttospeech.TextToSpeechClient(client_info=client_info) + print(f"Initialized Google TTS Client. with User-Agent: {user_agent}") + except Exception as e: + raise ConfigurationError( + "Failed to initialize Google TTS Client. " + "Please ensure you are authenticated (e.g., `gcloud auth application-default login`). " + f"Error: {e}" + ) + + def synthesize( + self, + language_code: str, + sample_rate: int, + speed: float, + text: str, + voice_name: str, + volume_gain_db: float, + ) -> tuple[bytes, int]: + """ + Synthesizes speech and returns the raw audio binary content and sample rate. + + Args: + text: The text to synthesize. + language_code: The language code for the voice. + voice_name: The name of the voice to use. + sample_rate: The desired sample rate of the audio. + speed: The speaking rate of the synthesized speech. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing the raw audio content (bytes) and the sample rate (int). + + + Raises: + APIInputError: If input parameters are invalid. + APIExecutionError: If music generation fails due to API or unexpected issues. + """ + + synthesis_input_params = classify_string(text) + voice_params = {"language_code": language_code, "name": voice_name} + + logger.info(f" - Synthesis Input: {synthesis_input_params}") + logger.info(f" - Voice Params: {voice_params}") + + return utils.generate_speech_from_text( + client=self.client, + sample_rate=sample_rate, + speed=speed, + synthesis_input_params=synthesis_input_params, + voice_params=voice_params, + volume_gain_db=volume_gain_db, + ) diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py new file mode 100644 index 000000000..17df4a025 --- /dev/null +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py @@ -0,0 +1,177 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +from .chirp3hd_api import Chirp3API +from .constants import CHIRP3_HD_MODEL +from .custom_exceptions import APIExecutionError, ConfigurationError +from .utils import get_tts_voices_and_languages, load_audio_from_bytes + + +class Chirp3Node: + """ + A ComfyUI node for Google's Text-to-Speech API, specifically for the Chirp3 HD model. + It synthesizes text into speech and returns an audio tensor. + """ + + def __init__(self) -> None: + """ + Initializes the Chirp3Node. + """ + pass + + @classmethod + def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: + """ + Defines the input types and widgets for the ComfyUI node. + + Returns: + A dictionary specifying the required and optional input parameters. + """ + dynamic_voices, dynamic_langs, voice_map = get_tts_voices_and_languages( + model_to_include=CHIRP3_HD_MODEL + ) + cls.voice_id_map = voice_map + return { + "required": { + "text": ( + "STRING", + { + "label_on": True, + "multiline": True, + "default": "Hello! I am a generative voice, designed by Google.", + "placeholder": "Enter the text to synthesize...", + }, + ), + "language_code": ( + dynamic_langs, + { + "default": ( + dynamic_langs[16] if len(dynamic_langs) > 16 else "en-US" + ) + }, + ), + "voice_name": ( + dynamic_voices, + { + "default": ( + dynamic_voices[0] + if dynamic_voices + else "en-US-Chirp3-HD-Charon" + ) + }, + ), + "sample_rate": ( + "INT", + {"default": 24000, "min": 8000, "max": 48000, "step": 10}, + ), + "speed": ( + "FLOAT", + {"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.05}, + ), + "volume_gain_db": ( + "FLOAT", + {"default": 0.0, "min": -96.0, "max": 16.0, "step": 0.1}, + ), + }, + "optional": { + "gcp_project_id": ( + "STRING", + {"default": "", "placeholder": "your-gcp-project-id"}, + ), + "gcp_region": ( + "STRING", + {"default": "", "placeholder": "us-central1"}, + ), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + FUNCTION = "execute_synthesis" + CATEGORY = "Google AI/Text-to-Speech" + + def execute_synthesis( + self, + gcp_project_id: Optional[str], + gcp_region: Optional[str], + language_code: str, + sample_rate: int, + speed: float, + text: str, + voice_name: str, + volume_gain_db: float, + ) -> Tuple[Dict[str, Any],]: + """ + Executes the text-to-speech synthesis process using the Chirp3 HD model. + + Args: + gcp_project_id: The GCP project ID. + gcp_region: The GCP region. + language_code: The language and region of the voice. + sample_rate: The desired sample rate for the audio. + speed: The speaking rate of the audio. + text: The text to be synthesized into speech. + voice_name: The name of the voice to use. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing a dictionary with the audio waveform and sample rate. + + Raises: + ConfigurationError: If the input text is empty. + RuntimeError: For errors during API communication or audio processing. + """ + short_voice_name = self.voice_id_map.get(voice_name) + if not short_voice_name: + raise ConfigurationError( + f"Voice ID lookup failed for selected voice: {voice_name}. " + "The voice list may be corrupt or outdated." + ) + if not text or not text.strip(): + raise ConfigurationError("Text input cannot be empty.") + + # Reconstruct the full voice name as per user's specified logic + full_voice_name = f"{language_code}-{CHIRP3_HD_MODEL}-{short_voice_name}" + + try: + api_client = Chirp3API(project_id=gcp_project_id, region=gcp_region) + + audio_data_binary, _ = api_client.synthesize( + language_code=language_code, + sample_rate=sample_rate, + speed=speed, + text=text, + voice_name=full_voice_name, + volume_gain_db=volume_gain_db, + ) + + if audio_data_binary is None or len(audio_data_binary) == 0: + raise APIExecutionError("API call returned no audio data.") + + waveform, r_sample_rate = load_audio_from_bytes(audio_data_binary) + output_audio = { + "waveform": waveform.unsqueeze(0), + "sample_rate": r_sample_rate, + } + return (output_audio,) + + except (ConfigurationError, APIExecutionError) as e: + raise RuntimeError(str(e)) from e + except Exception as e: + raise RuntimeError(f"An unexpected error occurred: {e}") from e + + +NODE_CLASS_MAPPINGS = {"Chirp3Node": Chirp3Node} +NODE_DISPLAY_NAME_MAPPINGS = {"Chirp3Node": "Chirp3 HD"} diff --git a/modules/python/src/custom_nodes/google_genmedia/constants.py b/modules/python/src/custom_nodes/google_genmedia/constants.py index 6ebf9f498..f8ace5bd6 100644 --- a/modules/python/src/custom_nodes/google_genmedia/constants.py +++ b/modules/python/src/custom_nodes/google_genmedia/constants.py @@ -19,6 +19,8 @@ from google.genai import types AUDIO_MIME_TYPES = ["audio/mp3", "audio/wav", "audio/mpeg"] +CHIRP3_HD_MODEL = "Chirp3-HD" +CHIRP3_USER_AGENT = "cloud-solutions/comfyui-chirp3-custom-node-v1" GEMINI_USER_AGENT = "cloud-solutions/comfyui-gemini-custom-node-v1" GEMINI_25_FLASH_IMAGE_ASPECT_RATIO = [ "1:1", diff --git a/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py b/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py index d19cf5d5e..6a9f81197 100644 --- a/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py +++ b/modules/python/src/custom_nodes/google_genmedia/lyria2_api.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This is a preview version of veo2 custom node # Copyright 2025 Google LLC # (license header) - +import base64 from typing import Optional from google.api_core.gapic_v1.client_info import ClientInfo @@ -24,7 +23,7 @@ from . import utils from .base import VertexAIClient from .constants import LYRIA2_MODEL, LYRIA2_USER_AGENT -from .custom_exceptions import APIExecutionError, APIInputError, ConfigurationError +from .custom_exceptions import APIExecutionError, ConfigurationError from .logger import get_node_logger from .retry import api_error_retry @@ -89,19 +88,46 @@ def generate_music_from_text( APIInputError: If input parameters are invalid. APIExecutionError: If music generation fails due to API or unexpected issues. """ + waveforms = [] + sample_rate = None instance = {"prompt": str(prompt)} if negative_prompt: instance["negative_prompt"] = str(negative_prompt) if seed > 0: instance["seed"] = seed instance["sample_count"] = 1 - logger.info("Seed is greater than 0, setting sample_count to 1.") + logger.info( + "Lyria Node: Seed is greater than 0, setting sample_count to 1." + ) else: instance["sample_count"] = sample_count - logger.info(f"Instance: {instance}") + logger.info(f"Lyria Node: Instance: {instance}") response = self.client.predict( endpoint=self.model_endpoint, instances=[instance] ) - logger.info(f"Response received from model: {response.model_display_name}") + logger.info( + f"Lyria Node: Response received from model: {response.model_display_name}" + ) + for n, prediction in enumerate(response.predictions): + prediction_dict = dict(prediction) + audio_bytes = base64.b64decode(prediction_dict["bytesBase64Encoded"]) + + try: + waveform, current_sample_rate = utils.load_audio_from_bytes(audio_bytes) + + if sample_rate is None: + sample_rate = current_sample_rate + elif sample_rate != current_sample_rate: + logger.warning( + f"Mismatch in sample rates at index {n}. Expected {sample_rate}, got {current_sample_rate}. " + "Proceeding, but this might cause playback issues." + ) + + waveforms.append(waveform) + + except Exception as e: + raise APIExecutionError( + f"An unexpected error occurred while processing audio sample {n}: {e}" + ) from e - return utils.process_audio_response(response) + return utils.process_audio_response(waveforms, sample_rate) diff --git a/modules/python/src/custom_nodes/google_genmedia/requirements.txt b/modules/python/src/custom_nodes/google_genmedia/requirements.txt index 79c4929f8..a864d321a 100644 --- a/modules/python/src/custom_nodes/google_genmedia/requirements.txt +++ b/modules/python/src/custom_nodes/google_genmedia/requirements.txt @@ -19,5 +19,6 @@ google-cloud-aiplatform==1.111.0 google-cloud-storage==2.19.0 google-genai==1.46.0 google-generativeai==0.8.5 +google-cloud-texttospeech==2.31.0 moviepy==2.2.1 opencv-python-headless==4.11.0.86 diff --git a/modules/python/src/custom_nodes/google_genmedia/utils.py b/modules/python/src/custom_nodes/google_genmedia/utils.py index 57f74ec45..2ee6c20bf 100644 --- a/modules/python/src/custom_nodes/google_genmedia/utils.py +++ b/modules/python/src/custom_nodes/google_genmedia/utils.py @@ -21,7 +21,7 @@ import time import wave from io import BytesIO -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import folder_paths import numpy as np @@ -29,14 +29,14 @@ from google import genai from google.api_core import exceptions as api_core_exceptions from google.api_core.client_info import ClientInfo -from google.cloud import storage +from google.cloud import storage, texttospeech from google.genai import errors as genai_errors from google.genai import types from google.genai.types import GenerateVideosConfig, Image from grpc import StatusCode from PIL import Image as PIL_Image -from .constants import STORAGE_USER_AGENT +from .constants import CHIRP3_HD_MODEL, STORAGE_USER_AGENT from .custom_exceptions import APIExecutionError, APIInputError from .logger import get_node_logger from .retry import api_error_retry @@ -195,6 +195,51 @@ def generate_image_from_text( return generated_pil_images +@api_error_retry +def generate_speech_from_text( + client: texttospeech.TextToSpeechClient, + sample_rate: int, + speed: float, + synthesis_input_params: dict, + voice_params: dict, + volume_gain_db: float, +) -> tuple[bytes, int]: + """ + Synthesizes speech and returns the raw audio binary content and sample rate. + + Args: + client: The TextToSpeechClient instance. + language_code: The language code for the voice. + sample_rate: The desired sample rate of the audio. + speed: The speaking rate of the synthesized speech. + synthesize: Synthesis input. + voice_params: Voice selection params. + volume_gain_db: The volume gain in dB. + + Returns: + A tuple containing the raw audio content (bytes) and the sample rate (int). + + Raises: + APIInputError: If the specified `file_path` does not exist. + APIExecutionError: If an error occurs during the file reading or conversion process. + """ + synthesis_input = texttospeech.SynthesisInput(**synthesis_input_params) + voice = texttospeech.VoiceSelectionParams(**voice_params) + audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.LINEAR16, + sample_rate_hertz=sample_rate, + speaking_rate=speed, + volume_gain_db=volume_gain_db, + ) + print(f" - Audio Config: {audio_config}") + + response = client.synthesize_speech( + input=synthesis_input, voice=voice, audio_config=audio_config + ) + print(f" - Received audio content of length: {len(response.audio_content)} bytes") + return response.audio_content, sample_rate + + @api_error_retry def generate_video_from_gcsuri_image( client: genai.Client, @@ -670,6 +715,153 @@ def generate_video_from_text( return process_video_response(operation) +def get_tts_voices_and_languages( + model_to_include: Optional[str] = None, +) -> Tuple[List[str], List[str], Dict[str, str]]: + """ + Fetches and filters Google TTS voices, caching the master list in memory + to ensure the API is only called once per session. + + Args: + model_to_include: If provided, only voices containing this string will be returned. + If None, voices containing CHIRP3_HD_MODEL will be excluded. + + Returns: + A tuple of (voice_display_names, language_codes, voice_id_map). + """ + if not hasattr(get_tts_voices_and_languages, "all_voices"): + try: + logger.info("Fetching all TTS voices from Google API for caching...") + client = texttospeech.TextToSpeechClient() + response = client.list_voices() + get_tts_voices_and_languages.all_voices = response.voices + logger.info(f"Successfully cached {len(response.voices)} voices from API.") + except Exception as e: + logger.warning(f"Failed to fetch TTS lists from API (likely no auth): {e}") + get_tts_voices_and_languages.all_voices = [] # Cache empty list on failure + + all_voices = get_tts_voices_and_languages.all_voices + + if not all_voices: + return ["Error"], ["en-US"], {"Error": "en-US-Standard-A"} + + voice_map = {} + lang_set = set() + gender_map = { + texttospeech.SsmlVoiceGender.MALE: "Male", + texttospeech.SsmlVoiceGender.FEMALE: "Female", + texttospeech.SsmlVoiceGender.NEUTRAL: "Neutral", + } + + filter_log = ( + f"including '{model_to_include}'" + if model_to_include + else f"excluding '{CHIRP3_HD_MODEL}'" + ) + + for voice in all_voices: + if model_to_include: + if model_to_include not in voice.name: + continue + short_name = voice.name.split("-")[-1] + gender_raw = getattr( + voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL + ) + gender = gender_map.get(gender_raw, "Neutral") + display_name = f"{short_name} ({gender})" + voice_map[display_name] = short_name + else: # Otherwise, exclude Chirp and use the full name + if CHIRP3_HD_MODEL in voice.name: + continue + # For Gemini, use the full name and map to the full name. + # e.g., "gemini... (Female)" -> "gemini..." + gender_raw = getattr( + voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL + ) + gender = gender_map.get(gender_raw, "Neutral") + display_name = f"{voice.name} ({gender})" + voice_map[display_name] = voice.name + + lang_set.update(voice.language_codes) + + voice_list = sorted(list(voice_map.keys())) + lang_list = sorted(list(lang_set)) + + # Handle case where filter results in an empty list + if not voice_list: + logger.warning(f"TTS voice filter ({filter_log}) resulted in an empty list.") + return ["(No voices found)"], ["en-US"], {} + + logger.info( + f"Filtered to {len(voice_list)} voices and {len(lang_list)} languages ({filter_log})." + ) + return voice_list, lang_list, voice_map + + +def load_audio_from_bytes(audio_bytes: bytes) -> Tuple[torch.Tensor, int]: + """ + Loads audio from WAV-formatted bytes into a normalized torch tensor. + Supports 8-bit (unsigned) and 16-bit (signed) WAV data. + + Args: + audio_bytes: The raw audio data in WAV format. + + Returns: + A tuple containing: + - waveform (torch.Tensor): Audio data with shape (Channels, Time). + - sample_rate (int): The sample rate of the audio. + + Raises: + APIExecutionError: If the audio data cannot be read or has an unsupported format. + """ + buffer = io.BytesIO(audio_bytes) + + try: + with wave.open(buffer, "rb") as wf: + sample_rate = wf.getframerate() + n_channels = wf.getnchannels() + sampwidth = wf.getsampwidth() + n_frames = wf.getnframes() + + frames = wf.readframes(n_frames) + + if sampwidth == 2: + dtype = np.int16 + elif sampwidth == 1: + dtype = np.uint8 # 8-bit WAV is typically unsigned + else: + raise APIExecutionError( + f"Unsupported sample width for WAV: {sampwidth} bytes. Only 8-bit and 16-bit are supported by this implementation." + ) + + waveform_np = np.frombuffer(frames, dtype=dtype) + + if dtype == np.int16: + waveform_np = waveform_np.astype(np.float32) / 32768.0 + elif dtype == np.uint8: + waveform_np = (waveform_np.astype(np.float32) - 128.0) / 128.0 + + waveform_tensor = torch.from_numpy(waveform_np) + + if n_channels > 1: + waveform_tensor = waveform_tensor.reshape(-1, n_channels).transpose( + 0, 1 + ) + else: + waveform_tensor = waveform_tensor.reshape(1, -1) + + return waveform_tensor, sample_rate + + except wave.Error as e: + raise APIExecutionError( + "Failed to read audio data as WAV. The data might be in a different format." + ) from e + except Exception as e: + raise APIExecutionError( + f"An unexpected error occurred while loading audio bytes: {e}" + ) from e + + def media_file_to_genai_part(file_path: str, mime_type: str) -> types.Part: """Reads a media file (image, audio, or video) and converts it to a genai.types.Part. @@ -731,7 +923,7 @@ def prep_for_media_conversion(file_path: str, mime_type: str) -> Optional[types. return None # Return None if file not found -def process_audio_response(response: Any) -> dict: +def process_audio_response(waveforms: Any, sample_rate: int) -> dict: """ Processes the audio generation response, loads the audio into a tensor, and returns it in the format expected by ComfyUI's AUDIO output. @@ -739,7 +931,8 @@ def process_audio_response(response: Any) -> dict: It assumes the audio from the API is in WAV format. Args: - response: The completed response object from the Lyria API. + waveforms: The completed response object from the Lyria API. + sample_rate: The sample rate of the audio. Returns: A dictionary containing the audio waveform as a torch.Tensor @@ -748,68 +941,6 @@ def process_audio_response(response: Any) -> dict: Raises: APIExecutionError: If no audio data is found or if loading fails. """ - if not response.predictions: - raise APIExecutionError("No predictions found in the API response.") - - waveforms = [] - sample_rate = None - - logger.info(f"Found {len(response.predictions)} audio clips to process.") - for n, prediction in enumerate(response.predictions): - prediction_dict = dict(prediction) - audio_bytes = base64.b64decode(prediction_dict["bytesBase64Encoded"]) - - buffer = io.BytesIO(audio_bytes) - - try: - with wave.open(buffer, "rb") as wf: - if sample_rate is None: - sample_rate = wf.getframerate() - elif sample_rate != wf.getframerate(): - logger.warning( - f"Mismatch in sample rates. Expected {sample_rate}, got {wf.getframerate()}. Using the first sample rate." - ) - - n_channels = wf.getnchannels() - sampwidth = wf.getsampwidth() - n_frames = wf.getnframes() - - frames = wf.readframes(n_frames) - - if sampwidth == 2: - dtype = np.int16 - elif sampwidth == 1: - dtype = np.uint8 # 8-bit is usually unsigned - else: - raise APIExecutionError( - f"Unsupported sample width for WAV: {sampwidth} bytes. Only 8-bit and 16-bit are supported by this implementation." - ) - - waveform_np = np.frombuffer(frames, dtype=dtype) - - # Normalize - if dtype == np.int16: - waveform_np = waveform_np.astype(np.float32) / 32768.0 - elif dtype == np.uint8: - waveform_np = (waveform_np.astype(np.float32) - 128.0) / 128.0 - - # Reshape to (T, C) and then convert to tensor - waveform_tensor = torch.from_numpy(waveform_np) - waveform_tensor = waveform_tensor.reshape(-1, n_channels) - - # Transpose to (C, T) - waveform_tensor = waveform_tensor.transpose(0, 1) - - waveforms.append(waveform_tensor) - - except wave.Error as e: - raise APIExecutionError( - "Failed to read audio data as WAV file. The API might have returned a different format, which requires ffmpeg." - ) from e - except Exception as e: - raise APIExecutionError( - f"An unexpected error occurred while processing audio sample {n}: {e}" - ) from e if not waveforms: raise APIExecutionError( From a1d88e5cd8cf8404d818d017fdf09fd9acd8c6c5 Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 07:28:44 +0000 Subject: [PATCH 2/8] format --- modules/python/src/custom_nodes/google_genmedia/__init__.py | 4 +--- .../python/src/custom_nodes/google_genmedia/chirp3hd_node.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/python/src/custom_nodes/google_genmedia/__init__.py b/modules/python/src/custom_nodes/google_genmedia/__init__.py index 56d4c2c05..7b5fe83b0 100644 --- a/modules/python/src/custom_nodes/google_genmedia/__init__.py +++ b/modules/python/src/custom_nodes/google_genmedia/__init__.py @@ -87,9 +87,7 @@ def setup_custom_package_logger(): setup_custom_package_logger() -from .chirp3hd_node import ( - NODE_CLASS_MAPPINGS as CHIRP3_HD_NODE_CLASS_MAPPINGS, -) +from .chirp3hd_node import NODE_CLASS_MAPPINGS as CHIRP3_HD_NODE_CLASS_MAPPINGS from .chirp3hd_node import ( NODE_DISPLAY_NAME_MAPPINGS as CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS, ) diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py index 17df4a025..d60dd7e21 100644 --- a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py @@ -100,7 +100,7 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "execute_synthesis" - CATEGORY = "Google AI/Text-to-Speech" + CATEGORY = "Google AI/Chirp3" def execute_synthesis( self, From 338fb9da1ceba0f3e2a357cb57db571ba23fbdfb Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 02:37:19 -0500 Subject: [PATCH 3/8] Add 'ssml' and 'texttospeech' to Google Cloud dictionary --- .github/workflows/dictionary/google-cloud.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/dictionary/google-cloud.txt b/.github/workflows/dictionary/google-cloud.txt index a10355408..55a3f33ba 100644 --- a/.github/workflows/dictionary/google-cloud.txt +++ b/.github/workflows/dictionary/google-cloud.txt @@ -95,11 +95,13 @@ servicemanagement servicemesh servicenetworking serviceusage +ssml stackdriver statefulset subnetwork subnetworks superblocks +texttospeech timepicker uefi ultragpu From 4065920ed751e94fb619451e19ee88affd0f7184 Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 05:11:24 -0500 Subject: [PATCH 4/8] Add texttospeech API to project configuration --- .../use-cases/inference-ref-arch/terraform/comfyui/project.tf | 1 + 1 file changed, 1 insertion(+) diff --git a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf index f82e0126a..9eb0e8961 100644 --- a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf +++ b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/project.tf @@ -25,6 +25,7 @@ resource "google_project_service" "cluster" { "aiplatform.googleapis.com", "cloudbuild.googleapis.com", "iap.googleapis.com", + "texttospeech.googleapis.com", ]) disable_dependent_services = false From 744636d05338c38d02b8cf94074a5b2d8a1f8409 Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 05:11:24 -0500 Subject: [PATCH 5/8] Add texttospeech API to project configuration --- .../python/src/custom_nodes/google_genmedia/chirp3hd_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py index d60dd7e21..29184c6b8 100644 --- a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py @@ -58,7 +58,7 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: dynamic_langs, { "default": ( - dynamic_langs[16] if len(dynamic_langs) > 16 else "en-US" + dynamic_langs[0] if len(dynamic_langs) > 16 else "en-US" ) }, ), @@ -68,7 +68,7 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: "default": ( dynamic_voices[0] if dynamic_voices - else "en-US-Chirp3-HD-Charon" + else "Charon" ) }, ), From 49928588414c8dc1aa8421f451fa456d6efb0ffb Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Thu, 13 Nov 2025 05:11:24 -0500 Subject: [PATCH 6/8] Add texttospeech API to project configuration --- .../src/custom_nodes/google_genmedia/utils.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/modules/python/src/custom_nodes/google_genmedia/utils.py b/modules/python/src/custom_nodes/google_genmedia/utils.py index 2ee6c20bf..ad8786fc3 100644 --- a/modules/python/src/custom_nodes/google_genmedia/utils.py +++ b/modules/python/src/custom_nodes/google_genmedia/utils.py @@ -743,8 +743,28 @@ def get_tts_voices_and_languages( all_voices = get_tts_voices_and_languages.all_voices if not all_voices: - return ["Error"], ["en-US"], {"Error": "en-US-Standard-A"} + return ["(No voices found)"], ["en-US"], {} + + # Explicitly filter the voices first + if model_to_include: + # If a model is specified, include only voices for that model. + filtered_voices = [ + voice for voice in all_voices if model_to_include in voice.name + ] + filter_log = f"including '{CHIRP3_HD_MODEL}'" + else: + # Otherwise, exclude the Chirp3 HD voices by default. + filtered_voices = [ + voice for voice in all_voices if CHIRP3_HD_MODEL not in voice.name + ] + filter_log = f"excluding '{CHIRP3_HD_MODEL}'" + + # Handle case where filter results in an empty list + if not filtered_voices: + logger.warning(f"TTS voice filter ({filter_log}) resulted in an empty list.") + return ["(No voices found)"], ["en-US"], {} + # Process the filtered list to create the final outputs voice_map = {} lang_set = set() gender_map = { @@ -753,32 +773,17 @@ def get_tts_voices_and_languages( texttospeech.SsmlVoiceGender.NEUTRAL: "Neutral", } - filter_log = ( - f"including '{model_to_include}'" - if model_to_include - else f"excluding '{CHIRP3_HD_MODEL}'" - ) + for voice in filtered_voices: + gender_raw = getattr(voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL) + gender = gender_map.get(gender_raw, "Neutral") - for voice in all_voices: if model_to_include: - if model_to_include not in voice.name: - continue + # For included models (like Chirp), use a short name. short_name = voice.name.split("-")[-1] - gender_raw = getattr( - voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL - ) - gender = gender_map.get(gender_raw, "Neutral") display_name = f"{short_name} ({gender})" voice_map[display_name] = short_name - else: # Otherwise, exclude Chirp and use the full name - if CHIRP3_HD_MODEL in voice.name: - continue - # For Gemini, use the full name and map to the full name. - # e.g., "gemini... (Female)" -> "gemini..." - gender_raw = getattr( - voice, "ssml_gender", texttospeech.SsmlVoiceGender.NEUTRAL - ) - gender = gender_map.get(gender_raw, "Neutral") + else: + # For other voices, use the full name. display_name = f"{voice.name} ({gender})" voice_map[display_name] = voice.name @@ -787,11 +792,6 @@ def get_tts_voices_and_languages( voice_list = sorted(list(voice_map.keys())) lang_list = sorted(list(lang_set)) - # Handle case where filter results in an empty list - if not voice_list: - logger.warning(f"TTS voice filter ({filter_log}) resulted in an empty list.") - return ["(No voices found)"], ["en-US"], {} - logger.info( f"Filtered to {len(voice_list)} voices and {len(lang_list)} languages ({filter_log})." ) From 5a069ec31d503139cb202b351735e6a676c24e1e Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Fri, 14 Nov 2025 23:30:03 +0000 Subject: [PATCH 7/8] Add workflow --- .../chirp3-text-to-speech.json | 102 ++++++++++++++++++ .../workflows/chirp3-text-to-speech.json | 31 ++++++ 2 files changed, 133 insertions(+) create mode 100644 platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json create mode 100644 test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json diff --git a/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json new file mode 100644 index 000000000..f339b21d6 --- /dev/null +++ b/platforms/gke/base/use-cases/inference-ref-arch/terraform/comfyui/src/comfyui-workflows/chirp3-text-to-speech.json @@ -0,0 +1,102 @@ +{ + "id": "ea4eda08-0e1b-4856-8417-e7ef5d3bafbb", + "revision": 0, + "last_node_id": 2, + "last_link_id": 1, + "nodes": [ + { + "id": 2, + "type": "PreviewAudio", + "pos": [ + 616.00390625, + 191.57421875 + ], + "size": [ + 270, + 88 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + { + "name": "audio", + "type": "AUDIO", + "link": 1 + } + ], + "outputs": [], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.66", + "Node name for S&R": "PreviewAudio" + }, + "widgets_values": [] + }, + { + "id": 1, + "type": "Chirp3Node", + "pos": [ + 97.59765625, + 177.49609375 + ], + "size": [ + 400, + 256 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "audio", + "type": "AUDIO", + "links": [ + 1 + ] + } + ], + "properties": { + "Node name for S&R": "Chirp3Node" + }, + "widgets_values": [ + "Hello! I am a generative voice, designed by Google.", + "en-US", + "Achernar (Female)", + 24000, + 1, + 0, + "", + "" + ] + } + ], + "links": [ + [ + 1, + 1, + 0, + 2, + 0, + "AUDIO" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1, + "offset": [ + 0, + 0 + ] + }, + "frontendVersion": "1.28.7", + "VHS_latentpreview": false, + "VHS_latentpreviewrate": 0, + "VHS_MetadataImage": true, + "VHS_KeepIntermediate": true + }, + "version": 0.4 +} diff --git a/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json b/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json new file mode 100644 index 000000000..87bdff047 --- /dev/null +++ b/test/ci-cd/scripts/comfyui/workflows/chirp3-text-to-speech.json @@ -0,0 +1,31 @@ +{ + "1": { + "inputs": { + "text": "Hello! I am a generative voice, designed by Google.", + "language_code": "ar-XA", + "voice_name": "Achernar (Female)", + "sample_rate": 24000, + "speed": 1, + "volume_gain_db": 0, + "gcp_project_id": "", + "gcp_region": "" + }, + "class_type": "Chirp3Node", + "_meta": { + "title": "Chirp3 HD" + } + }, + "2": { + "inputs": { + "audioUI": "", + "audio": [ + "1", + 0 + ] + }, + "class_type": "PreviewAudio", + "_meta": { + "title": "PreviewAudio" + } + } +} From 231f2264416973e868b685442227de5d3df3f4e3 Mon Sep 17 00:00:00 2001 From: Ameenah Burhan Date: Fri, 14 Nov 2025 23:34:45 +0000 Subject: [PATCH 8/8] format --- .../src/custom_nodes/google_genmedia/chirp3hd_node.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py index 29184c6b8..f8ce37f20 100644 --- a/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py +++ b/modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py @@ -64,13 +64,7 @@ def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: ), "voice_name": ( dynamic_voices, - { - "default": ( - dynamic_voices[0] - if dynamic_voices - else "Charon" - ) - }, + {"default": (dynamic_voices[0] if dynamic_voices else "Charon")}, ), "sample_rate": ( "INT",