Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,26 @@ def save_config_to_constant_methods(
and isinstance(config.num_attention_heads, int)
):
head_dim = config.hidden_size / config.num_attention_heads

# Build EOS token ID list
eos_token_id = getattr(config, "eos_token_id", None)
if isinstance(eos_token_id, list):
eos_ids = list(eos_token_id)
elif eos_token_id is not None:
eos_ids = [eos_token_id]
else:
eos_ids = []

# Add Gemma's <end_of_turn> token if applicable
model_type = getattr(config, "model_type", "")
if "gemma" in model_type.lower() and 106 not in eos_ids:
eos_ids.append(106)

metadata = {
"get_dtype": 5 if config.torch_dtype == torch.float16 else 6,
"get_bos_id": getattr(config, "bos_token_id", None),
"get_eos_id": getattr(config, "eos_token_id", None),
"get_eos_id": eos_ids[0] if eos_ids else None,
"get_eos_ids": eos_ids,
"get_head_dim": head_dim,
"get_n_kv_heads": getattr(config, "num_key_value_heads", None),
"get_n_layers": getattr(config, "num_hidden_layers", None),
Expand All @@ -60,7 +76,9 @@ def save_config_to_constant_methods(
metadata.update(processor_config)

# Combine/override with any additional kwargs and filter out None values
combined_metadata = {k: v for k, v in {**metadata, **kwargs}.items() if v is not None}
combined_metadata = {
k: v for k, v in {**metadata, **kwargs}.items() if v is not None
}
return combined_metadata


Expand Down Expand Up @@ -88,7 +106,9 @@ def apply_chat_template_with_fallback(processor, conversation, **kwargs):
return processor.apply_chat_template(conversation)


def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: TokenizersBackend) -> bool:
def verify_eos_tokens_in_pretrained_tokenizer(
model_eos_ids: List[int], tokenizer: TokenizersBackend
) -> bool:
"""
Verifies that the model's EOS token IDs are present in the tokenizer's
set of potential end-of-sequence tokens.
Expand Down Expand Up @@ -154,15 +174,22 @@ def process_conversation_inputs(
Returns:
Processed inputs ready for model consumption
"""
if isinstance(processor, transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor):
if isinstance(
processor,
transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor,
):
import requests
import torchaudio

conversation = copy.deepcopy(input_conversation)
audio_path = None

# Extract audio content and remove from conversation
audio_items = [(i, item) for i, item in enumerate(conversation) if item.get("type") == "audio"]
audio_items = [
(i, item)
for i, item in enumerate(conversation)
if item.get("type") == "audio"
]
if audio_items:
idx, audio_item = audio_items[0]
audio_path = audio_item["content"]
Expand All @@ -189,7 +216,9 @@ def process_conversation_inputs(
logging.warning(f"Resampled audio from {sampling_rate}Hz to 16000Hz")

# Generate text prompt and process with audio
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
prompt = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
inputs = processor(prompt, wav, return_tensors="pt")
else:
# Standard processing for other processors
Expand Down
Loading