-
Notifications
You must be signed in to change notification settings - Fork 24
[test+operators] Fuse llama attention to circle attention #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
a2057f4
[test+operators] Fuse attention to circle attention
375bcd7
Rename model.py to layer.py
db5c2b5
make lint happy by making code ugly
4e578a6
Fix local-silent but CI-loud lint error
c7c6b79
Add wq,wk,wv,wo and remove_unused_input pass
abf1288
Use recording_input in layer.py
22fb522
Add prefill.py
711c60d
Update layer.py
9fa791f
Rename model.py to layers.py
6789671
Update input_to_remove comment for prefill.py
bf53424
Factor out attention fuser to op_circle_attention.py
275f398
move op_circle_attention.py to onert/op_attention.py
3aaca06
remove unused import from op_attention.py
a04ff6f
Adjust input prompt size and kv_cache size = 12
b190777
remove @torch.library.impl("circle::attention.llama", "CPU")
e40fad5
Remove attention_mask and make kv_cache mandatory, not optional
8d9c2f7
add decode.py to export LlamaModel decode phase
72cd407
Restore attention_mask
22910c2
Fix wrong arg order and move layer_idx from inputs to params
8ac74b7
remove layer_idx
9efa26f
Remove remove_unused_inputs pass
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
69 changes: 69 additions & 0 deletions
69
test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # User input | ||
| prompt = "Lily picked up a flower." | ||
| model_name = "Maykeye/TinyLLama-v0" | ||
|
|
||
| # Tokenizer | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| tokenizer.padding_side = "right" | ||
| inputs = tokenizer( | ||
| prompt, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| max_length=30, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| # Generator | ||
| import torch | ||
|
|
||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
|
|
||
| from tico.utils.record_input import RecordingInput | ||
|
|
||
| # past_key_values | ||
| # --------------- | ||
| # During prefill, "past_key_values" not None, but an empty Cache instance. | ||
| # Passing None makes torch.export happy. | ||
|
|
||
|
|
||
| input_to_remove = [ | ||
| "attention_mask", | ||
| # For left pad, [0, ⋯, 0, 1, ⋯, 1] | ||
| # For right right pad, [1, ⋯, 1, 0, ⋯, 0] | ||
| # ( 0 is pad-token ) | ||
| # This script uses right pad and pass all-1 attention mask (including pad). | ||
| # Npu computes all positions whether it is pad or not. | ||
| ] | ||
| condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 | ||
|
|
||
| with torch.no_grad(), RecordingInput( | ||
| model, condition_fn, input_to_remove=input_to_remove | ||
| ) as rec: | ||
| outputs = model.generate( | ||
| **inputs, | ||
| max_new_tokens=32, | ||
| do_sample=False, | ||
| pad_token_id=tokenizer.eos_token_id, | ||
| ) | ||
| captured_input = rec.captured_input | ||
|
|
||
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
| print(generated_text) | ||
|
|
||
| # Tico | ||
| import tico | ||
| from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter | ||
| from transformers.models.llama.modeling_llama import LlamaAttention | ||
|
|
||
| LlamaAttention.forward = llama_attention_forward_adapter | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
| circle_model = tico.convert(model, captured_input) | ||
| circle_model.save(f"tinyllama.decode.circle") |
56 changes: 56 additions & 0 deletions
56
test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # User input | ||
| prompt = "Lily picked up a flower." | ||
| model_name = "Maykeye/TinyLLama-v0" | ||
|
|
||
| # Tokenizer | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| tokenizer.padding_side = "right" | ||
| inputs = tokenizer( | ||
| prompt, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| max_length=31, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| # Generator | ||
| import torch | ||
|
|
||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
|
|
||
| from tico.utils.record_input import RecordingInput | ||
|
|
||
| target_model = model.model.layers[0] | ||
| condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 | ||
|
|
||
| with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: | ||
| outputs = model.generate( | ||
| **inputs, | ||
| max_new_tokens=32, | ||
| do_sample=False, | ||
| pad_token_id=tokenizer.eos_token_id, | ||
| ) | ||
| captured_input = rec.captured_input | ||
|
|
||
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
| print(generated_text) | ||
|
|
||
|
|
||
| # Convert | ||
|
|
||
| import tico | ||
| from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter | ||
| from transformers.models.llama.modeling_llama import LlamaAttention | ||
|
|
||
| LlamaAttention.forward = llama_attention_forward_adapter | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
| circle_model = tico.convert(model.model.layers[0], captured_input) | ||
| circle_model.save(f"tinyllama.layer.attn.circle") |
110 changes: 110 additions & 0 deletions
110
test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # User input | ||
| prompt = "Lily picked up a flower." | ||
| model_name = "Maykeye/TinyLLama-v0" | ||
|
|
||
| # Tokenizer | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| tokenizer.padding_side = "right" | ||
| inputs = tokenizer( | ||
| prompt, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| max_length=32, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| # Generator | ||
| import torch | ||
|
|
||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
|
|
||
| from tico.utils.record_input import RecordingInput | ||
|
|
||
| target_model = model.model.layers[0] | ||
| condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 | ||
|
|
||
| with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: | ||
| outputs = model.generate( | ||
| **inputs, | ||
| max_new_tokens=32, | ||
| do_sample=False, | ||
| pad_token_id=tokenizer.eos_token_id, | ||
| ) | ||
| captured_input = rec.captured_input | ||
|
|
||
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
| print(generated_text) | ||
|
|
||
| from typing import Any, Optional, Tuple | ||
|
|
||
| # Define DecoderLayers | ||
|
|
||
| from torch import nn | ||
| from transformers.cache_utils import Cache | ||
| from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel | ||
|
|
||
|
|
||
| # DecoderLayers is not nn.Module. Not torch.export-able. | ||
| # Let's define decoder layers as nn.Module. | ||
|
|
||
|
|
||
| class LlamaDecoderLayers(nn.Module): | ||
| def __init__(self, model: LlamaModel): | ||
| super().__init__() | ||
| self.config = model.config | ||
| self.layers = model.layers | ||
|
|
||
| # Make sure signature is same to capturing input. | ||
| # Just copy and Paste from LlamaDecoderLayer::forward | ||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| output_attentions: Optional[bool] = False, | ||
| use_cache: Optional[bool] = False, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| position_embeddings: Optional[ | ||
| Tuple[torch.Tensor, torch.Tensor] | ||
| ] = None, # necessary, but kept here for BC | ||
| **kwargs: Any, | ||
| ) -> Tuple[ | ||
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | ||
| ]: | ||
|
|
||
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: | ||
| layer_outputs = decoder_layer( | ||
| hidden_states, | ||
| attention_mask=attention_mask, | ||
| past_key_value=past_key_value, | ||
| cache_position=cache_position, | ||
| position_embeddings=position_embeddings, | ||
| ) | ||
| hidden_states = layer_outputs[0] | ||
|
|
||
| return hidden_states | ||
|
|
||
|
|
||
| # Convert | ||
|
|
||
| import tico | ||
|
|
||
| # NOTE: | ||
| # If you want to restore forward, it may be implemented as context manager. | ||
| # However, it is just a simple script to export. No one uses forward after tico conversion. | ||
| from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter | ||
|
|
||
| LlamaAttention.forward = llama_attention_forward_adapter | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| layers = LlamaDecoderLayers(model.model) | ||
| layers.eval() | ||
| circle_model = tico.convert(layers, captured_input) | ||
| circle_model.save(f"tinyllama.layers.attn.circle") |
76 changes: 76 additions & 0 deletions
76
test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # User input | ||
| prompt = "Lily picked up a flower." | ||
| model_name = "Maykeye/TinyLLama-v0" | ||
|
|
||
| # Tokenizer | ||
| from transformers import AutoTokenizer | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
| tokenizer.padding_side = "right" | ||
| inputs = tokenizer( | ||
| prompt, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| max_length=32, | ||
| truncation=True, | ||
| ) | ||
|
|
||
| # Generator | ||
| import torch | ||
|
|
||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
|
|
||
| from tico.utils.record_input import RecordingInput | ||
|
|
||
| # past_key_values | ||
| # --------------- | ||
| # During prefill, "past_key_values" not None, but an empty Cache instance. | ||
| # Passing None makes torch.export happy. | ||
|
|
||
|
|
||
| input_to_remove = [ | ||
| "past_key_values", | ||
| # DynamicCache is flatten-able operator since 4.50. | ||
| # See _pytree.py > tree_flatten | ||
| # SUPPORTED_NODES has *transformers.DynamicCache* | ||
| # After flattening, DynamicCache becomes { "key_cache": [] , "value_cache": [ ] } | ||
| # dict.value is returne. dict.key is stored in treespec. | ||
| # | ||
| # On prefill, DynamicCache is empty, and dict is empty after flattening. | ||
| # PyTorch removes empty dict! | ||
| # If number of args is 4 (including cache), it becomes 3! | ||
| # To avoid this error, don't pass empty cache, just pass None. | ||
| "attention_mask", | ||
| # For left pad, [0, ⋯, 0, 1, ⋯, 1] | ||
| # For right right pad, [1, ⋯, 1, 0, ⋯, 0] | ||
| # ( 0 is pad-token ) | ||
| # This script uses right pad and pass all-1 attention mask (including pad). | ||
| # Npu computes all positions whether it is pad or not. | ||
| "cache_position" | ||
| # It is the list of cache position like [0, 1, ..., 11]. | ||
| # For npu, we always store all values (including pad). | ||
| ] | ||
|
|
||
| with torch.no_grad(), RecordingInput(model, input_to_remove=input_to_remove) as rec: | ||
| outputs = model.generate( | ||
| **inputs, | ||
| max_new_tokens=32, | ||
| do_sample=False, | ||
| pad_token_id=tokenizer.eos_token_id, | ||
| ) | ||
| captured_input = rec.captured_input | ||
|
|
||
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | ||
| print(generated_text) | ||
|
|
||
| # Tico | ||
| import tico | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(model_name) | ||
| model.eval() | ||
| circle_model = tico.convert(model, captured_input) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then, pass |
||
| circle_model.save(f"tinyllama.prefill.circle") | ||
1 change: 1 addition & 0 deletions
1
test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| transformers>=4.50.1 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Retrieve captured_input