Merged
Conversation
jackzhxng
reviewed
Nov 19, 2025
jackzhxng
reviewed
Nov 19, 2025
Collaborator
jackzhxng
left a comment
There was a problem hiding this comment.
Oh also run make style for formatting
b6e172d to
6ca7dd0
Compare
Collaborator
Author
14a9a0e to
128b100
Compare
89d8db3 to
4bdde1c
Compare
4bdde1c to
cbf9682
Compare
JacobSzwejbka
approved these changes
Jan 6, 2026
To avoid excessive computation we want to support kv cache for cross attention in Whisper.
Fundamentally we only run `k_proj` and `v_proj` once on the encoder output hidden state, at the first token generation, then we should keep the `key_states` and `value_states` and reuse them in all the subsequent token generation.
For whisper-large-v3-turbo, where we have 4 layers of decoder:
```
WhisperDecoder(
(embed_tokens): Embedding(51866, 1280, padding_idx=50257)
(embed_positions): WhisperPositionalEmbedding(448, 1280)
(layers): ModuleList(
(0-3): 4 x WhisperDecoderLayer(
(self_attn): WhisperAttention(
(k_proj): Linear(in_features=1280, out_features=1280, bias=False)
(v_proj): Linear(in_features=1280, out_features=1280, bias=True)
(q_proj): Linear(in_features=1280, out_features=1280, bias=True)
(out_proj): Linear(in_features=1280, out_features=1280, bias=True)
)
(activation_fn): GELUActivation()
(self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(encoder_attn): WhisperAttention(
(k_proj): Linear(in_features=1280, out_features=1280, bias=False)
(v_proj): Linear(in_features=1280, out_features=1280, bias=True)
(q_proj): Linear(in_features=1280, out_features=1280, bias=True)
(out_proj): Linear(in_features=1280, out_features=1280, bias=True)
)
(encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=1280, out_features=5120, bias=True)
(fc2): Linear(in_features=5120, out_features=1280, bias=True)
(final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
)
(layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
```
Without KV cache in `encoder_attn`, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.
This PR replaces `encoder_attn` with a `WhisperCrossAttention` class, where we replaces `if` condition with `torch.cond`. The logic becomes:
- If `cache_initialized` is False:
- Compute KV projections, update KV cache
- Otherwise:
- Directly return cached KV cache
- After torch.cond:
- Set `cache_initialized` to True
Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.
cbf9682 to
37957fe
Compare
larryliu0820
added a commit
to pytorch/executorch
that referenced
this pull request
Jan 8, 2026
…16485) As titled, so that we can include huggingface/optimum-executorch#187
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

To avoid excessive computation we want to support kv cache for cross attention in Whisper.
Fundamentally we only run
k_projandv_projonce on the encoder output hidden state, at the first token generation, then we should keep thekey_statesandvalue_statesand reuse them in all the subsequent token generation.For whisper-large-v3-turbo, where we have 4 layers of decoder:
Without KV cache in
encoder_attn, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.This PR replaces
encoder_attnwith aWhisperCrossAttentionclass, where we replacesifcondition withtorch.cond. The logic becomes:Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.