From a2057f44ba6816f5658ac1dbc6f29e4e5693fd51 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 15:59:27 +0900 Subject: [PATCH 01/21] [test+operators] Fuse attention to circle attention It introduces circle_attention op, and add tests which fuse attention from LlamaDecoderLayers. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- .../model.py | 156 ++++++++++++++++++ .../requirements.txt | 1 + .../operators/op_circle_attention.py | 88 ++++++++++ 3 files changed, 245 insertions(+) create mode 100644 test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py create mode 100644 test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt create mode 100644 tico/serialize/operators/op_circle_attention.py diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py new file mode 100644 index 00000000..7c0bbb77 --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -0,0 +1,156 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +captured_input = () + +import copy, inspect, types + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +forward_org = LlamaDecoderLayer.forward + + +def capture_and_forward(self, *args, **kwargs): + global captured_input + + # Prepare args tuple for TICO.convert() + # Get arg_names in positional args order using inspect + sig = inspect.signature(forward_org) + args_names = [ + # signature includes `self`` and `kwargs``. + # Just retrieve the ordinary positional inputs only + name + for name in sig.parameters.keys() + if name not in ("self", "kwargs") + ] + + args_dict = dict(zip(args_names, args)) + args_dict.update(kwargs) + + def populate_args(args_dict, filter): + for key in filter: + args_dict.pop(key, None) + args_tuple = tuple(args_dict.get(name, None) for name in args_names) + return copy.deepcopy(args_tuple) + + if len(args_dict["past_key_value"].key_cache) != 0: + input_to_remove = ["use_cache"] + captured_input = populate_args(args_dict, input_to_remove) + + return forward_org(self, *args, **kwargs) + + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +model.model.layers[0].forward = types.MethodType( + capture_and_forward, model.model.layers[0] +) +with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + + +# ATTENTION FUSER + +from typing import Optional, List + + +@torch.library.impl("circle::attention.llama", "CPU") +def attention_llama_cpu( + hidden_states, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, +): + return hidden_states + + +@torch.library.register_fake("circle::attention.llama") +def attention_llama(*args, **kwargs): + ( + hidden_states, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, + ) = args + return hidden_states + + +from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.cache_utils import DynamicCache + +def forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[DynamicCache], + cache_position: torch.Tensor, + **kwargs +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache + value_cache = past_key_value.value_cache + return ( + torch.ops.circle.attention.llama( + hidden_states, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + # key_cache is a list of cache for each decoder layer. + # Assumtion: key cache is continuous + # + # k_cache[0] | k_cache[1] | ... | k_cache[n] + key_cache[0], + value_cache[0], # Same to value_cache + self.layer_idx, + cache_position, + ), + None, + ) + +LlamaAttention.forward = forward_adapter + +# Tico +import tico + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model.model.layers[0], captured_input) +circle_model.save(f"tinyllama.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt new file mode 100644 index 00000000..5393938f --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/requirements.txt @@ -0,0 +1 @@ +transformers>=4.50.1 diff --git a/tico/serialize/operators/op_circle_attention.py b/tico/serialize/operators/op_circle_attention.py new file mode 100644 index 00000000..076784ef --- /dev/null +++ b/tico/serialize/operators/op_circle_attention.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph, extract_shape +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index + +from torch.library import Library + +lib = Library("circle", "DEF") +lib.define(""" +attention.llama( + Tensor hidden_states, + Tensor position_cos, + Tensor position_sin, + Tensor? attention_mask, + Tensor past_key, + Tensor past_value, + int layer_idx, + Tensor cache_position +) -> Tensor +""") + +@register_node_visitor +class AttentionVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle.attention.llama, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + ( + hidden_states, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + cache_position, + layer_idx, + ) = node.args + + inputs = node.args + outputs = [node] + + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes + ) + + inputs = node.args + outputs = [node] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + # Op-specific option + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.AttentionOptions + ) + option = circle.AttentionOptions.AttentionOptionsT() + option.layer_idx = layer_idx + + operator.builtinOptions = option + + return operator From 375bcd7fb6d9b6f56354c4aeef7ee93ba7461205 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 18:16:09 +0900 Subject: [PATCH 02/21] Rename model.py to layer.py model.py for LlamaModel layer.py for LlamaDecoderLayer --- .../layer.py | 158 ++++++++++++++++++ .../model.py | 63 +++++-- .../operators/op_circle_attention.py | 11 +- 3 files changed, 214 insertions(+), 18 deletions(-) create mode 100644 test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py new file mode 100644 index 00000000..19b2d6ab --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -0,0 +1,158 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +captured_input = () + +import copy, inspect, types + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +forward_org = LlamaDecoderLayer.forward + + +def capture_and_forward(self, *args, **kwargs): + global captured_input + + # Prepare args tuple for TICO.convert() + # Get arg_names in positional args order using inspect + sig = inspect.signature(forward_org) + args_names = [ + # signature includes `self`` and `kwargs``. + # Just retrieve the ordinary positional inputs only + name + for name in sig.parameters.keys() + if name not in ("self", "kwargs") + ] + + args_dict = dict(zip(args_names, args)) + args_dict.update(kwargs) + + def populate_args(args_dict, filter): + for key in filter: + args_dict.pop(key, None) + args_tuple = tuple(args_dict.get(name, None) for name in args_names) + return copy.deepcopy(args_tuple) + + if len(args_dict["past_key_value"].key_cache) != 0: + input_to_remove = ["use_cache"] + captured_input = populate_args(args_dict, input_to_remove) + + return forward_org(self, *args, **kwargs) + + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +model.model.layers[0].forward = types.MethodType( + capture_and_forward, model.model.layers[0] +) +with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + + +# ATTENTION FUSER + +from typing import List, Optional + + +@torch.library.impl("circle::attention.llama", "CPU") +def attention_llama_cpu( + hidden_states, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, +): + return hidden_states + + +@torch.library.register_fake("circle::attention.llama") +def attention_llama(*args, **kwargs): + ( + hidden_states, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, + ) = args + return hidden_states + + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[DynamicCache], + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache + value_cache = past_key_value.value_cache + return ( + torch.ops.circle.attention.llama( + hidden_states, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + # key_cache is a list of cache for each decoder layer. + # Assumtion: key cache is continuous + # + # k_cache[0] | k_cache[1] | ... | k_cache[n] + key_cache[0], + value_cache[0], # Same to value_cache + self.layer_idx, + cache_position, + ), + None, + ) + + +LlamaAttention.forward = forward_adapter + +# Tico +import tico + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model.model.layers[0], captured_input) +circle_model.save(f"tinyllama.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py index 7c0bbb77..1d4127dd 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -22,7 +22,7 @@ def capture_and_forward(self, *args, **kwargs): # Just retrieve the ordinary positional inputs only name for name in sig.parameters.keys() - if name not in ("self", "kwargs") + if name not in ("self", "kwargs", "use_cache", "position_ids", "output_attentions") ] args_dict = dict(zip(args_names, args)) @@ -34,8 +34,10 @@ def populate_args(args_dict, filter): args_tuple = tuple(args_dict.get(name, None) for name in args_names) return copy.deepcopy(args_tuple) - if len(args_dict["past_key_value"].key_cache) != 0: - input_to_remove = ["use_cache"] + if args_dict["past_key_value"].get_seq_length() != 0 and captured_input == (): + input_to_remove = [ + "use_cache", + ] captured_input = populate_args(args_dict, input_to_remove) return forward_org(self, *args, **kwargs) @@ -76,10 +78,9 @@ def populate_args(args_dict, filter): generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) - # ATTENTION FUSER -from typing import Optional, List +from typing import List, Optional, Tuple @torch.library.impl("circle::attention.llama", "CPU") @@ -111,8 +112,9 @@ def attention_llama(*args, **kwargs): return hidden_states +from transformers.cache_utils import Cache, DynamicCache from transformers.models.llama.modeling_llama import LlamaAttention -from transformers.cache_utils import DynamicCache + def forward_adapter( self: LlamaAttention, @@ -121,7 +123,7 @@ def forward_adapter( attention_mask: Optional[torch.Tensor], past_key_value: Optional[DynamicCache], cache_position: torch.Tensor, - **kwargs + **kwargs, ): # past_key_value is a dict with key_cache and value_cache. # It needs to be decomposed for tico and circle which does not know dict. @@ -130,27 +132,60 @@ def forward_adapter( return ( torch.ops.circle.attention.llama( hidden_states, - position_embeddings[0], # cos - position_embeddings[1], # sin + position_embeddings[0], # cos + position_embeddings[1], # sin attention_mask, # key_cache is a list of cache for each decoder layer. # Assumtion: key cache is continuous # # k_cache[0] | k_cache[1] | ... | k_cache[n] key_cache[0], - value_cache[0], # Same to value_cache + value_cache[0], # Same to value_cache self.layer_idx, cache_position, ), None, ) -LlamaAttention.forward = forward_adapter # Tico + import tico +from torch import nn +from transformers.models.llama.modeling_llama import LlamaModel + model = AutoModelForCausalLM.from_pretrained(model_name) -model.eval() -circle_model = tico.convert(model.model.layers[0], captured_input) -circle_model.save(f"tinyllama.attn.circle") + +class LlamaDecoderLayers(nn.Module): + def __init__(self, model: LlamaModel): + super().__init__() + self.config = model.config + self.layers = model.layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + + return hidden_states + +layers = LlamaDecoderLayers(model.model) +LlamaAttention.forward = forward_adapter +layers.eval() +circle_model = tico.convert(layers, captured_input) +circle_model.save(f"tinyllama.model.attn.circle") diff --git a/tico/serialize/operators/op_circle_attention.py b/tico/serialize/operators/op_circle_attention.py index 076784ef..b83d7c74 100644 --- a/tico/serialize/operators/op_circle_attention.py +++ b/tico/serialize/operators/op_circle_attention.py @@ -20,15 +20,16 @@ import torch from circle_schema import circle +from torch.library import Library + from tico.serialize.circle_graph import CircleSubgraph, extract_shape from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index -from torch.library import Library - lib = Library("circle", "DEF") -lib.define(""" +lib.define( + """ attention.llama( Tensor hidden_states, Tensor position_cos, @@ -39,7 +40,9 @@ int layer_idx, Tensor cache_position ) -> Tensor -""") +""" +) + @register_node_visitor class AttentionVisitor(NodeVisitor): From db5c2b5d49f1e93c30db37f52c37c1b8171ffd95 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 17 Jul 2025 19:53:13 +0900 Subject: [PATCH 03/21] make lint happy by making code ugly --- .../model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py index 1d4127dd..6a91ddc6 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -22,7 +22,8 @@ def capture_and_forward(self, *args, **kwargs): # Just retrieve the ordinary positional inputs only name for name in sig.parameters.keys() - if name not in ("self", "kwargs", "use_cache", "position_ids", "output_attentions") + if name + not in ("self", "kwargs", "use_cache", "position_ids", "output_attentions") ] args_dict = dict(zip(args_names, args)) @@ -157,6 +158,7 @@ def forward_adapter( model = AutoModelForCausalLM.from_pretrained(model_name) + class LlamaDecoderLayers(nn.Module): def __init__(self, model: LlamaModel): super().__init__() @@ -169,8 +171,12 @@ def forward( attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: for decoder_layer in self.layers[: self.config.num_hidden_layers]: layer_outputs = decoder_layer( @@ -184,6 +190,7 @@ def forward( return hidden_states + layers = LlamaDecoderLayers(model.model) LlamaAttention.forward = forward_adapter layers.eval() From 4e578a69d9a3f2bb4583f6b80cd223518c26309c Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 18 Jul 2025 11:17:53 +0900 Subject: [PATCH 04/21] Fix local-silent but CI-loud lint error --- .../LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py | 4 ++-- .../LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 19b2d6ab..32778526 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -126,8 +126,8 @@ def forward_adapter( ): # past_key_value is a dict with key_cache and value_cache. # It needs to be decomposed for tico and circle which does not know dict. - key_cache = past_key_value.key_cache - value_cache = past_key_value.value_cache + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] return ( torch.ops.circle.attention.llama( hidden_states, diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py index 6a91ddc6..679c3e5d 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -128,8 +128,8 @@ def forward_adapter( ): # past_key_value is a dict with key_cache and value_cache. # It needs to be decomposed for tico and circle which does not know dict. - key_cache = past_key_value.key_cache - value_cache = past_key_value.value_cache + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] return ( torch.ops.circle.attention.llama( hidden_states, From c7c6b79dc0f4ebbcca867cfb62c4f785a4960955 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 18 Jul 2025 13:43:26 +0900 Subject: [PATCH 05/21] Add wq,wk,wv,wo and remove_unused_input pass --- .../layer.py | 14 ++++- .../model.py | 12 +++++ tico/passes/remove_unused_inputs.py | 51 +++++++++++++++++++ .../operators/op_circle_attention.py | 11 ++-- tico/utils/convert.py | 2 + 5 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 tico/passes/remove_unused_inputs.py diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 32778526..2297a7f1 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -85,6 +85,10 @@ def populate_args(args_dict, filter): @torch.library.impl("circle::attention.llama", "CPU") def attention_llama_cpu( hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, position_cos, position_sin, attention_mask, @@ -100,6 +104,10 @@ def attention_llama_cpu( def attention_llama(*args, **kwargs): ( hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, position_cos, position_sin, attention_mask, @@ -131,6 +139,10 @@ def forward_adapter( return ( torch.ops.circle.attention.llama( hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, position_embeddings[0], # cos position_embeddings[1], # sin attention_mask, @@ -155,4 +167,4 @@ def forward_adapter( model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() circle_model = tico.convert(model.model.layers[0], captured_input) -circle_model.save(f"tinyllama.attn.circle") +circle_model.save(f"tinyllama.layer.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py index 679c3e5d..6019bb3e 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -87,6 +87,10 @@ def populate_args(args_dict, filter): @torch.library.impl("circle::attention.llama", "CPU") def attention_llama_cpu( hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, position_cos, position_sin, attention_mask, @@ -102,6 +106,10 @@ def attention_llama_cpu( def attention_llama(*args, **kwargs): ( hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, position_cos, position_sin, attention_mask, @@ -133,6 +141,10 @@ def forward_adapter( return ( torch.ops.circle.attention.llama( hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, position_embeddings[0], # cos position_embeddings[1], # sin attention_mask, diff --git a/tico/passes/remove_unused_inputs.py b/tico/passes/remove_unused_inputs.py new file mode 100644 index 00000000..14bbcbc7 --- /dev/null +++ b/tico/passes/remove_unused_inputs.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.passes import ops +from tico.utils import logging +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass + + +@trace_graph_diff_on_pass +class RemoveUnusedInput(PassBase): + """ + Let's remove dead inputs + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + for node in graph.nodes: + if node.op == "placeholder" and len(node.users) == 0: + graph.erase_node(node) + modified = True + + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/serialize/operators/op_circle_attention.py b/tico/serialize/operators/op_circle_attention.py index b83d7c74..3d6d73ff 100644 --- a/tico/serialize/operators/op_circle_attention.py +++ b/tico/serialize/operators/op_circle_attention.py @@ -32,6 +32,10 @@ """ attention.llama( Tensor hidden_states, + Tensor wq, + Tensor wk, + Tensor wv, + Tensor wo, Tensor position_cos, Tensor position_sin, Tensor? attention_mask, @@ -59,6 +63,10 @@ def define_node( ) -> circle.Operator.OperatorT: ( hidden_states, + wq, + wk, + wv, + wo, position_cos, position_sin, attention_mask, @@ -68,9 +76,6 @@ def define_node( layer_idx, ) = node.args - inputs = node.args - outputs = [node] - op_index = get_op_index( circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes ) diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 7ac47f25..d248edd7 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -58,6 +58,7 @@ from tico.passes.remove_redundant_reshape import passes as RemoveRedundantViewPasses from tico.passes.remove_redundant_slice import RemoveRedundantSlice from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy +from tico.passes.remove_unused_inputs import RemoveUnusedInput from tico.passes.restore_linear import RestoreLinear from tico.passes.segment_index_select import SegmentIndexSelectConst from tico.quantization.passes.fold_quant_ops import FoldQuantOps @@ -261,6 +262,7 @@ def convert_exported_module_to_circle( *LowerToSlicePasses(), FuseLeadingUnsqueezeReshape(), CastClampMixedTypeArgs(), + RemoveUnusedInput(), ] ) circle_legalize.run(exported_program) From abf1288205caa37064539c64165e7b8c0a3a074d Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 22 Jul 2025 19:45:34 +0900 Subject: [PATCH 06/21] Use recording_input in layer.py --- .../layer.py | 52 +++---------------- .../model.py | 4 +- 2 files changed, 9 insertions(+), 47 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 2297a7f1..b4972827 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -2,45 +2,6 @@ prompt = "Lily picked up a flower." model_name = "Maykeye/TinyLLama-v0" -captured_input = () - -import copy, inspect, types - -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -forward_org = LlamaDecoderLayer.forward - - -def capture_and_forward(self, *args, **kwargs): - global captured_input - - # Prepare args tuple for TICO.convert() - # Get arg_names in positional args order using inspect - sig = inspect.signature(forward_org) - args_names = [ - # signature includes `self`` and `kwargs``. - # Just retrieve the ordinary positional inputs only - name - for name in sig.parameters.keys() - if name not in ("self", "kwargs") - ] - - args_dict = dict(zip(args_names, args)) - args_dict.update(kwargs) - - def populate_args(args_dict, filter): - for key in filter: - args_dict.pop(key, None) - args_tuple = tuple(args_dict.get(name, None) for name in args_names) - return copy.deepcopy(args_tuple) - - if len(args_dict["past_key_value"].key_cache) != 0: - input_to_remove = ["use_cache"] - captured_input = populate_args(args_dict, input_to_remove) - - return forward_org(self, *args, **kwargs) - - # Tokenizer from transformers import AutoTokenizer @@ -55,7 +16,6 @@ def populate_args(args_dict, filter): truncation=True, ) - # Generator import torch @@ -63,16 +23,20 @@ def populate_args(args_dict, filter): model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() -model.model.layers[0].forward = types.MethodType( - capture_and_forward, model.model.layers[0] -) -with torch.no_grad(): + +from tico.utils.record_input import RecordingInput + +condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput(model.model.layers[0], condition_fn) as rec: outputs = model.generate( **inputs, max_new_tokens=32, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) + captured_input = rec.captured_input + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py index 6019bb3e..32d7128a 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py @@ -36,9 +36,7 @@ def populate_args(args_dict, filter): return copy.deepcopy(args_tuple) if args_dict["past_key_value"].get_seq_length() != 0 and captured_input == (): - input_to_remove = [ - "use_cache", - ] + input_to_remove = [] captured_input = populate_args(args_dict, input_to_remove) return forward_org(self, *args, **kwargs) From 22fb52239dfde7f1c876f8b9adb40b94182af550 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 23 Jul 2025 09:47:31 +0900 Subject: [PATCH 07/21] Add prefill.py --- .../prefill.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py new file mode 100644 index 00000000..cd7f7f65 --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py @@ -0,0 +1,58 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=32, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +# past_key_values +# --------------- +# During prefill, "past_key_values" not None, but an empty Cache instance. +# Passing None makes torch.export happy. + +# attention_mask, cache_position +# ------------------------------ +# For npu, ignore captured values generated from example prompt. + +input_to_remove = ["past_key_values", "attention_mask", "cache_position"] + +with torch.no_grad(), RecordingInput(model, input_to_remove=input_to_remove) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"tinyllama.prefill.circle") From 711c60d250e85ff63e300548c10cb2960f356e59 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 23 Jul 2025 10:12:34 +0900 Subject: [PATCH 08/21] Update layer.py --- .../layer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index b4972827..4971fa92 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -26,9 +26,10 @@ from tico.utils.record_input import RecordingInput +target_model = model.model.layers[0] condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 -with torch.no_grad(), RecordingInput(model.model.layers[0], condition_fn) as rec: +with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: outputs = model.generate( **inputs, max_new_tokens=32, @@ -123,12 +124,13 @@ def forward_adapter( ) -LlamaAttention.forward = forward_adapter - # Tico import tico model = AutoModelForCausalLM.from_pretrained(model_name) + +LlamaAttention.forward = forward_adapter + model.eval() circle_model = tico.convert(model.model.layers[0], captured_input) circle_model.save(f"tinyllama.layer.attn.circle") From 9fa791f37b2ef5ad1ac8f34a42151d78a0a50bf9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 23 Jul 2025 11:37:11 +0900 Subject: [PATCH 09/21] Rename model.py to layers.py --- .../{model.py => layers.py} | 68 +++++-------------- 1 file changed, 18 insertions(+), 50 deletions(-) rename test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/{model.py => layers.py} (73%) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py similarity index 73% rename from test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py rename to test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py index 32d7128a..7bdb2360 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py @@ -2,46 +2,6 @@ prompt = "Lily picked up a flower." model_name = "Maykeye/TinyLLama-v0" -captured_input = () - -import copy, inspect, types - -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -forward_org = LlamaDecoderLayer.forward - - -def capture_and_forward(self, *args, **kwargs): - global captured_input - - # Prepare args tuple for TICO.convert() - # Get arg_names in positional args order using inspect - sig = inspect.signature(forward_org) - args_names = [ - # signature includes `self`` and `kwargs``. - # Just retrieve the ordinary positional inputs only - name - for name in sig.parameters.keys() - if name - not in ("self", "kwargs", "use_cache", "position_ids", "output_attentions") - ] - - args_dict = dict(zip(args_names, args)) - args_dict.update(kwargs) - - def populate_args(args_dict, filter): - for key in filter: - args_dict.pop(key, None) - args_tuple = tuple(args_dict.get(name, None) for name in args_names) - return copy.deepcopy(args_tuple) - - if args_dict["past_key_value"].get_seq_length() != 0 and captured_input == (): - input_to_remove = [] - captured_input = populate_args(args_dict, input_to_remove) - - return forward_org(self, *args, **kwargs) - - # Tokenizer from transformers import AutoTokenizer @@ -56,7 +16,6 @@ def populate_args(args_dict, filter): truncation=True, ) - # Generator import torch @@ -64,22 +23,28 @@ def populate_args(args_dict, filter): model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() -model.model.layers[0].forward = types.MethodType( - capture_and_forward, model.model.layers[0] -) -with torch.no_grad(): + +from tico.utils.record_input import RecordingInput + +target_model = model.model.layers[0] +condition_fn = lambda args_dict: args_dict["past_key_value"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput(target_model, condition_fn) as rec: outputs = model.generate( **inputs, max_new_tokens=32, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) + captured_input = rec.captured_input + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) + # ATTENTION FUSER -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple @torch.library.impl("circle::attention.llama", "CPU") @@ -160,7 +125,6 @@ def forward_adapter( # Tico - import tico from torch import nn @@ -179,11 +143,15 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC + **kwargs: Any, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: @@ -192,7 +160,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, - past_key_value=past_key_values, + past_key_value=past_key_value, cache_position=cache_position, position_embeddings=position_embeddings, ) @@ -205,4 +173,4 @@ def forward( LlamaAttention.forward = forward_adapter layers.eval() circle_model = tico.convert(layers, captured_input) -circle_model.save(f"tinyllama.model.attn.circle") +circle_model.save(f"tinyllama.layers.attn.circle") From 6789671ca67eac14f7daaeda75680319a5bfd0ba Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 23 Jul 2025 12:23:06 +0900 Subject: [PATCH 10/21] Update input_to_remove comment for prefill.py --- .../prefill.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py index cd7f7f65..942223d5 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/prefill.py @@ -31,11 +31,29 @@ # During prefill, "past_key_values" not None, but an empty Cache instance. # Passing None makes torch.export happy. -# attention_mask, cache_position -# ------------------------------ -# For npu, ignore captured values generated from example prompt. -input_to_remove = ["past_key_values", "attention_mask", "cache_position"] +input_to_remove = [ + "past_key_values", + # DynamicCache is flatten-able operator since 4.50. + # See _pytree.py > tree_flatten + # SUPPORTED_NODES has *transformers.DynamicCache* + # After flattening, DynamicCache becomes { "key_cache": [] , "value_cache": [ ] } + # dict.value is returne. dict.key is stored in treespec. + # + # On prefill, DynamicCache is empty, and dict is empty after flattening. + # PyTorch removes empty dict! + # If number of args is 4 (including cache), it becomes 3! + # To avoid this error, don't pass empty cache, just pass None. + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. + "cache_position" + # It is the list of cache position like [0, 1, ..., 11]. + # For npu, we always store all values (including pad). +] with torch.no_grad(), RecordingInput(model, input_to_remove=input_to_remove) as rec: outputs = model.generate( From bf534244b0b297bd5c0486cbebf5e2f5bf7c0305 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 24 Jul 2025 12:06:02 +0900 Subject: [PATCH 11/21] Factor out attention fuser to op_circle_attention.py --- .../layer.py | 88 +------------- .../layers.py | 108 ++++-------------- .../operators/op_circle_attention.py | 87 +++++++++++++- 3 files changed, 110 insertions(+), 173 deletions(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 4971fa92..079b2219 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -42,95 +42,15 @@ print(generated_text) -# ATTENTION FUSER +# Convert -from typing import List, Optional - - -@torch.library.impl("circle::attention.llama", "CPU") -def attention_llama_cpu( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - layer_idx, - cache_position, -): - return hidden_states - - -@torch.library.register_fake("circle::attention.llama") -def attention_llama(*args, **kwargs): - ( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - layer_idx, - cache_position, - ) = args - return hidden_states - - -from transformers.cache_utils import DynamicCache +import tico +from tico.serialize.operators.op_circle_attention import llama_attention_forward_adapter from transformers.models.llama.modeling_llama import LlamaAttention - -def forward_adapter( - self: LlamaAttention, - hidden_states: torch.Tensor, - position_embeddings: List[torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[DynamicCache], - cache_position: torch.Tensor, - **kwargs, -): - # past_key_value is a dict with key_cache and value_cache. - # It needs to be decomposed for tico and circle which does not know dict. - key_cache = past_key_value.key_cache # type: ignore[union-attr] - value_cache = past_key_value.value_cache # type: ignore[union-attr] - return ( - torch.ops.circle.attention.llama( - hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - self.o_proj.weight, - position_embeddings[0], # cos - position_embeddings[1], # sin - attention_mask, - # key_cache is a list of cache for each decoder layer. - # Assumtion: key cache is continuous - # - # k_cache[0] | k_cache[1] | ... | k_cache[n] - key_cache[0], - value_cache[0], # Same to value_cache - self.layer_idx, - cache_position, - ), - None, - ) - - -# Tico -import tico +LlamaAttention.forward = llama_attention_forward_adapter model = AutoModelForCausalLM.from_pretrained(model_name) - -LlamaAttention.forward = forward_adapter - model.eval() circle_model = tico.convert(model.model.layers[0], captured_input) circle_model.save(f"tinyllama.layer.attn.circle") diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py index 7bdb2360..da9c4b01 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py @@ -41,96 +41,17 @@ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) print(generated_text) +from typing import Any, Optional, Tuple -# ATTENTION FUSER - -from typing import Any, List, Optional, Tuple - - -@torch.library.impl("circle::attention.llama", "CPU") -def attention_llama_cpu( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - layer_idx, - cache_position, -): - return hidden_states - - -@torch.library.register_fake("circle::attention.llama") -def attention_llama(*args, **kwargs): - ( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - layer_idx, - cache_position, - ) = args - return hidden_states - - -from transformers.cache_utils import Cache, DynamicCache -from transformers.models.llama.modeling_llama import LlamaAttention - - -def forward_adapter( - self: LlamaAttention, - hidden_states: torch.Tensor, - position_embeddings: List[torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[DynamicCache], - cache_position: torch.Tensor, - **kwargs, -): - # past_key_value is a dict with key_cache and value_cache. - # It needs to be decomposed for tico and circle which does not know dict. - key_cache = past_key_value.key_cache # type: ignore[union-attr] - value_cache = past_key_value.value_cache # type: ignore[union-attr] - return ( - torch.ops.circle.attention.llama( - hidden_states, - self.q_proj.weight, - self.k_proj.weight, - self.v_proj.weight, - self.o_proj.weight, - position_embeddings[0], # cos - position_embeddings[1], # sin - attention_mask, - # key_cache is a list of cache for each decoder layer. - # Assumtion: key cache is continuous - # - # k_cache[0] | k_cache[1] | ... | k_cache[n] - key_cache[0], - value_cache[0], # Same to value_cache - self.layer_idx, - cache_position, - ), - None, - ) - - -# Tico -import tico +# Define DecoderLayers from torch import nn -from transformers.models.llama.modeling_llama import LlamaModel +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel -model = AutoModelForCausalLM.from_pretrained(model_name) + +# DecoderLayers is not nn.Module. Not torch.export-able. +# Let's define decoder layers as nn.Module. class LlamaDecoderLayers(nn.Module): @@ -139,6 +60,8 @@ def __init__(self, model: LlamaModel): self.config = model.config self.layers = model.layers + # Make sure signature is same to capturing input. + # Just copy and Paste from LlamaDecoderLayer::forward def forward( self, hidden_states: torch.Tensor, @@ -169,8 +92,19 @@ def forward( return hidden_states +# Convert + +import tico + +# NOTE: +# If you want to restore forward, it may be implemented as context manager. +# However, it is just a simple script to export. No one uses forward after tico conversion. +from tico.serialize.operators.op_circle_attention import llama_attention_forward_adapter + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) layers = LlamaDecoderLayers(model.model) -LlamaAttention.forward = forward_adapter layers.eval() circle_model = tico.convert(layers, captured_input) circle_model.save(f"tinyllama.layers.attn.circle") diff --git a/tico/serialize/operators/op_circle_attention.py b/tico/serialize/operators/op_circle_attention.py index 3d6d73ff..b8147940 100644 --- a/tico/serialize/operators/op_circle_attention.py +++ b/tico/serialize/operators/op_circle_attention.py @@ -20,13 +20,13 @@ import torch from circle_schema import circle -from torch.library import Library - from tico.serialize.circle_graph import CircleSubgraph, extract_shape from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from torch.library import Library + lib = Library("circle", "DEF") lib.define( """ @@ -47,6 +47,89 @@ """ ) +# ATTENTION FUSER + +from typing import List, Optional + + +@torch.library.impl("circle::attention.llama", "CPU") +def attention_llama_cpu( + hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, +): + return hidden_states + + +@torch.library.register_fake("circle::attention.llama") +def attention_llama(*args, **kwargs): + ( + hidden_states, + q_proj, + k_proj, + v_proj, + o_proj, + position_cos, + position_sin, + attention_mask, + past_key, + past_value, + layer_idx, + cache_position, + ) = args + return hidden_states + + +from typing import List, Optional + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def llama_attention_forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[DynamicCache], + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] + return ( + torch.ops.circle.attention.llama( + hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + # key_cache is a list of cache for each decoder layer. + # Assumtion: key cache is continuous + # + # k_cache[0] | k_cache[1] | ... | k_cache[n] + key_cache[0], + value_cache[0], # Same to value_cache + self.layer_idx, + cache_position, + ), + None, + ) + @register_node_visitor class AttentionVisitor(NodeVisitor): From 275f3988248ef9287808c6c163f2ae5f717695b3 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 29 Jul 2025 06:33:24 +0900 Subject: [PATCH 12/21] move op_circle_attention.py to onert/op_attention.py --- .../LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py | 2 +- .../LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py | 2 +- .../operators/{op_circle_attention.py => onert/op_attention.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename tico/serialize/operators/{op_circle_attention.py => onert/op_attention.py} (100%) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 079b2219..9a84e171 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -45,7 +45,7 @@ # Convert import tico -from tico.serialize.operators.op_circle_attention import llama_attention_forward_adapter +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter from transformers.models.llama.modeling_llama import LlamaAttention LlamaAttention.forward = llama_attention_forward_adapter diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py index da9c4b01..74e3db4e 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layers.py @@ -99,7 +99,7 @@ def forward( # NOTE: # If you want to restore forward, it may be implemented as context manager. # However, it is just a simple script to export. No one uses forward after tico conversion. -from tico.serialize.operators.op_circle_attention import llama_attention_forward_adapter +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter LlamaAttention.forward = llama_attention_forward_adapter diff --git a/tico/serialize/operators/op_circle_attention.py b/tico/serialize/operators/onert/op_attention.py similarity index 100% rename from tico/serialize/operators/op_circle_attention.py rename to tico/serialize/operators/onert/op_attention.py From 3aaca064e45bfe4b6c2de6e79134e0050dccc818 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 31 Jul 2025 09:37:52 +0900 Subject: [PATCH 13/21] remove unused import from op_attention.py --- tico/serialize/operators/onert/op_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index b8147940..4b7c126e 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -20,7 +20,7 @@ import torch from circle_schema import circle -from tico.serialize.circle_graph import CircleSubgraph, extract_shape +from tico.serialize.circle_graph import CircleSubgraph from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index From a04ff6f3e26fd4b6ac498773b6758e67f0eaed33 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 5 Aug 2025 10:57:13 +0900 Subject: [PATCH 14/21] Adjust input prompt size and kv_cache size = 12 --- .../LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py index 9a84e171..4b035adc 100644 --- a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py @@ -12,7 +12,7 @@ prompt, return_tensors="pt", padding="max_length", - max_length=32, + max_length=31, truncation=True, ) From b190777b77ea3cd96247b9da7641f00e6020f73f Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 5 Aug 2025 10:58:25 +0900 Subject: [PATCH 15/21] remove @torch.library.impl("circle::attention.llama", "CPU") library.implf for "CPU" turned out not necessary. --- .../serialize/operators/onert/op_attention.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index 4b7c126e..7ef18acb 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -49,26 +49,6 @@ # ATTENTION FUSER -from typing import List, Optional - - -@torch.library.impl("circle::attention.llama", "CPU") -def attention_llama_cpu( - hidden_states, - q_proj, - k_proj, - v_proj, - o_proj, - position_cos, - position_sin, - attention_mask, - past_key, - past_value, - layer_idx, - cache_position, -): - return hidden_states - @torch.library.register_fake("circle::attention.llama") def attention_llama(*args, **kwargs): From e40fad593b4fb8cb774e544e77ff193fcb3cdd6f Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 13 Aug 2025 17:50:06 +0900 Subject: [PATCH 16/21] Remove attention_mask and make kv_cache mandatory, not optional --- tico/serialize/operators/onert/op_attention.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index 7ef18acb..989e3db8 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -38,7 +38,6 @@ Tensor wo, Tensor position_cos, Tensor position_sin, - Tensor? attention_mask, Tensor past_key, Tensor past_value, int layer_idx, @@ -60,7 +59,6 @@ def attention_llama(*args, **kwargs): o_proj, position_cos, position_sin, - attention_mask, past_key, past_value, layer_idx, @@ -69,7 +67,7 @@ def attention_llama(*args, **kwargs): return hidden_states -from typing import List, Optional +from typing import List from transformers.cache_utils import DynamicCache from transformers.models.llama.modeling_llama import LlamaAttention @@ -79,8 +77,7 @@ def llama_attention_forward_adapter( self: LlamaAttention, hidden_states: torch.Tensor, position_embeddings: List[torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[DynamicCache], + past_key_value: DynamicCache, cache_position: torch.Tensor, **kwargs, ): @@ -97,13 +94,12 @@ def llama_attention_forward_adapter( self.o_proj.weight, position_embeddings[0], # cos position_embeddings[1], # sin - attention_mask, # key_cache is a list of cache for each decoder layer. # Assumtion: key cache is continuous # # k_cache[0] | k_cache[1] | ... | k_cache[n] - key_cache[0], - value_cache[0], # Same to value_cache + key_cache[self.layer_idx], + value_cache[self.layer_idx], # Same to value_cache self.layer_idx, cache_position, ), @@ -132,7 +128,6 @@ def define_node( wo, position_cos, position_sin, - attention_mask, past_key, past_value, cache_position, From 8d9c2f73ebe3587aa0215dcdcee84122a09857a9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Wed, 13 Aug 2025 17:51:53 +0900 Subject: [PATCH 17/21] add decode.py to export LlamaModel decode phase --- .../decode.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py diff --git a/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py new file mode 100644 index 00000000..7ec412d9 --- /dev/null +++ b/test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/decode.py @@ -0,0 +1,69 @@ +# User input +prompt = "Lily picked up a flower." +model_name = "Maykeye/TinyLLama-v0" + +# Tokenizer +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" +inputs = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=30, + truncation=True, +) + +# Generator +import torch + +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() + +from tico.utils.record_input import RecordingInput + +# past_key_values +# --------------- +# During prefill, "past_key_values" not None, but an empty Cache instance. +# Passing None makes torch.export happy. + + +input_to_remove = [ + "attention_mask", + # For left pad, [0, ⋯, 0, 1, ⋯, 1] + # For right right pad, [1, ⋯, 1, 0, ⋯, 0] + # ( 0 is pad-token ) + # This script uses right pad and pass all-1 attention mask (including pad). + # Npu computes all positions whether it is pad or not. +] +condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0 + +with torch.no_grad(), RecordingInput( + model, condition_fn, input_to_remove=input_to_remove +) as rec: + outputs = model.generate( + **inputs, + max_new_tokens=32, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + captured_input = rec.captured_input + +generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) +print(generated_text) + +# Tico +import tico +from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter +from transformers.models.llama.modeling_llama import LlamaAttention + +LlamaAttention.forward = llama_attention_forward_adapter + +model = AutoModelForCausalLM.from_pretrained(model_name) +model.eval() +circle_model = tico.convert(model, captured_input) +circle_model.save(f"tinyllama.decode.circle") From 72cd40752d6c43fb51905b1a3a4c3e676ec214ff Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Thu, 14 Aug 2025 11:36:30 +0900 Subject: [PATCH 18/21] Restore attention_mask Causal attention_mask can be calculated in op_attention. However, op_attention is not the proper place because the cos and sin table needs to be calculated again and again in each attention layer while the cos and sin table is sharable between decode layers. --- tico/serialize/operators/onert/op_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index 989e3db8..4ad46424 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -38,6 +38,7 @@ Tensor wo, Tensor position_cos, Tensor position_sin, + Tensor attention_mask, Tensor past_key, Tensor past_value, int layer_idx, @@ -59,6 +60,7 @@ def attention_llama(*args, **kwargs): o_proj, position_cos, position_sin, + attention_mask, past_key, past_value, layer_idx, @@ -67,7 +69,7 @@ def attention_llama(*args, **kwargs): return hidden_states -from typing import List +from typing import List, Optional from transformers.cache_utils import DynamicCache from transformers.models.llama.modeling_llama import LlamaAttention @@ -77,6 +79,7 @@ def llama_attention_forward_adapter( self: LlamaAttention, hidden_states: torch.Tensor, position_embeddings: List[torch.Tensor], + attention_mask: torch.Tensor, past_key_value: DynamicCache, cache_position: torch.Tensor, **kwargs, @@ -94,6 +97,7 @@ def llama_attention_forward_adapter( self.o_proj.weight, position_embeddings[0], # cos position_embeddings[1], # sin + attention_mask, # key_cache is a list of cache for each decoder layer. # Assumtion: key cache is continuous # @@ -128,6 +132,7 @@ def define_node( wo, position_cos, position_sin, + attention_mask, past_key, past_value, cache_position, From 22910c2e3182b0c8e2e7bdcaaeff53676b2eec4a Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Fri, 12 Sep 2025 10:25:58 +0900 Subject: [PATCH 19/21] Fix wrong arg order and move layer_idx from inputs to params --- tico/serialize/operators/onert/op_attention.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index 4ad46424..aa565a84 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -41,8 +41,8 @@ Tensor attention_mask, Tensor past_key, Tensor past_value, - int layer_idx, - Tensor cache_position + Tensor cache_position, + int layer_idx ) -> Tensor """ ) @@ -63,8 +63,8 @@ def attention_llama(*args, **kwargs): attention_mask, past_key, past_value, - layer_idx, cache_position, + layer_idx, ) = args return hidden_states @@ -104,8 +104,8 @@ def llama_attention_forward_adapter( # k_cache[0] | k_cache[1] | ... | k_cache[n] key_cache[self.layer_idx], value_cache[self.layer_idx], # Same to value_cache - self.layer_idx, cache_position, + self.layer_idx, ), None, ) @@ -143,7 +143,9 @@ def define_node( circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes ) - inputs = node.args + # remove last arg (= layer_idx) from inputs. + # layer_idx is attention op's param, not input. + inputs = node.args[:-1] outputs = [node] operator = create_builtin_operator(self.graph, op_index, inputs, outputs) From 8ac74b76cc44a8bd6ae86976bcc59a6f5388ea93 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 4 Nov 2025 14:09:04 +0900 Subject: [PATCH 20/21] remove layer_idx --- tico/serialize/operators/onert/op_attention.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tico/serialize/operators/onert/op_attention.py b/tico/serialize/operators/onert/op_attention.py index aa565a84..c52f5ee3 100644 --- a/tico/serialize/operators/onert/op_attention.py +++ b/tico/serialize/operators/onert/op_attention.py @@ -41,8 +41,7 @@ Tensor attention_mask, Tensor past_key, Tensor past_value, - Tensor cache_position, - int layer_idx + Tensor cache_position ) -> Tensor """ ) @@ -64,7 +63,6 @@ def attention_llama(*args, **kwargs): past_key, past_value, cache_position, - layer_idx, ) = args return hidden_states @@ -105,7 +103,6 @@ def llama_attention_forward_adapter( key_cache[self.layer_idx], value_cache[self.layer_idx], # Same to value_cache cache_position, - self.layer_idx, ), None, ) @@ -136,7 +133,6 @@ def define_node( past_key, past_value, cache_position, - layer_idx, ) = node.args op_index = get_op_index( @@ -153,9 +149,6 @@ def define_node( operator.builtinOptionsType = ( circle.BuiltinOptions.BuiltinOptions.AttentionOptions ) - option = circle.AttentionOptions.AttentionOptionsT() - option.layer_idx = layer_idx - - operator.builtinOptions = option + operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT() return operator From 9efa26fb3bdb7f7b35515e3db3e04a759b684c38 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 4 Nov 2025 14:11:16 +0900 Subject: [PATCH 21/21] Remove remove_unused_inputs pass --- tico/passes/remove_unused_inputs.py | 51 ----------------------------- tico/utils/convert.py | 2 -- 2 files changed, 53 deletions(-) delete mode 100644 tico/passes/remove_unused_inputs.py diff --git a/tico/passes/remove_unused_inputs.py b/tico/passes/remove_unused_inputs.py deleted file mode 100644 index 14bbcbc7..00000000 --- a/tico/passes/remove_unused_inputs.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import torch.fx -import torch -from torch.export import ExportedProgram - -from tico.passes import ops -from tico.utils import logging -from tico.utils.passes import PassBase, PassResult -from tico.utils.trace_decorators import trace_graph_diff_on_pass - - -@trace_graph_diff_on_pass -class RemoveUnusedInput(PassBase): - """ - Let's remove dead inputs - """ - - def __init__(self): - super().__init__() - - def call(self, exported_program: ExportedProgram) -> PassResult: - logger = logging.getLogger(__name__) - - graph_module = exported_program.graph_module - graph = graph_module.graph - modified = False - for node in graph.nodes: - if node.op == "placeholder" and len(node.users) == 0: - graph.erase_node(node) - modified = True - - graph.lint() - graph_module.recompile() - - return PassResult(modified) diff --git a/tico/utils/convert.py b/tico/utils/convert.py index d248edd7..7ac47f25 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -58,7 +58,6 @@ from tico.passes.remove_redundant_reshape import passes as RemoveRedundantViewPasses from tico.passes.remove_redundant_slice import RemoveRedundantSlice from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy -from tico.passes.remove_unused_inputs import RemoveUnusedInput from tico.passes.restore_linear import RestoreLinear from tico.passes.segment_index_select import SegmentIndexSelectConst from tico.quantization.passes.fold_quant_ops import FoldQuantOps @@ -262,7 +261,6 @@ def convert_exported_module_to_circle( *LowerToSlicePasses(), FuseLeadingUnsqueezeReshape(), CastClampMixedTypeArgs(), - RemoveUnusedInput(), ] ) circle_legalize.run(exported_program)