diff --git a/test/modules/model/TinyLlamaWithFusedRMSNorm/__init__.py b/test/modules/model/TinyLlamaWithFusedRMSNorm/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedRMSNorm/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py b/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py new file mode 100644 index 00000000..c2af79e3 --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm +from tico.utils.pytree_utils import register_dynamic_cache + +from transformers import AutoModelForCausalLM + +from test.modules.base import TestModuleBase + + +class TinyLlamaWithFusedRMSNorm(TestModuleBase): + def __init__(self): + super().__init__() + with patched_llama_rmsnorm(): + self.model = AutoModelForCausalLM.from_pretrained( + "Maykeye/TinyLLama-v0" + ).to("cpu") + self.rtol = 1e-4 + self.atol = 1e-4 + register_dynamic_cache() + + def forward(self, x): + return self.model(x) + + def get_example_inputs(self): + # >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True) + # >>> tokenizer.encode("Hello .") # 869 is '▁.' + # [1, 15043, 29871, 1, 869] + return (torch.Tensor([[1, 15043, 29871, 1, 869]]).to(dtype=torch.int32),), {} diff --git a/test/modules/model/TinyLlamaWithFusedRMSNorm/requirements.txt b/test/modules/model/TinyLlamaWithFusedRMSNorm/requirements.txt new file mode 100644 index 00000000..1e4043b8 --- /dev/null +++ b/test/modules/model/TinyLlamaWithFusedRMSNorm/requirements.txt @@ -0,0 +1 @@ +transformers==4.52.4 diff --git a/tico/serialize/operators/adapters/__init__.py b/tico/serialize/operators/adapters/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/tico/serialize/operators/adapters/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/tico/serialize/operators/adapters/llama_rmsnorm.py b/tico/serialize/operators/adapters/llama_rmsnorm.py new file mode 100644 index 00000000..7806f1ea --- /dev/null +++ b/tico/serialize/operators/adapters/llama_rmsnorm.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +import torch + +from transformers.models.llama.modeling_llama import LlamaRMSNorm + + +def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return torch.ops.circle_custom.rms_norm( + hidden_states, self.weight, self.variance_epsilon + ) + + +@contextmanager +def patched_llama_rmsnorm(): + orig = LlamaRMSNorm.forward + LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter + try: + yield + finally: + LlamaRMSNorm.forward = orig diff --git a/tico/serialize/operators/op_rmsnorm.py b/tico/serialize/operators/op_rmsnorm.py new file mode 100644 index 00000000..835141e4 --- /dev/null +++ b/tico/serialize/operators/op_rmsnorm.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch._ops + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.utils.validate_args_kwargs import CircleRMSNormArgs + + +@register_node_visitor +class RMSNormVisitor(NodeVisitor): + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.rms_norm.default, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + input = args.input + weight = args.weight + eps = args.eps + + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes + ) + + inputs = [input, weight] + outputs = [node] + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + # Op-specific option + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.RmsNormOptions + ) + option = circle.RmsNormOptions.RmsNormOptionsT() + option.epsilon = eps + + operator.builtinOptions = option + + return operator diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 23dbc3d9..48372e0c 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -703,6 +703,28 @@ def _( return input_ +def CircleRMSNorm(): + @custom_op("circle_custom::rms_norm", mutates_args=()) + def rms_norm( + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + eps: float = 1e-05, + ) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + return weight * hidden_states.to(input_dtype) + + @register_fake("circle_custom::rms_norm") + def _( + hidden_states: torch.Tensor, + weight: Optional[torch.Tensor] = None, + eps: float = 1e-05, + ) -> torch.Tensor: + return hidden_states.new_empty(hidden_states.size()) + + # Add custom ops to the torch namespace def RegisterOps(): CircleResizeNearestNeighbor() @@ -715,3 +737,4 @@ def RegisterOps(): CircleAvgPool2D() CircleInstanceNorm() CircleQuantizeMX() + CircleRMSNorm() diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 03f5e5b4..f6badfff 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -171,6 +171,19 @@ class CatArgs: dim: int = 0 +@enforce_type +@dataclass +class CircleRMSNormArgs: + """ + This is not aten ops but custom op for RMSNorm. + circle_custom.rms_norm(Tensor input, Tensor? weight=None, float? eps=None) -> Tensor + """ + + input: torch.fx.Node + weight: Optional[torch.fx.Node] + eps: Optional[float] + + @enforce_type @dataclass class ClampArgs: @@ -931,6 +944,19 @@ class ResizeNearestNeighborArgs: size: List[int] +@enforce_type +@dataclass +class RMSNormArgs: + """ + rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + """ + + input: torch.fx.Node + normalized_shape: List[int] + weight: Optional[torch.fx.Node] + eps: Optional[float] + + @enforce_type @dataclass class RoundArgs: