Skip to content

Support cross attention kv cache#187

Merged
larryliu0820 merged 1 commit intomainfrom
whisper_cond
Jan 7, 2026
Merged

Support cross attention kv cache#187
larryliu0820 merged 1 commit intomainfrom
whisper_cond

Conversation

@larryliu0820
Copy link
Collaborator

To avoid excessive computation we want to support kv cache for cross attention in Whisper.

Fundamentally we only run k_proj and v_proj once on the encoder output hidden state, at the first token generation, then we should keep the key_states and value_states and reuse them in all the subsequent token generation.

For whisper-large-v3-turbo, where we have 4 layers of decoder:

WhisperDecoder(
  (embed_tokens): Embedding(51866, 1280, padding_idx=50257)
  (embed_positions): WhisperPositionalEmbedding(448, 1280)
  (layers): ModuleList(
    (0-3): 4 x WhisperDecoderLayer(
      (self_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)

Without KV cache in encoder_attn, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.

This PR replaces encoder_attn with a WhisperCrossAttention class, where we replaces if condition with torch.cond. The logic becomes:

  • If KV cache values are all zero:
    • Compute KV projections
  • Otherwise:
    • Clone from KV cache. Note here we can't directly return KV cache, due to the non-aliasing requirement.
  • After torch.cond:
    • Write back the values from either branch back to KV cache

Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.

Copy link
Collaborator

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also run make style for formatting

@larryliu0820
Copy link
Collaborator Author

This works, gives correct output, but eventually we still need to copy data from GPU to CPU, just for the predicate. There's no way we can workaround it.

For whisper-large-v3-turbo, there are 4 decoder layers, so we see 4 cudaAsyncMemcpy blocks in each token generation:
588375888_849887687980811_8745763307608935071_n

This is too expensive to be a good solution.

To avoid excessive computation we want to support kv cache for cross attention in Whisper.

Fundamentally we only run `k_proj` and `v_proj` once on the encoder output hidden state, at the first token generation, then we should keep the `key_states` and `value_states` and reuse them in all the subsequent token generation.

For whisper-large-v3-turbo, where we have 4 layers of decoder:

```
WhisperDecoder(
  (embed_tokens): Embedding(51866, 1280, padding_idx=50257)
  (embed_positions): WhisperPositionalEmbedding(448, 1280)
  (layers): ModuleList(
    (0-3): 4 x WhisperDecoderLayer(
      (self_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (encoder_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
)
```

Without KV cache in `encoder_attn`, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.

This PR replaces `encoder_attn` with a `WhisperCrossAttention` class, where we replaces `if` condition with `torch.cond`. The logic becomes:

- If `cache_initialized` is False:
  - Compute KV projections, update KV cache
- Otherwise:
  - Directly return cached KV cache
- After torch.cond:
  - Set `cache_initialized` to True

Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.
@larryliu0820 larryliu0820 merged commit 732b113 into main Jan 7, 2026
69 of 83 checks passed
@larryliu0820 larryliu0820 deleted the whisper_cond branch January 7, 2026 01:13
larryliu0820 added a commit to pytorch/executorch that referenced this pull request Jan 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants