diff --git a/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py b/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py index c2af79e3..dd12a4c1 100644 --- a/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py +++ b/test/modules/model/TinyLlamaWithFusedRMSNorm/model.py @@ -13,11 +13,13 @@ # limitations under the License. import torch +from tico.passes.module_fusion import llama_rmsnorm -from tico.serialize.operators.adapters.llama_rmsnorm import patched_llama_rmsnorm +from tico.passes.module_fusion.fusion_registry import replace_modules_with_fused from tico.utils.pytree_utils import register_dynamic_cache from transformers import AutoModelForCausalLM +from transformers.models.llama.modeling_llama import LlamaRMSNorm from test.modules.base import TestModuleBase @@ -25,12 +27,14 @@ class TinyLlamaWithFusedRMSNorm(TestModuleBase): def __init__(self): super().__init__() - with patched_llama_rmsnorm(): - self.model = AutoModelForCausalLM.from_pretrained( - "Maykeye/TinyLLama-v0" - ).to("cpu") + self.model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0").to( + "cpu" + ) + self.rtol = 1e-4 self.atol = 1e-4 + + replace_modules_with_fused(self.model, [LlamaRMSNorm]) register_dynamic_cache() def forward(self, x): diff --git a/tico/serialize/operators/adapters/__init__.py b/tico/passes/module_fusion/__init__.py similarity index 100% rename from tico/serialize/operators/adapters/__init__.py rename to tico/passes/module_fusion/__init__.py diff --git a/tico/passes/module_fusion/fusion_registry.py b/tico/passes/module_fusion/fusion_registry.py new file mode 100644 index 00000000..90ca7188 --- /dev/null +++ b/tico/passes/module_fusion/fusion_registry.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Type + +import torch.nn as nn + +# Dict with original module classes as keys and fused module classes as values. +# The value can be the fused module class itself, or a factory function that +# takes the original module as an argument and creates a fused module instance +_FUSED_MODULE_MAPPING: Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]] = {} + + +def register_fused_module(original_module_class: Type[nn.Module]): + """ + Decorator to register an original module class and its corresponding factory that creates the fused module + """ + + def decorator(fused_module_factory: Callable[[nn.Module], nn.Module]): + _FUSED_MODULE_MAPPING[original_module_class] = fused_module_factory + return fused_module_factory + + return decorator + + +def get_fused_module_factory( + original_module_class: Type[nn.Module], +) -> Callable[[nn.Module], nn.Module] | None: + """ + Returns the fused module factory corresponding to the registered original module class + """ + return _FUSED_MODULE_MAPPING.get(original_module_class) + + +def replace_modules_with_fused( + model: nn.Module, target_module_classes: List[Type[nn.Module]] +): + """ + Replaces all instances within the model that correspond to target_module_classes + with their fused versions registered in the registry + """ + replaced_count = 0 + for name, module in model.named_modules(): + if type(module) in target_module_classes: + fused_module_factory = get_fused_module_factory(type(module)) + if fused_module_factory: + parent_module_name = ".".join(name.split(".")[:-1]) + module_short_name = name.split(".")[-1] + + parent_module = model + if parent_module_name: + for part in parent_module_name.split("."): + parent_module = getattr(parent_module, part) + + new_module = fused_module_factory(module) + + setattr(parent_module, module_short_name, new_module) + replaced_count += 1 + print( + f"Replaced {name} ({type(module).__name__}) with {type(new_module).__name__}" + ) + else: + print( + f"Warning: No fused module factory registered for {type(module).__name__}. Skipping replacement of {name}." + ) + + if replaced_count > 0: + print(f"Successfully replaced {replaced_count} module instances.") + else: + print("No target module instances found to replace.") diff --git a/tico/serialize/operators/adapters/llama_rmsnorm.py b/tico/passes/module_fusion/llama_rmsnorm.py similarity index 51% rename from tico/serialize/operators/adapters/llama_rmsnorm.py rename to tico/passes/module_fusion/llama_rmsnorm.py index 7806f1ea..77e6eaf2 100644 --- a/tico/serialize/operators/adapters/llama_rmsnorm.py +++ b/tico/passes/module_fusion/llama_rmsnorm.py @@ -12,24 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager - import torch from transformers.models.llama.modeling_llama import LlamaRMSNorm +from .fusion_registry import register_fused_module + + +class FusedLlamaRMSNorm(LlamaRMSNorm): + def __init__(self, original_rmsnorm: LlamaRMSNorm): + super().__init__( + original_rmsnorm.weight.shape[0], original_rmsnorm.variance_epsilon + ) + with torch.no_grad(): + self.weight.copy_(original_rmsnorm.weight) -def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return torch.ops.circle_custom.rms_norm( - hidden_states, self.weight, self.variance_epsilon - ) + def forward(self, hidden_states): + return torch.ops.circle_custom.rms_norm( + hidden_states, self.weight, self.variance_epsilon + ) -@contextmanager -def patched_llama_rmsnorm(): - orig = LlamaRMSNorm.forward - LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter - try: - yield - finally: - LlamaRMSNorm.forward = orig +@register_fused_module(LlamaRMSNorm) +def create_fused_llama_rmsnorm(original_module: LlamaRMSNorm) -> FusedLlamaRMSNorm: + return FusedLlamaRMSNorm(original_module)