From b59177edf6d7c27ab43c6cf1b66ed4836ffbd7e5 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 15:59:27 +0900 Subject: [PATCH 01/13] Fuse LlamaAttention to attention (onert) It fuses LlamaAttention from TinyLlama model. Fused attention works as onert attention op. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- .../TinyLlamaWithFusedAttention/__init__.py | 1 + .../TinyLlamaWithFusedAttention/decode.py | 71 ++++++++ .../requirements.txt | 1 + .../operators/adapters/onert/__init__.py | 1 + .../operators/adapters/onert/op_attention.py | 154 ++++++++++++++++++ 5 files changed, 228 insertions(+) create mode 100644 test/modules/model/TinyLlamaWithFusedAttention/__init__.py create mode 100644 test/modules/model/TinyLlamaWithFusedAttention/decode.py create mode 100644 test/modules/model/TinyLlamaWithFusedAttention/requirements.txt create mode 100644 tico/serialize/operators/adapters/onert/__init__.py create mode 100644 tico/serialize/operators/adapters/onert/op_attention.py diff --git a/test/modules/model/TinyLlamaWithFusedAttention/__init__.py b/test/modules/model/TinyLlamaWithFusedAttention/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedAttention/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/test/modules/model/TinyLlamaWithFusedAttention/decode.py b/test/modules/model/TinyLlamaWithFusedAttention/decode.py new file mode 100644 index 00000000..d70953cc --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedAttention/decode.py @@ -0,0 +1,71 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=30, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +# past_key_values +# --------------- +# During prefill, "past_key_values" not None, but an empty Cache instance. +# Passing None makes torch.export happy. + + +input_to_remove = [ + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. +] +condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput( + model, condition_fn, input_to_remove=input_to_remove +) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico +from tico.serialize.operators.adapters.onert.op_attention import ( + llama_attention_forward_adapter, +) +from transformers.models.llama.modeling_llama import LlamaAttention + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"tinyllama.decode.circle") diff --git a/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt b/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt new file mode 100644 index 00000000..5393938f --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt @@ -0,0 +1 @@ +transformers>=4.50.1 diff --git a/tico/serialize/operators/adapters/onert/__init__.py b/tico/serialize/operators/adapters/onert/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/tico/serialize/operators/adapters/onert/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/tico/serialize/operators/adapters/onert/op_attention.py b/tico/serialize/operators/adapters/onert/op_attention.py new file mode 100644 index 00000000..0a04d5b3 --- /dev/null +++ b/tico/serialize/operators/adapters/onert/op_attention.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from torch.library import Library + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index + +lib = Library("circle", "DEF") +lib.define( + """ +attention.llama( + Tensor hidden_states, + Tensor wq, + Tensor wk, + Tensor wv, + Tensor wo, + Tensor position_cos, + Tensor position_sin, + Tensor attention_mask, + Tensor past_key, + Tensor past_value, + Tensor cache_position +) -> Tensor +""" +) + +# ATTENTION FUSER + + +@torch.library.register_fake("circle::attention.llama") +def attention_llama(*args, **kwargs): + ( + hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + cache_position, + ) = args + return hidden_states + + +from typing import List, Optional + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def llama_attention_forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: torch.Tensor, + past_key_value: DynamicCache, + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] + return ( + torch.ops.circle.attention.llama( + hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + # key_cache is a list of cache for each decoder layer. + # Assumtion: key cache is continuous + # + # k_cache[0] | k_cache[1] | ... | k_cache[n] + key_cache[self.layer_idx], + value_cache[self.layer_idx], # Same to value_cache + cache_position, + ), + None, + ) + + +@register_node_visitor +class AttentionVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle.attention.llama, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + ( + hidden_states, + wq, + wk, + wv, + wo, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + cache_position, + ) = node.args + + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes + ) + + # remove last arg (= layer_idx) from inputs. + # layer_idx is attention op's param, not input. + inputs = node.args[:-1] + outputs = [node] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + # Op-specific option + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.AttentionOptions + ) + operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT() + + return operator From fa45293c159bda226ac6d5e51c4751c2251c7979 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 7 Nov 2025 15:28:54 +0900 Subject: [PATCH 02/13] Update requirements.txt --- test/modules/model/TinyLlamaWithFusedAttention/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt b/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt index 5393938f..bc13c47a 100644 --- a/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt +++ b/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt @@ -1 +1 @@ -transformers>=4.50.1 +transformers==4.50.3 From e94388d14b2405e596a00ddac2c15007bebc6863 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 08:27:11 +0900 Subject: [PATCH 03/13] register circle_custom::attention in TICO's way --- .../operators/adapters/onert/op_attention.py | 52 ++----------------- tico/utils/register_custom_op.py | 35 +++++++++++++ 2 files changed, 39 insertions(+), 48 deletions(-) diff --git a/tico/serialize/operators/adapters/onert/op_attention.py b/tico/serialize/operators/adapters/onert/op_attention.py index 0a04d5b3..1ff65a6c 100644 --- a/tico/serialize/operators/adapters/onert/op_attention.py +++ b/tico/serialize/operators/adapters/onert/op_attention.py @@ -20,58 +20,14 @@ import torch from circle_schema import circle -from torch.library import Library +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention from tico.serialize.circle_graph import CircleSubgraph from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index -lib = Library("circle", "DEF") -lib.define( - """ -attention.llama( - Tensor hidden_states, - Tensor wq, - Tensor wk, - Tensor wv, - Tensor wo, - Tensor position_cos, - Tensor position_sin, - Tensor attention_mask, - Tensor past_key, - Tensor past_value, - Tensor cache_position -) -> Tensor -""" -) - -# ATTENTION FUSER - - -@torch.library.register_fake("circle::attention.llama") -def attention_llama(*args, **kwargs): - ( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - cache_position, - ) = args - return hidden_states - - -from typing import List, Optional - -from transformers.cache_utils import DynamicCache -from transformers.models.llama.modeling_llama import LlamaAttention - def llama_attention_forward_adapter( self: LlamaAttention, @@ -87,7 +43,7 @@ def llama_attention_forward_adapter( key_cache = past_key_value.key_cache # type: ignore[union-attr] value_cache = past_key_value.value_cache # type: ignore[union-attr] return ( - torch.ops.circle.attention.llama( + torch.ops.circle_custom.attention( hidden_states, self.q_proj.weight, self.k_proj.weight, @@ -111,7 +67,7 @@ def llama_attention_forward_adapter( @register_node_visitor class AttentionVisitor(NodeVisitor): target: List[torch._ops.OpOverload] = [ - torch.ops.circle.attention.llama, + torch.ops.circle_custom.attention.default, ] def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 095cb6a5..7f400047 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -727,6 +727,40 @@ def _( return hidden_states.new_empty(hidden_states.size()) +def CircleAttention(): + @custom_op("circle_custom::attention", mutates_args=()) + def attention( + hidden_states: torch.Tensor, + wq: torch.Tensor, + wk: torch.Tensor, + wv: torch.Tensor, + wo: torch.Tensor, + position_cos: torch.Tensor, + position_sin: torch.Tensor, + attention_mask: torch.Tensor, + past_key: torch.Tensor, + past_value: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + return None + + @register_fake("circle_custom::attention") + def _( + hidden_states: torch.Tensor, + wq: torch.Tensor, + wk: torch.Tensor, + wv: torch.Tensor, + wo: torch.Tensor, + position_cos: torch.Tensor, + position_sin: torch.Tensor, + attention_mask: torch.Tensor, + past_key: torch.Tensor, + past_value: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + return hidden_states + + # Add custom ops to the torch namespace def RegisterOps(): CircleResizeNearestNeighbor() @@ -740,3 +774,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleAttention() From 2c49c53cc8bb80de16b1bae29eec2bbf80d85209 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 08:42:10 +0900 Subject: [PATCH 04/13] Extract attention adapter code into adapters/onert --- .../TinyLlamaWithFusedAttention/decode.py | 2 +- .../adapters/onert/llama_attention.py | 51 +++++++++++++++++++ .../{adapters/onert => }/op_attention.py | 37 -------------- 3 files changed, 52 insertions(+), 38 deletions(-) create mode 100644 tico/serialize/operators/adapters/onert/llama_attention.py rename tico/serialize/operators/{adapters/onert => }/op_attention.py (63%) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/decode.py b/test/modules/model/TinyLlamaWithFusedAttention/decode.py index d70953cc..16d12d2f 100644 --- a/test/modules/model/TinyLlamaWithFusedAttention/decode.py +++ b/test/modules/model/TinyLlamaWithFusedAttention/decode.py @@ -58,7 +58,7 @@ # Tico import tico -from tico.serialize.operators.adapters.onert.op_attention import ( +from tico.serialize.operators.adapters.onert.llama_attention import ( llama_attention_forward_adapter, ) from transformers.models.llama.modeling_llama import LlamaAttention diff --git a/tico/serialize/operators/adapters/onert/llama_attention.py b/tico/serialize/operators/adapters/onert/llama_attention.py new file mode 100644 index 00000000..2700117d --- /dev/null +++ b/tico/serialize/operators/adapters/onert/llama_attention.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +import torch + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def llama_attention_forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: torch.Tensor, + past_key_value: DynamicCache, + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] + return ( + torch.ops.circle_custom.attention( + hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + key_cache[self.layer_idx], + value_cache[self.layer_idx], # Same to value_cache + cache_position, + ), + None, + ) diff --git a/tico/serialize/operators/adapters/onert/op_attention.py b/tico/serialize/operators/op_attention.py similarity index 63% rename from tico/serialize/operators/adapters/onert/op_attention.py rename to tico/serialize/operators/op_attention.py index 1ff65a6c..160f2228 100644 --- a/tico/serialize/operators/adapters/onert/op_attention.py +++ b/tico/serialize/operators/op_attention.py @@ -20,49 +20,12 @@ import torch from circle_schema import circle -from transformers.cache_utils import DynamicCache -from transformers.models.llama.modeling_llama import LlamaAttention - from tico.serialize.circle_graph import CircleSubgraph from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index -def llama_attention_forward_adapter( - self: LlamaAttention, - hidden_states: torch.Tensor, - position_embeddings: List[torch.Tensor], - attention_mask: torch.Tensor, - past_key_value: DynamicCache, - cache_position: torch.Tensor, - **kwargs, -): - # past_key_value is a dict with key_cache and value_cache. - # It needs to be decomposed for tico and circle which does not know dict. - key_cache = past_key_value.key_cache # type: ignore[union-attr] - value_cache = past_key_value.value_cache # type: ignore[union-attr] - return ( - torch.ops.circle_custom.attention( - hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - self.o_proj.weight, - position_embeddings[0], # cos - position_embeddings[1], # sin - attention_mask, - # key_cache is a list of cache for each decoder layer. - # Assumtion: key cache is continuous - # - # k_cache[0] | k_cache[1] | ... | k_cache[n] - key_cache[self.layer_idx], - value_cache[self.layer_idx], # Same to value_cache - cache_position, - ), - None, - ) - @register_node_visitor class AttentionVisitor(NodeVisitor): From 275b6cfd11fa35175020e6fc224d172fb4bb87d9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 09:01:45 +0900 Subject: [PATCH 05/13] Update tico/utils/validate_args_kwargs.py --- tico/serialize/operators/op_attention.py | 17 ++--------------- tico/utils/validate_args_kwargs.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tico/serialize/operators/op_attention.py b/tico/serialize/operators/op_attention.py index 160f2228..79f72129 100644 --- a/tico/serialize/operators/op_attention.py +++ b/tico/serialize/operators/op_attention.py @@ -24,7 +24,7 @@ from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index - +from tico.utils.validate_args_kwargs import CircleAttentionArgs @register_node_visitor @@ -40,20 +40,7 @@ def define_node( self, node: torch.fx.Node, ) -> circle.Operator.OperatorT: - ( - hidden_states, - wq, - wk, - wv, - wo, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - cache_position, - ) = node.args - + args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type] op_index = get_op_index( circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes ) diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 8a5feb21..60fad7b8 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -171,6 +171,26 @@ class CatArgs: dim: int = 0 +@enforce_type +@dataclass +class CircleAttentionArgs: + """ + For circle.BuiltinOperator.BuiltinOperator.ATTENTION + """ + + hidden_states: torch.fx.Node + wq: torch.fx.Node + wk: torch.fx.Node + wv: torch.fx.Node + wo: torch.fx.Node + position_cos: torch.fx.Node + position_sin: torch.fx.Node + attention_mask: torch.fx.Node + past_key: torch.fx.Node + past_value: torch.fx.Node + cache_position: torch.fx.Node + + @enforce_type @dataclass class CircleRMSNormArgs: From c01470d7cfe8317ee56d669d9114d6b2bff333f5 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 11:03:43 +0900 Subject: [PATCH 06/13] rename decode.py to model.py and support test --- .../TinyLlamaWithFusedAttention/decode.py | 71 --------------- .../TinyLlamaWithFusedAttention/model.py | 87 +++++++++++++++++++ 2 files changed, 87 insertions(+), 71 deletions(-) delete mode 100644 test/modules/model/TinyLlamaWithFusedAttention/decode.py create mode 100644 test/modules/model/TinyLlamaWithFusedAttention/model.py diff --git a/test/modules/model/TinyLlamaWithFusedAttention/decode.py b/test/modules/model/TinyLlamaWithFusedAttention/decode.py deleted file mode 100644 index 16d12d2f..00000000 --- a/test/modules/model/TinyLlamaWithFusedAttention/decode.py +++ /dev/null @@ -1,71 +0,0 @@ -# User input -prompt = "Lily picked up a flower." -model_name = "Maykeye/TinyLLama-v0" - -# Tokenizer -from transformers import AutoTokenizer - -tokenizer = AutoTokenizer.from_pretrained(model_name) -tokenizer.pad_token = tokenizer.eos_token -tokenizer.padding_side = "right" -inputs = tokenizer( - prompt, - return_tensors="pt", - padding="max_length", - max_length=30, - truncation=True, -) - -# Generator -import torch - -from transformers import AutoModelForCausalLM - -model = AutoModelForCausalLM.from_pretrained(model_name) -model.eval() - -from tico.utils.record_input import RecordingInput - -# past_key_values -# --------------- -# During prefill, "past_key_values" not None, but an empty Cache instance. -# Passing None makes torch.export happy. - - -input_to_remove = [ - "attention_mask", - # For left pad, [0, ⋯, 0, 1, ⋯, 1] - # For right right pad, [1, ⋯, 1, 0, ⋯, 0] - # ( 0 is pad-token ) - # This script uses right pad and pass all-1 attention mask (including pad). - # Npu computes all positions whether it is pad or not. -] -condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 - -with torch.no_grad(), RecordingInput( - model, condition_fn, input_to_remove=input_to_remove -) as rec: - outputs = model.generate( - **inputs, - max_new_tokens=32, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - ) - captured_input = rec.captured_input - -generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) -print(generated_text) - -# Tico -import tico -from tico.serialize.operators.adapters.onert.llama_attention import ( - llama_attention_forward_adapter, -) -from transformers.models.llama.modeling_llama import LlamaAttention - -LlamaAttention.forward = llama_attention_forward_adapter - -model = AutoModelForCausalLM.from_pretrained(model_name) -model.eval() -circle_model = tico.convert(model, captured_input) -circle_model.save(f"tinyllama.decode.circle") diff --git a/test/modules/model/TinyLlamaWithFusedAttention/model.py b/test/modules/model/TinyLlamaWithFusedAttention/model.py new file mode 100644 index 00000000..f4b0fbfa --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedAttention/model.py @@ -0,0 +1,87 @@ +import torch + +from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm +from tico.serialize.operators.adapters.onert.llama_attention import ( + llama_attention_forward_adapter, +) +from tico.utils.pytree_utils import register_dynamic_cache +from tico.utils.record_input import RecordingInput +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformers.models.llama.modeling_llama import LlamaAttention + +from test.modules.base import TestModuleBase +from test.utils import tag + + +@tag.use_onert +class TinyLlamaWithFusedAttention(TestModuleBase): + def __init__(self): + super().__init__() + self.model_name = "Maykeye/TinyLLama-v0" + self._call_count = 0 + self.original_model = AutoModelForCausalLM.from_pretrained( + self.model_name + ).eval() + self.fused_model = AutoModelForCausalLM.from_pretrained(self.model_name).eval() + for layer in self.fused_model.model.layers: + layer.self_attn.forward = llama_attention_forward_adapter.__get__( + layer.self_attn + ) + self.rtol = 1e-4 + self.atol = 1e-4 + + def forward(self, *args, **kwargs): + self._call_count += 1 + + if self._call_count == 2: + return self.fused_model(*args, **kwargs) + else: + return self.original_model(*args, **kwargs) + + def get_example_inputs(self): + prompt = "Lily picked up a flower." + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=30, + truncation=True, + ) + model = self.original_model + model.eval() + + # past_key_values + # --------------- + # During prefill, "past_key_values" not None, but an empty Cache instance. + # Passing None makes torch.export happy. + + input_to_remove = [ + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. + ] + condition_fn = ( + lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 + ) + + with torch.no_grad(), RecordingInput( + model, condition_fn, input_to_remove=input_to_remove + ) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(generated_text) + return captured_input, {} From 6a5b1f53e662994dde4b3790c053b19f063ee35a Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 11:18:51 +0900 Subject: [PATCH 07/13] Remove layer_idx --- tico/serialize/operators/op_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tico/serialize/operators/op_attention.py b/tico/serialize/operators/op_attention.py index 79f72129..6f4c02e4 100644 --- a/tico/serialize/operators/op_attention.py +++ b/tico/serialize/operators/op_attention.py @@ -45,9 +45,7 @@ def define_node( circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes ) - # remove last arg (= layer_idx) from inputs. - # layer_idx is attention op's param, not input. - inputs = node.args[:-1] + inputs = node.args outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) From be5723210f9db39ac72974627cc03f1f4f03078c Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 11:26:06 +0900 Subject: [PATCH 08/13] Disable circle2circle --- test/pt2_to_circle_test/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index 3f793067..c80df3e9 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -172,7 +172,7 @@ def _run( config=compile_config, ) - verify_circle(circle_model_path, opt_circle_model_path) + # verify_circle(circle_model_path, opt_circle_model_path) if dynamic_shapes: From 9281efced16a0d2c79a962b10b33e9c5042eaa53 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 11:47:25 +0900 Subject: [PATCH 09/13] Update comment why attention_mask is removed from inputs --- .../model/TinyLlamaWithFusedAttention/model.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/model.py b/test/modules/model/TinyLlamaWithFusedAttention/model.py index f4b0fbfa..96d1aa64 100644 --- a/test/modules/model/TinyLlamaWithFusedAttention/model.py +++ b/test/modules/model/TinyLlamaWithFusedAttention/model.py @@ -54,18 +54,9 @@ def get_example_inputs(self): model = self.original_model model.eval() - # past_key_values - # --------------- - # During prefill, "past_key_values" not None, but an empty Cache instance. - # Passing None makes torch.export happy. - input_to_remove = [ "attention_mask", - # For left pad, [0, ⋯, 0, 1, ⋯, 1] - # For right right pad, [1, ⋯, 1, 0, ⋯, 0] - # ( 0 is pad-token ) - # This script uses right pad and pass all-1 attention mask (including pad). - # Npu computes all positions whether it is pad or not. + # attention mask will be manually provided, not from example input ] condition_fn = ( lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 From ad53700802dc70050be6cd9cb3f88143db96f6bf Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 12:20:34 +0900 Subject: [PATCH 10/13] Add get_compile_config --- test/modules/model/TinyLlamaWithFusedAttention/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/model.py b/test/modules/model/TinyLlamaWithFusedAttention/model.py index 96d1aa64..c45e2388 100644 --- a/test/modules/model/TinyLlamaWithFusedAttention/model.py +++ b/test/modules/model/TinyLlamaWithFusedAttention/model.py @@ -1,5 +1,5 @@ import torch - +from tico.config.v1 import CompileConfigV1 from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm from tico.serialize.operators.adapters.onert.llama_attention import ( llama_attention_forward_adapter, @@ -76,3 +76,6 @@ def get_example_inputs(self): generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) return captured_input, {} + + def get_compile_config(self): + return CompileConfigV1(convert_single_batch_lhs_const_bmm_to_fc=True) From c882fe68ce3bd7954a4898b7f1509446a2230246 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 13:44:31 +0900 Subject: [PATCH 11/13] let it work in more environment --- test/modules/model/TinyLlamaWithFusedAttention/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/model.py b/test/modules/model/TinyLlamaWithFusedAttention/model.py index c45e2388..73e41c5e 100644 --- a/test/modules/model/TinyLlamaWithFusedAttention/model.py +++ b/test/modules/model/TinyLlamaWithFusedAttention/model.py @@ -30,6 +30,7 @@ def __init__(self): ) self.rtol = 1e-4 self.atol = 1e-4 + register_dynamic_cache() def forward(self, *args, **kwargs): self._call_count += 1 From 3233267f3f52e06effa5a5bf9db7fd048f6b160f Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 12 Nov 2025 15:52:25 +0900 Subject: [PATCH 12/13] Update requirements.txt (torch <= 2.8.0) --- test/modules/model/LlamaAttentionWithKVCache/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/modules/model/LlamaAttentionWithKVCache/requirements.txt b/test/modules/model/LlamaAttentionWithKVCache/requirements.txt index fcceb41a..956286d9 100644 --- a/test/modules/model/LlamaAttentionWithKVCache/requirements.txt +++ b/test/modules/model/LlamaAttentionWithKVCache/requirements.txt @@ -1,2 +1,3 @@ numpy==1.24.1 transformers==4.49.0 +torch<=2.8.0 From c8179711eab4dabfa6c93a93852c7a5f3137593a Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Mon, 24 Nov 2025 18:44:22 +0900 Subject: [PATCH 13/13] Remove test --- .../requirements.txt | 1 - .../TinyLlamaWithFusedAttention/__init__.py | 1 - .../TinyLlamaWithFusedAttention/model.py | 82 ------------------- .../requirements.txt | 1 - test/pt2_to_circle_test/builder.py | 2 +- 5 files changed, 1 insertion(+), 86 deletions(-) delete mode 100644 test/modules/model/TinyLlamaWithFusedAttention/__init__.py delete mode 100644 test/modules/model/TinyLlamaWithFusedAttention/model.py delete mode 100644 test/modules/model/TinyLlamaWithFusedAttention/requirements.txt diff --git a/test/modules/model/LlamaAttentionWithKVCache/requirements.txt b/test/modules/model/LlamaAttentionWithKVCache/requirements.txt index 956286d9..fcceb41a 100644 --- a/test/modules/model/LlamaAttentionWithKVCache/requirements.txt +++ b/test/modules/model/LlamaAttentionWithKVCache/requirements.txt @@ -1,3 +1,2 @@ numpy==1.24.1 transformers==4.49.0 -torch<=2.8.0 diff --git a/test/modules/model/TinyLlamaWithFusedAttention/__init__.py b/test/modules/model/TinyLlamaWithFusedAttention/__init__.py deleted file mode 100644 index 0c29109f..00000000 --- a/test/modules/model/TinyLlamaWithFusedAttention/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# DO NOT REMOVE THIS FILE diff --git a/test/modules/model/TinyLlamaWithFusedAttention/model.py b/test/modules/model/TinyLlamaWithFusedAttention/model.py deleted file mode 100644 index 73e41c5e..00000000 --- a/test/modules/model/TinyLlamaWithFusedAttention/model.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -from tico.config.v1 import CompileConfigV1 -from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm -from tico.serialize.operators.adapters.onert.llama_attention import ( - llama_attention_forward_adapter, -) -from tico.utils.pytree_utils import register_dynamic_cache -from tico.utils.record_input import RecordingInput -from transformers import AutoModelForCausalLM, AutoTokenizer - -from transformers.models.llama.modeling_llama import LlamaAttention - -from test.modules.base import TestModuleBase -from test.utils import tag - - -@tag.use_onert -class TinyLlamaWithFusedAttention(TestModuleBase): - def __init__(self): - super().__init__() - self.model_name = "Maykeye/TinyLLama-v0" - self._call_count = 0 - self.original_model = AutoModelForCausalLM.from_pretrained( - self.model_name - ).eval() - self.fused_model = AutoModelForCausalLM.from_pretrained(self.model_name).eval() - for layer in self.fused_model.model.layers: - layer.self_attn.forward = llama_attention_forward_adapter.__get__( - layer.self_attn - ) - self.rtol = 1e-4 - self.atol = 1e-4 - register_dynamic_cache() - - def forward(self, *args, **kwargs): - self._call_count += 1 - - if self._call_count == 2: - return self.fused_model(*args, **kwargs) - else: - return self.original_model(*args, **kwargs) - - def get_example_inputs(self): - prompt = "Lily picked up a flower." - tokenizer = AutoTokenizer.from_pretrained(self.model_name) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "right" - inputs = tokenizer( - prompt, - return_tensors="pt", - padding="max_length", - max_length=30, - truncation=True, - ) - model = self.original_model - model.eval() - - input_to_remove = [ - "attention_mask", - # attention mask will be manually provided, not from example input - ] - condition_fn = ( - lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 - ) - - with torch.no_grad(), RecordingInput( - model, condition_fn, input_to_remove=input_to_remove - ) as rec: - outputs = model.generate( - **inputs, - max_new_tokens=32, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - ) - captured_input = rec.captured_input - - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - print(generated_text) - return captured_input, {} - - def get_compile_config(self): - return CompileConfigV1(convert_single_batch_lhs_const_bmm_to_fc=True) diff --git a/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt b/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt deleted file mode 100644 index bc13c47a..00000000 --- a/test/modules/model/TinyLlamaWithFusedAttention/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -transformers==4.50.3 diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index c80df3e9..3f793067 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -172,7 +172,7 @@ def _run( config=compile_config, ) - # verify_circle(circle_model_path, opt_circle_model_path) + verify_circle(circle_model_path, opt_circle_model_path) if dynamic_shapes: