Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class MultiModalTextToTextExportableModule(torch.nn.Module):
Args:
model (torch.nn.Module): The multimodal model to export.
modality (str): The input modality type ("audio" or "vision").
encoder_name (str): Name of the encoder attribute in the model.
encoder_model (str): The encoder model within the mutlimodal model.
processor_config (dict, optional): Preprocessor configuration loaded from preprocessor_config.json.
use_custom_kv_cache (bool): Whether to use custom key-value caching for optimization.
use_custom_sdpa (bool): Whether to use custom scaled dot-product attention.
Expand All @@ -186,21 +186,18 @@ def __init__(
self,
model: torch.nn.Module,
modality: str,
encoder_name: str,
encoder_model: torch.nn.Module,
max_seq_len: int,
processor_config: dict = None,
use_custom_kv_cache: bool = False,
use_custom_sdpa: bool = False,
):
super().__init__()

if not hasattr(model, encoder_name):
raise ValueError(f'Model does not contain encoder "{encoder_name}".')

self.model = model
self.config = model.config
self.modality = modality
self.encoder_name = encoder_name
self.encoder_model = encoder_model
self.processor_config = processor_config
self.use_custom_kv_cache = use_custom_kv_cache
self.use_custom_sdpa = use_custom_sdpa
Expand Down Expand Up @@ -370,7 +367,7 @@ def export(

# 3. Export encoder.
if self.use_custom_sdpa:
getattr(self.model, self.encoder_name).config._attn_implementation = "custom_sdpa"
self.encoder_model.config._attn_implementation = "custom_sdpa"

if self.modality == "audio":
encoder = AudioExportableModule(self.model)
Expand Down
99 changes: 25 additions & 74 deletions optimum/exporters/executorch/tasks/multimodal_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,68 +25,6 @@
from ..task_registry import register_task


def _validate_multimodal_components(model):
"""
Validates that the multimodal model has required decoder and encoder components.

Args:
model: The loaded model instance

Returns:
tuple: (decoder_name, audio_encoder_name, vision_encoder_name)
"""
POTENTIAL_DECODER_NAMES = [
"language_model",
"text_model",
]
POTENTIAL_AUDIO_ENCODER_NAMES = [
"encoder", # Here mainly for Granite Speech.
"audio_tower",
"audio_model",
]
POTENTIAL_VISION_ENCODER_NAMES = [
"vision_tower",
"vision_model",
]

# Find decoder component
decoder_name = None
for name in POTENTIAL_DECODER_NAMES:
if hasattr(model, name):
decoder_name = name
break

if decoder_name is None:
raise ValueError(
"The model does not have any of the expected decoder attributes: "
f"{POTENTIAL_DECODER_NAMES}. This is required for multimodal text-to-text models."
)

# Find encoder components
audio_encoder_name = None
for name in POTENTIAL_AUDIO_ENCODER_NAMES:
if hasattr(model, name):
audio_encoder_name = name
break

vision_encoder_name = None
for name in POTENTIAL_VISION_ENCODER_NAMES:
if hasattr(model, name):
vision_encoder_name = name
break

if (audio_encoder_name is None) == (vision_encoder_name is None):
raise ValueError(
"The model does not have one of the expected encoder attributes: "
f"{POTENTIAL_AUDIO_ENCODER_NAMES + POTENTIAL_VISION_ENCODER_NAMES}. "
"This is required for multimodal text-to-text models."
"Currently only a maximum of 1 modality is supported, so there can only be one of these"
"encoders in the model."
)

return decoder_name, audio_encoder_name, vision_encoder_name


# NOTE: It's important to map the registered task name to the pipeline name in https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py.
# This will streamline using inferred task names and make exporting models to Hugging Face pipelines easier.
@register_task("image-text-to-text")
Expand Down Expand Up @@ -181,13 +119,22 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
"device": device,
},
)
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)
encoder_name = audio_encoder_name if audio_encoder_name else vision_encoder_name

# Find the modality (only one modality of either image/audio is supported at the moment).
if len(eager_model.input_modalities) != 2:
raise AttributeError(
"Only one modality is supported for multimodal models at the moment. The input modalities must be either ['text', 'image'] or ['text, 'audio']"
)
for input_modality in eager_model.input_modalities:
if input_modality == "text":
continue
modality = input_modality
eager_encoder = eager_model.get_encoder(modality)

# Need to do this since apparently when nested modules (e.g. model.language_model) access the .property
# config, it always comes from the generation_config.json file, not the `generation_config` override
# from from_pretrained().
getattr(eager_model, decoder_name).generation_config = eager_model.generation_config
eager_model.get_decoder().generation_config = eager_model.generation_config

# Must disable gradient when exporting a model with a prequantized checkpoint,
# e.g. "pytorch/Phi-4-mini-instruct-8da4w".
Expand All @@ -210,7 +157,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if qlinear_config:
logging.info("Quantizing decoder linears...")
quantize_decoder_kwargs = {
"eager_model": getattr(eager_model, decoder_name),
"eager_model": eager_model.get_decoder(),
"qlinear_config": qlinear_config,
}
if qlinear_group_size is not None:
Expand All @@ -225,13 +172,15 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
# self.decoder = ...
# self.lm_head = ... # lm_head is not part of the decoder instance
# ...
if not hasattr(getattr(eager_model, decoder_name), "lm_head"):
if not hasattr(eager_model, "lm_head"):
if not hasattr(eager_model.get_decoder(), "lm_head"):
# Voxtral specifically is weird since you need to specifically do eager_model.language_model.lm_head.
lm_head = getattr(eager_model, "lm_head", None) or getattr(eager_model.language_model, "lm_head", None)
if not lm_head:
raise AttributeError(
f"Could not find `lm_head` for {model_name_or_path} has no `lm_head`, please double check if this is expected."
)
quantize_lm_head_kwargs = {
"eager_model": eager_model.lm_head,
"eager_model": lm_head,
"qlinear_config": qlinear_config,
}
quantize_model_(**quantize_lm_head_kwargs)
Expand All @@ -240,7 +189,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if qlinear_encoder_config:
logging.info("Quantizing encoder linears...")
quantize_encoder_kwargs = {
"eager_model": getattr(eager_model, encoder_name),
"eager_model": eager_encoder,
"qlinear_config": qlinear_encoder_config,
}
if qlinear_encoder_group_size is not None:
Expand All @@ -253,7 +202,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if qembedding_config:
logging.info("Quantizing embeddings...")
quantize_decoder_embedding_kwargs = {
"eager_model": getattr(eager_model, decoder_name),
"eager_model": eager_model.get_decoder(),
"qembedding_config": qembedding_config,
}
if qembedding_group_size is not None:
Expand All @@ -264,7 +213,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
if qembedding_encoder_config:
logging.info("Quantizing embeddings...")
quantize_encoder_embedding_kwargs = {
"eager_model": getattr(eager_model, encoder_name),
"eager_model": eager_encoder,
"qembedding_config": qembedding_encoder_config,
}
if qembedding_encoder_group_size is not None:
Expand All @@ -273,8 +222,10 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):

return MultiModalTextToTextExportableModule(
model=eager_model,
modality="audio" if audio_encoder_name else "vision",
encoder_name=audio_encoder_name if audio_encoder_name else vision_encoder_name,
modality="vision"
if modality == "image"
else modality, # TODO: hack since downstream uses "vision" atm. Change this to match Transformers.
encoder_model=eager_encoder,
max_seq_len=max_length,
processor_config=processor_config,
use_custom_kv_cache=use_custom_kv_cache,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,5 +350,5 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self):
self.assertTrue("serene" in generated_text)
self.assertTrue("lake" in generated_text)
self.assertTrue(
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5)
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=10)
)
2 changes: 1 addition & 1 deletion tests/models/test_modeling_voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8
# likely friends or acquaintances, who are discussing tattoos.'
self.assertTrue("tattoo" in generated_text)
self.assertTrue(
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5)
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=10)
)

@slow
Expand Down
Loading