Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/dictionary/google-cloud.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ servicemanagement
servicemesh
servicenetworking
serviceusage
ssml
stackdriver
statefulset
subnetwork
subnetworks
superblocks
texttospeech
timepicker
uefi
ultragpu
Expand Down
7 changes: 6 additions & 1 deletion modules/python/src/custom_nodes/google_genmedia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def setup_custom_package_logger():


setup_custom_package_logger()

from .chirp3hd_node import NODE_CLASS_MAPPINGS as CHIRP3_HD_NODE_CLASS_MAPPINGS
from .chirp3hd_node import (
NODE_DISPLAY_NAME_MAPPINGS as CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS,
)
from .gemini_flash_image_node import (
NODE_CLASS_MAPPINGS as GEMINI_FLASH_25_IMAGE_NODE_CLASS_MAPPINGS,
)
Expand Down Expand Up @@ -123,6 +126,7 @@ def setup_custom_package_logger():

# Combine all node class mappings
NODE_CLASS_MAPPINGS = {
**CHIRP3_HD_NODE_CLASS_MAPPINGS,
**IMAGEN3_NODE_CLASS_MAPPINGS,
**IMAGEN4_NODE_CLASS_MAPPINGS,
**LYRIA2_NODE_CLASS_MAPPINGS,
Expand All @@ -136,6 +140,7 @@ def setup_custom_package_logger():

# Combine all node display name mappings
NODE_DISPLAY_NAME_MAPPINGS = {
**CHIRP3_HD_NODE_DISPLAY_NAME_MAPPINGS,
**IMAGEN3_NODE_DISPLAY_NAME_MAPPINGS,
**IMAGEN4_NODE_DISPLAY_NAME_MAPPINGS,
**LYRIA2_NODE_DISPLAY_NAME_MAPPINGS,
Expand Down
125 changes: 125 additions & 0 deletions modules/python/src/custom_nodes/google_genmedia/chirp3hd_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Optional

from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import texttospeech

from . import utils
from .base import VertexAIClient
from .constants import CHIRP3_USER_AGENT
from .custom_exceptions import ConfigurationError

logger = logging.getLogger(__name__)


def classify_string(text: str) -> dict:
"""
Classifies a string as 'ssml', 'markup', or 'text'.

Args:
text: The string to classify.

Returns:
The dictionary with classification.
"""
if re.search(r"<.*?>", text):
return {"ssml": text}
elif re.search(r"\[.*?\]", text):
return {"markup": text}
else:
return {"text": text}


class Chirp3API(VertexAIClient):
"""
Handles all communication with the Google Cloud Text-to-Speech API for Chirp3 models.
"""

def __init__(
self,
project_id: Optional[str] = None,
region: Optional[str] = None,
user_agent: Optional[str] = CHIRP3_USER_AGENT,
):
"""
Initializes the TextToSpeechClient.

Args:
project_id: The GCP project ID. Overrides metadata lookup.
region: The GCP region. Overrides metadata lookup.
user_agent: The user agent to use for the client.

Raises:
ConfigurationError: If client initialization fails.
"""
super().__init__(gcp_project_id=project_id, gcp_region=region)

try:
client_info = ClientInfo(user_agent=user_agent) if user_agent else None
self.client = texttospeech.TextToSpeechClient(client_info=client_info)
print(f"Initialized Google TTS Client. with User-Agent: {user_agent}")
except Exception as e:
raise ConfigurationError(
"Failed to initialize Google TTS Client. "
"Please ensure you are authenticated (e.g., `gcloud auth application-default login`). "
f"Error: {e}"
)

def synthesize(
self,
language_code: str,
sample_rate: int,
speed: float,
text: str,
voice_name: str,
volume_gain_db: float,
) -> tuple[bytes, int]:
"""
Synthesizes speech and returns the raw audio binary content and sample rate.

Args:
text: The text to synthesize.
language_code: The language code for the voice.
voice_name: The name of the voice to use.
sample_rate: The desired sample rate of the audio.
speed: The speaking rate of the synthesized speech.
volume_gain_db: The volume gain in dB.

Returns:
A tuple containing the raw audio content (bytes) and the sample rate (int).


Raises:
APIInputError: If input parameters are invalid.
APIExecutionError: If music generation fails due to API or unexpected issues.
"""

synthesis_input_params = classify_string(text)
voice_params = {"language_code": language_code, "name": voice_name}

logger.info(f" - Synthesis Input: {synthesis_input_params}")
logger.info(f" - Voice Params: {voice_params}")

return utils.generate_speech_from_text(
client=self.client,
sample_rate=sample_rate,
speed=speed,
synthesis_input_params=synthesis_input_params,
voice_params=voice_params,
volume_gain_db=volume_gain_db,
)
171 changes: 171 additions & 0 deletions modules/python/src/custom_nodes/google_genmedia/chirp3hd_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple

from .chirp3hd_api import Chirp3API
from .constants import CHIRP3_HD_MODEL
from .custom_exceptions import APIExecutionError, ConfigurationError
from .utils import get_tts_voices_and_languages, load_audio_from_bytes


class Chirp3Node:
"""
A ComfyUI node for Google's Text-to-Speech API, specifically for the Chirp3 HD model.
It synthesizes text into speech and returns an audio tensor.
"""

def __init__(self) -> None:
"""
Initializes the Chirp3Node.
"""
pass

@classmethod
def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
"""
Defines the input types and widgets for the ComfyUI node.

Returns:
A dictionary specifying the required and optional input parameters.
"""
dynamic_voices, dynamic_langs, voice_map = get_tts_voices_and_languages(
model_to_include=CHIRP3_HD_MODEL
)
cls.voice_id_map = voice_map
return {
"required": {
"text": (
"STRING",
{
"label_on": True,
"multiline": True,
"default": "Hello! I am a generative voice, designed by Google.",
"placeholder": "Enter the text to synthesize...",
},
),
"language_code": (
dynamic_langs,
{
"default": (
dynamic_langs[0] if len(dynamic_langs) > 16 else "en-US"
)
},
),
"voice_name": (
dynamic_voices,
{"default": (dynamic_voices[0] if dynamic_voices else "Charon")},
),
"sample_rate": (
"INT",
{"default": 24000, "min": 8000, "max": 48000, "step": 10},
),
"speed": (
"FLOAT",
{"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.05},
),
"volume_gain_db": (
"FLOAT",
{"default": 0.0, "min": -96.0, "max": 16.0, "step": 0.1},
),
},
"optional": {
"gcp_project_id": (
"STRING",
{"default": "", "placeholder": "your-gcp-project-id"},
),
"gcp_region": (
"STRING",
{"default": "", "placeholder": "us-central1"},
),
},
}

RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "execute_synthesis"
CATEGORY = "Google AI/Chirp3"

def execute_synthesis(
self,
gcp_project_id: Optional[str],
gcp_region: Optional[str],
language_code: str,
sample_rate: int,
speed: float,
text: str,
voice_name: str,
volume_gain_db: float,
) -> Tuple[Dict[str, Any],]:
"""
Executes the text-to-speech synthesis process using the Chirp3 HD model.

Args:
gcp_project_id: The GCP project ID.
gcp_region: The GCP region.
language_code: The language and region of the voice.
sample_rate: The desired sample rate for the audio.
speed: The speaking rate of the audio.
text: The text to be synthesized into speech.
voice_name: The name of the voice to use.
volume_gain_db: The volume gain in dB.

Returns:
A tuple containing a dictionary with the audio waveform and sample rate.

Raises:
ConfigurationError: If the input text is empty.
RuntimeError: For errors during API communication or audio processing.
"""
short_voice_name = self.voice_id_map.get(voice_name)
if not short_voice_name:
raise ConfigurationError(
f"Voice ID lookup failed for selected voice: {voice_name}. "
"The voice list may be corrupt or outdated."
)
if not text or not text.strip():
raise ConfigurationError("Text input cannot be empty.")

# Reconstruct the full voice name as per user's specified logic
full_voice_name = f"{language_code}-{CHIRP3_HD_MODEL}-{short_voice_name}"

try:
api_client = Chirp3API(project_id=gcp_project_id, region=gcp_region)

audio_data_binary, _ = api_client.synthesize(
language_code=language_code,
sample_rate=sample_rate,
speed=speed,
text=text,
voice_name=full_voice_name,
volume_gain_db=volume_gain_db,
)

if audio_data_binary is None or len(audio_data_binary) == 0:
raise APIExecutionError("API call returned no audio data.")

waveform, r_sample_rate = load_audio_from_bytes(audio_data_binary)
output_audio = {
"waveform": waveform.unsqueeze(0),
"sample_rate": r_sample_rate,
}
return (output_audio,)

except (ConfigurationError, APIExecutionError) as e:
raise RuntimeError(str(e)) from e
except Exception as e:
raise RuntimeError(f"An unexpected error occurred: {e}") from e


NODE_CLASS_MAPPINGS = {"Chirp3Node": Chirp3Node}
NODE_DISPLAY_NAME_MAPPINGS = {"Chirp3Node": "Chirp3 HD"}
2 changes: 2 additions & 0 deletions modules/python/src/custom_nodes/google_genmedia/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from google.genai import types

AUDIO_MIME_TYPES = ["audio/mp3", "audio/wav", "audio/mpeg"]
CHIRP3_HD_MODEL = "Chirp3-HD"
CHIRP3_USER_AGENT = "cloud-solutions/comfyui-chirp3-custom-node-v1"
GEMINI_USER_AGENT = "cloud-solutions/comfyui-gemini-custom-node-v1"
GEMINI_25_FLASH_IMAGE_ASPECT_RATIO = [
"1:1",
Expand Down
Loading
Loading