diff --git a/install_dev.py b/install_dev.py index ed831882..7007eb94 100644 --- a/install_dev.py +++ b/install_dev.py @@ -5,10 +5,10 @@ def install_torch_nightly_deps(): """Install torch related dependencies from pinned nightly""" - EXECUTORCH_NIGHTLY_VERSION = "dev20251104" - TORCHAO_NIGHTLY_VERSION = "dev20251104" + EXECUTORCH_NIGHTLY_VERSION = "dev20260104" + TORCHAO_NIGHTLY_VERSION = "dev20251222" # Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/torch_pin.py#L2 - TORCH_NIGHTLY_VERSION = "dev20251104" + TORCH_NIGHTLY_VERSION = "dev20251222" subprocess.check_call( [ sys.executable, @@ -17,10 +17,10 @@ def install_torch_nightly_deps(): "install", "--no-cache-dir", # Prevent cached CUDA packages f"executorch==1.1.0.{EXECUTORCH_NIGHTLY_VERSION}", - f"torch==2.10.0.{TORCH_NIGHTLY_VERSION}", + f"torch==2.11.0.{TORCH_NIGHTLY_VERSION}", f"torchvision==0.25.0.{TORCH_NIGHTLY_VERSION}", f"torchaudio==2.10.0.{TORCH_NIGHTLY_VERSION}", - f"torchao==0.15.0.{TORCHAO_NIGHTLY_VERSION}", + f"torchao==0.16.0.{TORCHAO_NIGHTLY_VERSION}", "--extra-index-url", "https://download.pytorch.org/whl/nightly/cpu", ] diff --git a/optimum/executorch/attentions/custom_kv_cache.py b/optimum/executorch/attentions/custom_kv_cache.py index cf0d4ad9..64b7322d 100644 --- a/optimum/executorch/attentions/custom_kv_cache.py +++ b/optimum/executorch/attentions/custom_kv_cache.py @@ -51,11 +51,19 @@ def __init__( batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device ) - assert device is None or device in [ - "cpu", - "cuda", - "mps", - ], "Device must be None or one of 'cpu', 'cuda' or 'mps'." + # Validate device - handle both string and torch.device types + if device is not None: + device_type = ( + device if isinstance(device, str) else (device.type if isinstance(device, torch.device) else None) + ) + # Extract just the device type (e.g., "cuda:0" -> "cuda") + if isinstance(device_type, str): + device_type = device_type.split(":")[0] + assert device_type in [ + "cpu", + "cuda", + "mps", + ], f"Device must be None or one of 'cpu', 'cuda', 'mps' (with optional index like 'cuda:0'), got {device}" # Create a list of CustomKVCache instances derived from each layer of the original Transformers cache, one per layer. self.kv_cache = torch.nn.ModuleList() @@ -99,8 +107,7 @@ def update( # Get cache position from cache_kwargs (used by StaticCache) cache_position = cache_kwargs.get("cache_position") - assert cache_position is not None - assert isinstance(cache_position, torch.Tensor) + torch._assert(cache_position is not None, "cache_position must be provided") # Get the CustomKVCache instance for this layer layer_cache = self.kv_cache[layer_idx] @@ -212,11 +219,19 @@ def __init__( batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device ) - assert device is None or device in [ - "cpu", - "cuda", - "mps", - ], "Device must be None or one of 'cpu', 'cuda' or 'mps'." + # Validate device - handle both string and torch.device types + if device is not None: + device_type = ( + device if isinstance(device, str) else (device.type if isinstance(device, torch.device) else None) + ) + # Extract just the device type (e.g., "cuda:0" -> "cuda") + if isinstance(device_type, str): + device_type = device_type.split(":")[0] + assert device_type in [ + "cpu", + "cuda", + "mps", + ], f"Device must be None or one of 'cpu', 'cuda', 'mps' (with optional index like 'cuda:0'), got {device}" self.cache_position = None # Create a list of cache instances, one per layer. diff --git a/optimum/executorch/attentions/whisper_attention.py b/optimum/executorch/attentions/whisper_attention.py new file mode 100644 index 00000000..e25a4b4b --- /dev/null +++ b/optimum/executorch/attentions/whisper_attention.py @@ -0,0 +1,175 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Export friendly cross attention implementation for Whisper. Adopted +# from https://github.com/huggingface/transformers/blob/454c0a7ccf33f7fc13e3e2eb9b188a5c09ab708b/src/transformers/models/whisper/modeling_whisper.py#L241 +# Rewritten to replace if branches with torch.cond. Note that unlike +# the original WhisperAttention, this implementation only works for +# cross attention (where `key_value_states` is not None). + +from typing import Callable, Optional + +import torch +from executorch.extension.llm.custom_ops import custom_ops # noqa +from torch import Tensor, nn +from transformers.cache_utils import EncoderDecoderCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.whisper.configuration_whisper import WhisperConfig +from transformers.models.whisper.modeling_whisper import eager_attention_forward +from transformers.processing_utils import Unpack +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class WhisperCrossAttention(nn.Module): + """Multi-headed cross attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + layer_idx: Optional[int] = None, + config: Optional[WhisperConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.register_buffer("cache_initialized", torch.zeros(1, 1, dtype=torch.bool), persistent=False) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + past_key_values: EncoderDecoderCache, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + torch._assert( + isinstance(past_key_values, EncoderDecoderCache), + f"past_key_values must be an EncoderDecoderCache, got {type(past_key_values)}", + ) + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + + # Scaling is susceptible to floating point arithmetics' inprecisions + # which can lead to different results (this is dependent from model + # to model, e.g. whisper is one such case). We therefore keep the + # original order of scaling to follow the original implementation + # and enforce no scaling (1.0) in the attention call below. + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(*q_input_shape) + query_states = query_states.transpose(1, 2).contiguous() + + # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_values = past_key_values.cross_attention_cache + + def use_cached_kv( + cached_keys: Tensor, + cached_values: Tensor, + key_value_states: Tensor, + ) -> tuple[Tensor, Tensor]: + # Just reuse cached K/V + return torch.ops.executorch.alias(cached_keys, cached_values) + + def recompute_kv( + cached_keys: Tensor, # unused + cached_values: Tensor, # unused + key_value_states: Tensor, + ) -> tuple[Tensor, Tensor]: + # Compute fresh K/V (export-friendly: no cache mutation in here) + key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys) + v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values) + return k, v + + # Grab cached tensors (these are Tensors, so they are OK for export) + cached_keys = past_key_values.layers[self.layer_idx].keys + cached_values = past_key_values.layers[self.layer_idx].values + + # Use torch.cond to select branch in a traceable way. + # All operands must be (nested) tensors or simple Python values. + key_states, value_states = torch.cond( + self.cache_initialized, + use_cached_kv, + recompute_kv, + operands=(cached_keys, cached_values, key_value_states), + ) + + # Update the cache_initialized flag to True after first use + self.cache_initialized.fill_(True) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + output_attentions=output_attentions, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 0a2aa995..ce7d6a47 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -22,7 +22,6 @@ from transformers import ( AutoConfig, AutoProcessor, - DynamicCache, EncoderDecoderCache, PreTrainedModel, StaticCache, @@ -36,6 +35,7 @@ from transformers.modeling_utils import AttentionInterface from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough +from optimum.executorch.attentions.whisper_attention import WhisperCrossAttention from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods @@ -678,11 +678,30 @@ def __init__(self, model, max_static_cache_length, batch_size): self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, model.dtype, model.device) # Initialize cross attention cache - self.dynamic_cache = DynamicCache(config=self.config) - self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache) + cross_attention_heads = getattr( + self.config, "decoder_attention_heads", getattr(self.config, "num_attention_heads", None) + ) + if cross_attention_heads is None: + raise ValueError("Unable to determine decoder attention heads for cross-attention cache.") + hidden_size = getattr(self.config, "hidden_size", getattr(self.config, "d_model", None)) + if hidden_size is None: + raise ValueError("Unable to determine hidden size for cross-attention cache allocation.") + cross_head_dim = getattr(self.config, "head_dim", hidden_size // cross_attention_heads) + + self.cross_attention_cache = StaticCache( + config=self.config, + max_batch_size=batch_size, + max_cache_len=getattr( + self.config, "max_source_positions", max_static_cache_length + ), # This is fixed in whisper + device=model.device, + dtype=model.dtype, + ) + self.cross_attention_cache.early_initialization( + batch_size, cross_attention_heads, cross_head_dim, model.dtype, model.device + ) # Register cache buffers to make them exportable. - # Cross attention cache buffer is not registered since it's not actually being used atm. for i in range(len(self.self_attention_cache)): self.register_buffer( f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False @@ -690,6 +709,41 @@ def __init__(self, model, max_static_cache_length, batch_size): self.register_buffer( f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False ) + for i in range(len(self.cross_attention_cache)): + self.register_buffer( + f"cross_attention_key_cache_{i}", self.cross_attention_cache.layers[i].keys, persistent=False + ) + self.register_buffer( + f"cross_attention_value_cache_{i}", self.cross_attention_cache.layers[i].values, persistent=False + ) + # self.register_buffer( + # "cross_attention_cache_initialized", torch.zeros(batch_size, 1, dtype=torch.bool), persistent=False + # ) + # Add a flag to indicate if the cache has been initialized. + # Initialize it as False on CPU so it can be used as a predicate in torch.cond. + # After the first forward pass, we'll set it to True to indicate the cache is populated. + # self.cross_attention_cache._initialized = self.cross_attention_cache_initialized + + self.cache = EncoderDecoderCache(self.self_attention_cache, self.cross_attention_cache) + # Use custom cross attention for Whisper. + # Only use WhisperCrossAttention if torch.ops.executorch.alias is available and device is CUDA. + _has_et_alias = hasattr(torch.ops, "executorch") and hasattr(torch.ops.executorch, "alias") + _is_cuda = model.device.type == "cuda" + if isinstance(model, WhisperForConditionalGeneration) and _has_et_alias and _is_cuda: + for layer in self.decoder.layers: + cross_attn = WhisperCrossAttention( + embed_dim=layer.encoder_attn.embed_dim, + num_heads=layer.encoder_attn.num_heads, + dropout=layer.encoder_attn.dropout, + is_decoder=layer.encoder_attn.is_decoder, + layer_idx=layer.encoder_attn.layer_idx, + config=layer.encoder_attn.config, + ).to(dtype=model.dtype, device=model.device) + cross_attn.q_proj = layer.encoder_attn.q_proj + cross_attn.k_proj = layer.encoder_attn.k_proj + cross_attn.v_proj = layer.encoder_attn.v_proj + cross_attn.out_proj = layer.encoder_attn.out_proj + layer.encoder_attn = cross_attn def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): # Get outputs from decoder @@ -700,6 +754,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position): use_cache=True, cache_position=cache_position, ) + # Set the cross attention cache as initialized after the first forward pass + # This allows torch.cond to branch differently on subsequent runs + # self.cross_attention_cache_initialized.fill_(True) # Apply linear projection (lm head) to obtain logits logits = self.proj_out(outputs[0]) diff --git a/optimum/exporters/executorch/quantization.py b/optimum/exporters/executorch/quantization.py index 6e48d9dc..7e32244a 100644 --- a/optimum/exporters/executorch/quantization.py +++ b/optimum/exporters/executorch/quantization.py @@ -16,8 +16,6 @@ from typing import Optional import torch -from packaging.version import parse -from torch import __version__ as torch_version def quantize_model_( @@ -31,7 +29,6 @@ def quantize_model_( if not (qlinear_config or qembedding_config): return - from torchao.experimental.quant_api import UIntxWeightOnlyConfig from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, @@ -43,9 +40,9 @@ def quantize_model_( if qembedding_config: if qlinear_config == "8w": - assert qembedding_group_size == 0, ( - "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0." - ) + assert ( + qembedding_group_size == 0 + ), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0." if qembedding_group_size == 0: embedding_weight_granularity = PerAxis(0) else: @@ -105,6 +102,8 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: if quant_config_key == "fpa4w": # Need to import to load the ops import torchao.experimental.ops.mps # noqa: F401 + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + return UIntxWeightOnlyConfig( group_size=qlinear_group_size, bitwidth=4, @@ -128,9 +127,9 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: ) fallback_linear_config_key = None else: - assert qlinear_group_size % 2 == 0, ( - f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}." - ) + assert ( + qlinear_group_size % 2 == 0 + ), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}." linear_weight_granularity = PerGroup(qlinear_group_size) logging.info("Quantizing linear layers.") @@ -172,6 +171,4 @@ def per_token_filter(module, fqn): filter_fn=per_token_filter, ) - # TODO: remove after ExecuTorch dep on Torch >= 2.10.0. - if parse(torch_version) < parse("2.10.0.dev20251104"): - unwrap_tensor_subclass(eager_model) + unwrap_tensor_subclass(eager_model) diff --git a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py index cdb26b9f..7fc7811b 100644 --- a/optimum/exporters/executorch/tasks/multimodal_text_to_text.py +++ b/optimum/exporters/executorch/tasks/multimodal_text_to_text.py @@ -222,9 +222,9 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs): return MultiModalTextToTextExportableModule( model=eager_model, - modality="vision" - if modality == "image" - else modality, # TODO: hack since downstream uses "vision" atm. Change this to match Transformers. + modality=( + "vision" if modality == "image" else modality + ), # TODO: hack since downstream uses "vision" atm. Change this to match Transformers. encoder_model=eager_encoder, max_seq_len=max_length, processor_config=processor_config,