diff --git a/optimum/exporters/executorch/utils.py b/optimum/exporters/executorch/utils.py index d6c41f0..72f7095 100644 --- a/optimum/exporters/executorch/utils.py +++ b/optimum/exporters/executorch/utils.py @@ -38,10 +38,26 @@ def save_config_to_constant_methods( and isinstance(config.num_attention_heads, int) ): head_dim = config.hidden_size / config.num_attention_heads + + # Build EOS token ID list + eos_token_id = getattr(config, "eos_token_id", None) + if isinstance(eos_token_id, list): + eos_ids = list(eos_token_id) + elif eos_token_id is not None: + eos_ids = [eos_token_id] + else: + eos_ids = [] + + # Add Gemma's token if applicable + model_type = getattr(config, "model_type", "") + if "gemma" in model_type.lower() and 106 not in eos_ids: + eos_ids.append(106) + metadata = { "get_dtype": 5 if config.torch_dtype == torch.float16 else 6, "get_bos_id": getattr(config, "bos_token_id", None), - "get_eos_id": getattr(config, "eos_token_id", None), + "get_eos_id": eos_ids[0] if eos_ids else None, + "get_eos_ids": eos_ids, "get_head_dim": head_dim, "get_n_kv_heads": getattr(config, "num_key_value_heads", None), "get_n_layers": getattr(config, "num_hidden_layers", None), @@ -60,7 +76,9 @@ def save_config_to_constant_methods( metadata.update(processor_config) # Combine/override with any additional kwargs and filter out None values - combined_metadata = {k: v for k, v in {**metadata, **kwargs}.items() if v is not None} + combined_metadata = { + k: v for k, v in {**metadata, **kwargs}.items() if v is not None + } return combined_metadata @@ -88,7 +106,9 @@ def apply_chat_template_with_fallback(processor, conversation, **kwargs): return processor.apply_chat_template(conversation) -def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: TokenizersBackend) -> bool: +def verify_eos_tokens_in_pretrained_tokenizer( + model_eos_ids: List[int], tokenizer: TokenizersBackend +) -> bool: """ Verifies that the model's EOS token IDs are present in the tokenizer's set of potential end-of-sequence tokens. @@ -154,7 +174,10 @@ def process_conversation_inputs( Returns: Processed inputs ready for model consumption """ - if isinstance(processor, transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor): + if isinstance( + processor, + transformers.models.granite_speech.processing_granite_speech.GraniteSpeechProcessor, + ): import requests import torchaudio @@ -162,7 +185,11 @@ def process_conversation_inputs( audio_path = None # Extract audio content and remove from conversation - audio_items = [(i, item) for i, item in enumerate(conversation) if item.get("type") == "audio"] + audio_items = [ + (i, item) + for i, item in enumerate(conversation) + if item.get("type") == "audio" + ] if audio_items: idx, audio_item = audio_items[0] audio_path = audio_item["content"] @@ -189,7 +216,9 @@ def process_conversation_inputs( logging.warning(f"Resampled audio from {sampling_rate}Hz to 16000Hz") # Generate text prompt and process with audio - prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True + ) inputs = processor(prompt, wav, return_tensors="pt") else: # Standard processing for other processors