Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/modules/model/TinyLlamaWithFusedRMSNorm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# DO NOT REMOVE THIS FILE
43 changes: 43 additions & 0 deletions test/modules/model/TinyLlamaWithFusedRMSNorm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm
from tico.utils.pytree_utils import register_dynamic_cache

from transformers import AutoModelForCausalLM

from test.modules.base import TestModuleBase


class TinyLlamaWithFusedRMSNorm(TestModuleBase):
def __init__(self):
super().__init__()
with patched_llama_rmsnorm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we patch multiple modules? For example, how can we patch both LlamaRMSNorm and LlamaAttention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think we can use same approach. (not tested)

@contextmanager
def patched_llama_modules():
    with patched_llama_rmsnorm(), patched_llama_attention():
        yield


class TinyLlamaWithFusedRMSNorm(TestModuleBase):
    def __init__(self):
        super().__init__()
        with patched_llama_modules():
            self.model = AutoModelForCausalLM.from_pretrained(
                "Maykeye/TinyLLama-v0"
            ).to("cpu")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim Yes, I think the same way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. seems that this code doesn't work well. Because the with statement ends before exporting a module. I'll patch this code soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it really doesn't work, but I'm curious how it worked before.
FYI, #304 is another way to fuse rmsnorm and it works.

self.model = AutoModelForCausalLM.from_pretrained(
"Maykeye/TinyLLama-v0"
).to("cpu")
self.rtol = 1e-4
self.atol = 1e-4
register_dynamic_cache()

def forward(self, x):
return self.model(x)

def get_example_inputs(self):
# >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True)
# >>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
# [1, 15043, 29871, 1, 869]
return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers==4.52.4
1 change: 1 addition & 0 deletions tico/serialize/operators/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# DO NOT REMOVE THIS FILE
35 changes: 35 additions & 0 deletions tico/serialize/operators/adapters/llama_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager

import torch

from transformers.models.llama.modeling_llama import LlamaRMSNorm


def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return torch.ops.circle_custom.rms_norm(
hidden_states, self.weight, self.variance_epsilon
)


@contextmanager
def patched_llama_rmsnorm():
orig = LlamaRMSNorm.forward
LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter
try:
yield
finally:
LlamaRMSNorm.forward = orig
65 changes: 65 additions & 0 deletions tico/serialize/operators/op_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, TYPE_CHECKING

if TYPE_CHECKING:
import torch._ops
import torch.fx
import torch
from circle_schema import circle

from tico.serialize.circle_graph import CircleSubgraph
from tico.serialize.operators.hashable_opcode import OpCode
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
from tico.utils.validate_args_kwargs import CircleRMSNormArgs


@register_node_visitor
class RMSNormVisitor(NodeVisitor):
target: List[torch._ops.OpOverload] = [
torch.ops.circle_custom.rms_norm.default,
]

def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
super().__init__(op_codes, graph)

def define_node(
self,
node: torch.fx.Node,
) -> circle.Operator.OperatorT:
args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
input = args.input
weight = args.weight
eps = args.eps

op_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
)

inputs = [input, weight]
outputs = [node]
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)

# Op-specific option
operator.builtinOptionsType = (
circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
)
option = circle.RmsNormOptions.RmsNormOptionsT()
option.epsilon = eps

operator.builtinOptions = option

return operator
23 changes: 23 additions & 0 deletions tico/utils/register_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,28 @@ def _(
return input_


def CircleRMSNorm():
@custom_op("circle_custom::rms_norm", mutates_args=())
Copy link
Contributor

@glistening glistening Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rms_norm is circle builtin-op. I think circle::rms_norm is enough. In my op_attention case, @jinevening preferred onert prefix. I don't know the clear rule. Maybe new op which did not exist in tflite and if it is going to run in cpu backend (not triv npu), it is onert. @jinevening Is it right? What prefix do you prefer for rms_norm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've followed the name of other custom op like instance_norm. It's also circle builtin-op.

Copy link
Contributor

@glistening glistening Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, instance norm is not a custom op. I guess someone wanted to distinguish circle-only op from tflite-circle-common op. (why? 🤔)

Copy link
Contributor Author

@seockho-kim seockho-kim Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In register_custom_op.py

def CircleInstanceNorm():
    @custom_op("circle_custom::instance_norm", mutates_args=())
    def instance_norm(
        input_: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        running_mean: Optional[torch.Tensor] = None,
        running_var: Optional[torch.Tensor] = None,
        use_input_stats: bool = False,
        momentum: float = 0.1,
        eps: float = 1e-05,
        cudnn_enabled: bool = False,
    ) -> torch.Tensor:
        NHWC_to_NCHW = [0, 3, 1, 2]
        NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)

        args = [NCHW_input, weight, bias, None, None, False, momentum, eps, False]
        NCHW_output = torch.ops.aten.instance_norm.default(*args)
        NCHW_to_NHWC = [0, 2, 3, 1]
        NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)

        return NHWC_output
......

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim I already understood — some TICO developer wants to define circle built-in op InstanceNorm as custom in TICO's view. I am wondering why? If any (though I don't find) reason to distinguish them, circle_ext would be better one in my personal view, which is not confused with other custom_op in circle_schema.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I don't have any idea why it is named like that. :)
I agree circle_custom is a little confusing with custom_op in circle_schema.

Copy link
Contributor

@jinevening jinevening Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, instance norm is not a custom op. I guess someone wanted to distinguish circle-only op from tflite-circle-common op. (why? 🤔)

There are tflite-circle-common Ops too (circle_custom.conv2d, circle_custom.maxpool2d, ..).

circle_custom is just a namespace for circle Ops. It would be ok to change the namespace to circle as you suggested (not in this PR) @mhs4670go AFAIK, you made circle_custom. Is it ok to change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I added _custom prefix because this is related with torch "custom" operator creation. Just torch.ops.circle looks good as well. Feel free to change them in another PR.

def rms_norm(
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
eps: float = 1e-05,
) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)

@register_fake("circle_custom::rms_norm")
def _(
hidden_states: torch.Tensor,
weight: Optional[torch.Tensor] = None,
eps: float = 1e-05,
) -> torch.Tensor:
return hidden_states.new_empty(hidden_states.size())


# Add custom ops to the torch namespace
def RegisterOps():
CircleResizeNearestNeighbor()
Expand All @@ -715,3 +737,4 @@ def RegisterOps():
CircleAvgPool2D()
CircleInstanceNorm()
CircleQuantizeMX()
CircleRMSNorm()
26 changes: 26 additions & 0 deletions tico/utils/validate_args_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ class CatArgs:
dim: int = 0


@enforce_type
@dataclass
class CircleRMSNormArgs:
"""
This is not aten ops but custom op for RMSNorm.
circle_custom.rms_norm(Tensor input, Tensor? weight=None, float? eps=None) -> Tensor
"""

input: torch.fx.Node
weight: Optional[torch.fx.Node]
eps: Optional[float]


@enforce_type
@dataclass
class ClampArgs:
Expand Down Expand Up @@ -931,6 +944,19 @@ class ResizeNearestNeighborArgs:
size: List[int]


@enforce_type
@dataclass
class RMSNormArgs:
"""
rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
"""

input: torch.fx.Node
normalized_shape: List[int]
weight: Optional[torch.fx.Node]
eps: Optional[float]


@enforce_type
@dataclass
class RoundArgs:
Expand Down