Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions install_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

def install_torch_nightly_deps():
"""Install torch related dependencies from pinned nightly"""
EXECUTORCH_NIGHTLY_VERSION = "dev20251104"
TORCHAO_NIGHTLY_VERSION = "dev20251104"
EXECUTORCH_NIGHTLY_VERSION = "dev20260104"
TORCHAO_NIGHTLY_VERSION = "dev20251222"
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/torch_pin.py#L2
TORCH_NIGHTLY_VERSION = "dev20251104"
TORCH_NIGHTLY_VERSION = "dev20251222"
subprocess.check_call(
[
sys.executable,
Expand All @@ -17,10 +17,10 @@ def install_torch_nightly_deps():
"install",
"--no-cache-dir", # Prevent cached CUDA packages
f"executorch==1.1.0.{EXECUTORCH_NIGHTLY_VERSION}",
f"torch==2.10.0.{TORCH_NIGHTLY_VERSION}",
f"torch==2.11.0.{TORCH_NIGHTLY_VERSION}",
f"torchvision==0.25.0.{TORCH_NIGHTLY_VERSION}",
f"torchaudio==2.10.0.{TORCH_NIGHTLY_VERSION}",
f"torchao==0.15.0.{TORCHAO_NIGHTLY_VERSION}",
f"torchao==0.16.0.{TORCHAO_NIGHTLY_VERSION}",
"--extra-index-url",
"https://download.pytorch.org/whl/nightly/cpu",
]
Expand Down
39 changes: 27 additions & 12 deletions optimum/executorch/attentions/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ def __init__(
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
)

assert device is None or device in [
"cpu",
"cuda",
"mps",
], "Device must be None or one of 'cpu', 'cuda' or 'mps'."
# Validate device - handle both string and torch.device types
if device is not None:
device_type = (
device if isinstance(device, str) else (device.type if isinstance(device, torch.device) else None)
)
# Extract just the device type (e.g., "cuda:0" -> "cuda")
if isinstance(device_type, str):
device_type = device_type.split(":")[0]
assert device_type in [
"cpu",
"cuda",
"mps",
], f"Device must be None or one of 'cpu', 'cuda', 'mps' (with optional index like 'cuda:0'), got {device}"

# Create a list of CustomKVCache instances derived from each layer of the original Transformers cache, one per layer.
self.kv_cache = torch.nn.ModuleList()
Expand Down Expand Up @@ -99,8 +107,7 @@ def update(

# Get cache position from cache_kwargs (used by StaticCache)
cache_position = cache_kwargs.get("cache_position")
assert cache_position is not None
assert isinstance(cache_position, torch.Tensor)
torch._assert(cache_position is not None, "cache_position must be provided")

# Get the CustomKVCache instance for this layer
layer_cache = self.kv_cache[layer_idx]
Expand Down Expand Up @@ -212,11 +219,19 @@ def __init__(
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
)

assert device is None or device in [
"cpu",
"cuda",
"mps",
], "Device must be None or one of 'cpu', 'cuda' or 'mps'."
# Validate device - handle both string and torch.device types
if device is not None:
device_type = (
device if isinstance(device, str) else (device.type if isinstance(device, torch.device) else None)
)
# Extract just the device type (e.g., "cuda:0" -> "cuda")
if isinstance(device_type, str):
device_type = device_type.split(":")[0]
assert device_type in [
"cpu",
"cuda",
"mps",
], f"Device must be None or one of 'cpu', 'cuda', 'mps' (with optional index like 'cuda:0'), got {device}"

self.cache_position = None
# Create a list of cache instances, one per layer.
Expand Down
175 changes: 175 additions & 0 deletions optimum/executorch/attentions/whisper_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Export friendly cross attention implementation for Whisper. Adopted
# from https://github.com/huggingface/transformers/blob/454c0a7ccf33f7fc13e3e2eb9b188a5c09ab708b/src/transformers/models/whisper/modeling_whisper.py#L241
# Rewritten to replace if branches with torch.cond. Note that unlike
# the original WhisperAttention, this implementation only works for
# cross attention (where `key_value_states` is not None).

from typing import Callable, Optional

import torch
from executorch.extension.llm.custom_ops import custom_ops # noqa
from torch import Tensor, nn
from transformers.cache_utils import EncoderDecoderCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.whisper.configuration_whisper import WhisperConfig
from transformers.models.whisper.modeling_whisper import eager_attention_forward
from transformers.processing_utils import Unpack
from transformers.utils import logging


logger = logging.get_logger(__name__)


class WhisperCrossAttention(nn.Module):
"""Multi-headed cross attention from 'Attention Is All You Need' paper"""

def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
layer_idx: Optional[int] = None,
config: Optional[WhisperConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

self.register_buffer("cache_initialized", torch.zeros(1, 1, dtype=torch.bool), persistent=False)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: torch.Tensor,
past_key_values: EncoderDecoderCache,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
torch._assert(
isinstance(past_key_values, EncoderDecoderCache),
f"past_key_values must be an EncoderDecoderCache, got {type(past_key_values)}",
)
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)

# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. whisper is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()

# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_values = past_key_values.cross_attention_cache

def use_cached_kv(
cached_keys: Tensor,
cached_values: Tensor,
key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
# Just reuse cached K/V
return torch.ops.executorch.alias(cached_keys, cached_values)

def recompute_kv(
cached_keys: Tensor, # unused
cached_values: Tensor, # unused
key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
# Compute fresh K/V (export-friendly: no cache mutation in here)
key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys)
v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values)
return k, v

# Grab cached tensors (these are Tensors, so they are OK for export)
cached_keys = past_key_values.layers[self.layer_idx].keys
cached_values = past_key_values.layers[self.layer_idx].values

# Use torch.cond to select branch in a traceable way.
# All operands must be (nested) tensors or simple Python values.
key_states, value_states = torch.cond(
self.cache_initialized,
use_cached_kv,
recompute_kv,
operands=(cached_keys, cached_values, key_value_states),
)

# Update the cache_initialized flag to True after first use
self.cache_initialized.fill_(True)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=1.0,
output_attentions=output_attentions,
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
65 changes: 61 additions & 4 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from transformers import (
AutoConfig,
AutoProcessor,
DynamicCache,
EncoderDecoderCache,
PreTrainedModel,
StaticCache,
Expand All @@ -36,6 +35,7 @@
from transformers.modeling_utils import AttentionInterface

from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough
from optimum.executorch.attentions.whisper_attention import WhisperCrossAttention

from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods

Expand Down Expand Up @@ -678,18 +678,72 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.self_attention_cache.early_initialization(batch_size, num_heads, head_dim, model.dtype, model.device)

# Initialize cross attention cache
self.dynamic_cache = DynamicCache(config=self.config)
self.cache = EncoderDecoderCache(self.self_attention_cache, self.dynamic_cache)
cross_attention_heads = getattr(
self.config, "decoder_attention_heads", getattr(self.config, "num_attention_heads", None)
)
if cross_attention_heads is None:
raise ValueError("Unable to determine decoder attention heads for cross-attention cache.")
hidden_size = getattr(self.config, "hidden_size", getattr(self.config, "d_model", None))
if hidden_size is None:
raise ValueError("Unable to determine hidden size for cross-attention cache allocation.")
cross_head_dim = getattr(self.config, "head_dim", hidden_size // cross_attention_heads)

self.cross_attention_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=getattr(
self.config, "max_source_positions", max_static_cache_length
), # This is fixed in whisper
device=model.device,
dtype=model.dtype,
)
self.cross_attention_cache.early_initialization(
batch_size, cross_attention_heads, cross_head_dim, model.dtype, model.device
)

# Register cache buffers to make them exportable.
# Cross attention cache buffer is not registered since it's not actually being used atm.
for i in range(len(self.self_attention_cache)):
self.register_buffer(
f"self_attention_key_cache_{i}", self.self_attention_cache.layers[i].keys, persistent=False
)
self.register_buffer(
f"self_attention_value_cache_{i}", self.self_attention_cache.layers[i].values, persistent=False
)
for i in range(len(self.cross_attention_cache)):
self.register_buffer(
f"cross_attention_key_cache_{i}", self.cross_attention_cache.layers[i].keys, persistent=False
)
self.register_buffer(
f"cross_attention_value_cache_{i}", self.cross_attention_cache.layers[i].values, persistent=False
)
# self.register_buffer(
# "cross_attention_cache_initialized", torch.zeros(batch_size, 1, dtype=torch.bool), persistent=False
# )
# Add a flag to indicate if the cache has been initialized.
# Initialize it as False on CPU so it can be used as a predicate in torch.cond.
# After the first forward pass, we'll set it to True to indicate the cache is populated.
# self.cross_attention_cache._initialized = self.cross_attention_cache_initialized

self.cache = EncoderDecoderCache(self.self_attention_cache, self.cross_attention_cache)
# Use custom cross attention for Whisper.
# Only use WhisperCrossAttention if torch.ops.executorch.alias is available and device is CUDA.
_has_et_alias = hasattr(torch.ops, "executorch") and hasattr(torch.ops.executorch, "alias")
_is_cuda = model.device.type == "cuda"
if isinstance(model, WhisperForConditionalGeneration) and _has_et_alias and _is_cuda:
for layer in self.decoder.layers:
cross_attn = WhisperCrossAttention(
embed_dim=layer.encoder_attn.embed_dim,
num_heads=layer.encoder_attn.num_heads,
dropout=layer.encoder_attn.dropout,
is_decoder=layer.encoder_attn.is_decoder,
layer_idx=layer.encoder_attn.layer_idx,
config=layer.encoder_attn.config,
).to(dtype=model.dtype, device=model.device)
cross_attn.q_proj = layer.encoder_attn.q_proj
cross_attn.k_proj = layer.encoder_attn.k_proj
cross_attn.v_proj = layer.encoder_attn.v_proj
cross_attn.out_proj = layer.encoder_attn.out_proj
layer.encoder_attn = cross_attn

def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
# Get outputs from decoder
Expand All @@ -700,6 +754,9 @@ def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
use_cache=True,
cache_position=cache_position,
)
# Set the cross attention cache as initialized after the first forward pass
# This allows torch.cond to branch differently on subsequent runs
# self.cross_attention_cache_initialized.fill_(True)

# Apply linear projection (lm head) to obtain logits
logits = self.proj_out(outputs[0])
Expand Down
Loading
Loading