Skip to content

Fuse LlamaRMSNorm class to Circle RMSNorm op#266

Merged
jinevening merged 9 commits intoSamsung:mainfrom
seockho-kim:fuse_rmsnorm
Aug 11, 2025
Merged

Fuse LlamaRMSNorm class to Circle RMSNorm op#266
jinevening merged 9 commits intoSamsung:mainfrom
seockho-kim:fuse_rmsnorm

Conversation

@seockho-kim
Copy link
Contributor

This shows how to fuse LlamaRMSNorm class to Circle RMSNorm operation.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com

Like #217,
it does not check patterns, but replaces LlamaRMSNorm with custom op.

This commit fuse LlamaRMSNorm class to Circle RMSNorm op

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
This commit fixes format with lint.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
@seockho-kim seockho-kim requested a review from a team August 6, 2025 00:41


def CircleRMSNorm():
@custom_op("circle_custom::rms_norm", mutates_args=())
Copy link
Contributor

@glistening glistening Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rms_norm is circle builtin-op. I think circle::rms_norm is enough. In my op_attention case, @jinevening preferred onert prefix. I don't know the clear rule. Maybe new op which did not exist in tflite and if it is going to run in cpu backend (not triv npu), it is onert. @jinevening Is it right? What prefix do you prefer for rms_norm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've followed the name of other custom op like instance_norm. It's also circle builtin-op.

Copy link
Contributor

@glistening glistening Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, instance norm is not a custom op. I guess someone wanted to distinguish circle-only op from tflite-circle-common op. (why? 🤔)

Copy link
Contributor Author

@seockho-kim seockho-kim Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In register_custom_op.py

def CircleInstanceNorm():
    @custom_op("circle_custom::instance_norm", mutates_args=())
    def instance_norm(
        input_: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        running_mean: Optional[torch.Tensor] = None,
        running_var: Optional[torch.Tensor] = None,
        use_input_stats: bool = False,
        momentum: float = 0.1,
        eps: float = 1e-05,
        cudnn_enabled: bool = False,
    ) -> torch.Tensor:
        NHWC_to_NCHW = [0, 3, 1, 2]
        NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)

        args = [NCHW_input, weight, bias, None, None, False, momentum, eps, False]
        NCHW_output = torch.ops.aten.instance_norm.default(*args)
        NCHW_to_NHWC = [0, 2, 3, 1]
        NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)

        return NHWC_output
......

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim I already understood — some TICO developer wants to define circle built-in op InstanceNorm as custom in TICO's view. I am wondering why? If any (though I don't find) reason to distinguish them, circle_ext would be better one in my personal view, which is not confused with other custom_op in circle_schema.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I don't have any idea why it is named like that. :)
I agree circle_custom is a little confusing with custom_op in circle_schema.

Copy link
Contributor

@jinevening jinevening Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, instance norm is not a custom op. I guess someone wanted to distinguish circle-only op from tflite-circle-common op. (why? 🤔)

There are tflite-circle-common Ops too (circle_custom.conv2d, circle_custom.maxpool2d, ..).

circle_custom is just a namespace for circle Ops. It would be ok to change the namespace to circle as you suggested (not in this PR) @mhs4670go AFAIK, you made circle_custom. Is it ok to change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I added _custom prefix because this is related with torch "custom" operator creation. Just torch.ops.circle looks good as well. Feel free to change them in another PR.

@glistening
Copy link
Contributor

@seockho-kim Why do you want to fuse rmsnorm? For npu compiler? onert? or something else?

@seockho-kim
Copy link
Contributor Author

@seockho-kim Why do you want to fuse rmsnorm? For npu compiler? onert? or something else?

For npu compiler,
and I'm trying to find a way to fuse rmsnorm for quantized model.

This applies review comments
- It uses contextmanager
- Useless format change removed
- Custom RMSNorm Args defined.
- register_dynamic_cache() added.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
Comment on lines 27 to 40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving these patcher under tico project, not inside test directory?

This commit applies the review comments.
- RMSNormCustomArgs is changed to CircleRMSNormArgs
- Patcher is moved to tico utils from test.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
This commit fixes format error.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
@seockho-kim
Copy link
Contributor Author

seockho-kim commented Aug 6, 2025

  • Without Fusing
image
  • Fused RMSNorm
image

@seockho-kim seockho-kim marked this pull request as ready for review August 8, 2025 00:32
This commit assigns exact transformers version.
It cannot work with latest version(ex.4.53)

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
try:
yield
finally:
LlamaRMSNorm.forward = orig
Copy link
Contributor

@glistening glistening Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure It is good idea to put patched_llama_rms_norm and related things in utils/patcher.py.

As the operators to fuse grow, patcher.py gets more and more dependencies.

First, models (not only modeling_llama, but also modeling_florence, modeling_something_else, ...).

Second, in same model (e.g. llama), there will be multiple ops to fuse (e.g. attention and so on).

It would be better to break these by operators.

Thus, in my implementation of fusing attention (#217), I put the attention related adapters in op_attention.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it may be complicated if we need to support other ops.
I've referred to your attention implementation, but I'm not sure it is a good way to include an adapter in op code.
The op code has dependency with model code, so I thought it need to be separated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we talked in offline, I'm going to move patcher to each operator, but separated files.
like tico/serialize/operators/adapters/adapter_rmsnorm.py

- Move patcher.py to serialize/operators/adapters/rmsnorm.py

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
This adds __init__.py to adapter folder to make it a package.

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
glistening
glistening previously approved these changes Aug 8, 2025
Copy link
Contributor

@glistening glistening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM



@contextmanager
def patched_llama_rmsnorm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specific to llama model, so I think it would be better to rename the file to llama_rmsnorm.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll update it.

Copy link
Contributor

@glistening glistening Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought rmsnorm.py will be used as a collection of adapters for several rmsnorms (including llama).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, I thought like @glistening has explained.
But @jinevening 's suggestion would be good in terms of SRP.

class TinyLlamaWithFusedRMSNorm(TestModuleBase):
def __init__(self):
super().__init__()
with patched_llama_rmsnorm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we patch multiple modules? For example, how can we patch both LlamaRMSNorm and LlamaAttention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think we can use same approach. (not tested)

@contextmanager
def patched_llama_modules():
    with patched_llama_rmsnorm(), patched_llama_attention():
        yield


class TinyLlamaWithFusedRMSNorm(TestModuleBase):
    def __init__(self):
        super().__init__()
        with patched_llama_modules():
            self.model = AutoModelForCausalLM.from_pretrained(
                "Maykeye/TinyLLama-v0"
            ).to("cpu")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim Yes, I think the same way.

This commit renamed rmsnorm adapter file to llama_rmsnorm.py

TICO-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
Copy link
Contributor

@jinevening jinevening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@glistening glistening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@dayo09 dayo09 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jinevening jinevening merged commit d0f37b6 into Samsung:main Aug 11, 2025
6 checks passed
@seockho-kim seockho-kim deleted the fuse_rmsnorm branch August 11, 2025 06:43
class TinyLlamaWithFusedRMSNorm(TestModuleBase):
def __init__(self):
super().__init__()
with patched_llama_rmsnorm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. seems that this code doesn't work well. Because the with statement ends before exporting a module. I'll patch this code soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it really doesn't work, but I'm curious how it worked before.
FYI, #304 is another way to fuse rmsnorm and it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Comments