Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tico/serialize/operators/adapters/onert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# DO NOT REMOVE THIS FILE
51 changes: 51 additions & 0 deletions tico/serialize/operators/adapters/onert/llama_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, TYPE_CHECKING

import torch

from transformers.cache_utils import DynamicCache
from transformers.models.llama.modeling_llama import LlamaAttention


def llama_attention_forward_adapter(
self: LlamaAttention,
hidden_states: torch.Tensor,
position_embeddings: List[torch.Tensor],
attention_mask: torch.Tensor,
past_key_value: DynamicCache,
cache_position: torch.Tensor,
**kwargs,
):
# past_key_value is a dict with key_cache and value_cache.
# It needs to be decomposed for tico and circle which does not know dict.
key_cache = past_key_value.key_cache # type: ignore[union-attr]
value_cache = past_key_value.value_cache # type: ignore[union-attr]
return (
torch.ops.circle_custom.attention(
hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
self.o_proj.weight,
position_embeddings[0], # cos
position_embeddings[1], # sin
attention_mask,
key_cache[self.layer_idx],
value_cache[self.layer_idx], # Same to value_cache
cache_position,
),
None,
)
58 changes: 58 additions & 0 deletions tico/serialize/operators/op_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, TYPE_CHECKING

if TYPE_CHECKING:
import torch._ops
import torch.fx
import torch
from circle_schema import circle

from tico.serialize.circle_graph import CircleSubgraph
from tico.serialize.operators.hashable_opcode import OpCode
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
from tico.utils.validate_args_kwargs import CircleAttentionArgs


@register_node_visitor
class AttentionVisitor(NodeVisitor):
target: List[torch._ops.OpOverload] = [
torch.ops.circle_custom.attention.default,
]

def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
super().__init__(op_codes, graph)

def define_node(
self,
node: torch.fx.Node,
) -> circle.Operator.OperatorT:
args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
op_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
)

inputs = list(node.args)
outputs = [node]
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)

# Op-specific option
operator.builtinOptionsType = (
circle.BuiltinOptions.BuiltinOptions.AttentionOptions
)
operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT()

return operator
36 changes: 36 additions & 0 deletions tico/utils/register_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,41 @@ def _(
return hidden_states.new_empty(hidden_states.size())


def CircleAttention():
@custom_op("circle_custom::attention", mutates_args=())
def attention(
hidden_states: torch.Tensor,
wq: torch.Tensor,
wk: torch.Tensor,
wv: torch.Tensor,
wo: torch.Tensor,
position_cos: torch.Tensor,
position_sin: torch.Tensor,
attention_mask: torch.Tensor,
past_key: torch.Tensor,
past_value: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
# TODO It is recommended to add corresponding tests after implementing this.
return None # type: ignore[return-value]

@register_fake("circle_custom::attention")
def _(
hidden_states: torch.Tensor,
wq: torch.Tensor,
wk: torch.Tensor,
wv: torch.Tensor,
wo: torch.Tensor,
position_cos: torch.Tensor,
position_sin: torch.Tensor,
attention_mask: torch.Tensor,
past_key: torch.Tensor,
past_value: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
return hidden_states


# Add custom ops to the torch namespace
def RegisterOps():
CircleResizeNearestNeighbor()
Expand All @@ -740,3 +775,4 @@ def RegisterOps():
CircleInstanceNorm()
CircleQuantizeMX()
CircleRMSNorm()
CircleAttention()
20 changes: 20 additions & 0 deletions tico/utils/validate_args_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,26 @@ class CatArgs:
dim: int = 0


@enforce_type
@dataclass
class CircleAttentionArgs:
"""
For circle.BuiltinOperator.BuiltinOperator.ATTENTION
"""

hidden_states: torch.fx.Node
wq: torch.fx.Node
wk: torch.fx.Node
wv: torch.fx.Node
wo: torch.fx.Node
position_cos: torch.fx.Node
position_sin: torch.fx.Node
attention_mask: torch.fx.Node
past_key: torch.fx.Node
past_value: torch.fx.Node
cache_position: torch.fx.Node


@enforce_type
@dataclass
class CircleRMSNormArgs:
Expand Down