diff --git a/tico/serialize/operators/adapters/onert/__init__.py b/tico/serialize/operators/adapters/onert/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/tico/serialize/operators/adapters/onert/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/tico/serialize/operators/adapters/onert/llama_attention.py b/tico/serialize/operators/adapters/onert/llama_attention.py new file mode 100644 index 00000000..2700117d --- /dev/null +++ b/tico/serialize/operators/adapters/onert/llama_attention.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +import torch + +from transformers.cache_utils import DynamicCache +from transformers.models.llama.modeling_llama import LlamaAttention + + +def llama_attention_forward_adapter( + self: LlamaAttention, + hidden_states: torch.Tensor, + position_embeddings: List[torch.Tensor], + attention_mask: torch.Tensor, + past_key_value: DynamicCache, + cache_position: torch.Tensor, + **kwargs, +): + # past_key_value is a dict with key_cache and value_cache. + # It needs to be decomposed for tico and circle which does not know dict. + key_cache = past_key_value.key_cache # type: ignore[union-attr] + value_cache = past_key_value.value_cache # type: ignore[union-attr] + return ( + torch.ops.circle_custom.attention( + hidden_states, + self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight, + self.o_proj.weight, + position_embeddings[0], # cos + position_embeddings[1], # sin + attention_mask, + key_cache[self.layer_idx], + value_cache[self.layer_idx], # Same to value_cache + cache_position, + ), + None, + ) diff --git a/tico/serialize/operators/op_attention.py b/tico/serialize/operators/op_attention.py new file mode 100644 index 00000000..6f4c02e4 --- /dev/null +++ b/tico/serialize/operators/op_attention.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.utils.validate_args_kwargs import CircleAttentionArgs + + +@register_node_visitor +class AttentionVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.attention.default, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes + ) + + inputs = node.args + outputs = [node] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + # Op-specific option + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.AttentionOptions + ) + operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT() + + return operator diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 095cb6a5..7f400047 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -727,6 +727,40 @@ def _( return hidden_states.new_empty(hidden_states.size()) +def CircleAttention(): + @custom_op("circle_custom::attention", mutates_args=()) + def attention( + hidden_states: torch.Tensor, + wq: torch.Tensor, + wk: torch.Tensor, + wv: torch.Tensor, + wo: torch.Tensor, + position_cos: torch.Tensor, + position_sin: torch.Tensor, + attention_mask: torch.Tensor, + past_key: torch.Tensor, + past_value: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + return None + + @register_fake("circle_custom::attention") + def _( + hidden_states: torch.Tensor, + wq: torch.Tensor, + wk: torch.Tensor, + wv: torch.Tensor, + wo: torch.Tensor, + position_cos: torch.Tensor, + position_sin: torch.Tensor, + attention_mask: torch.Tensor, + past_key: torch.Tensor, + past_value: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + return hidden_states + + # Add custom ops to the torch namespace def RegisterOps(): CircleResizeNearestNeighbor() @@ -740,3 +774,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleAttention() diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 8a5feb21..60fad7b8 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -171,6 +171,26 @@ class CatArgs: dim: int = 0 +@enforce_type +@dataclass +class CircleAttentionArgs: + """ + For circle.BuiltinOperator.BuiltinOperator.ATTENTION + """ + + hidden_states: torch.fx.Node + wq: torch.fx.Node + wk: torch.fx.Node + wv: torch.fx.Node + wo: torch.fx.Node + position_cos: torch.fx.Node + position_sin: torch.fx.Node + attention_mask: torch.fx.Node + past_key: torch.fx.Node + past_value: torch.fx.Node + cache_position: torch.fx.Node + + @enforce_type @dataclass class CircleRMSNormArgs: