-
Notifications
You must be signed in to change notification settings - Fork 24
[passes] Add ConvertMatmulToLinear pass #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,10 @@ | |
|
|
||
| import torch | ||
|
|
||
| from tico.config.v1 import CompileConfigV1 | ||
|
|
||
| from test.modules.base import TestModuleBase | ||
| from test.utils.tag import test_negative, use_onert | ||
|
|
||
|
|
||
| class SimpleMatmul(TestModuleBase): | ||
|
|
@@ -27,3 +30,64 @@ def forward(self, lhs, rhs): | |
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(3, 4), torch.randn(4, 5)), {} | ||
|
|
||
|
|
||
| class SimpleMatmulConstRhs(TestModuleBase): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.weight = torch.randn(4, 5) | ||
|
|
||
| def forward(self, lhs): | ||
| out = torch.mm(lhs, self.weight) | ||
| return out | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(3, 4),), {} | ||
|
|
||
|
|
||
| @use_onert | ||
| class SimpleMatmulConstRhsOnert(TestModuleBase): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.weight = torch.randn(4, 5) | ||
|
|
||
| def forward(self, lhs): | ||
| out = torch.mm(lhs, self.weight) | ||
| return out | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(3, 4),), {} | ||
|
|
||
|
|
||
| @use_onert | ||
| @test_negative(expected_err="NNFW_STATUS_ERROR") | ||
| class SimpleMatmulConstLhsOnert(TestModuleBase): | ||
| """ """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.weight = torch.randn(3, 4) | ||
|
|
||
| def forward(self, rhs): | ||
| out = torch.mm(self.weight, rhs) | ||
| return out | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(4, 5),), {} | ||
|
|
||
|
|
||
| @use_onert | ||
| class SimpleMatmulConstLhsOnertWithLinearConversion(TestModuleBase): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.weight = torch.randn(3, 4) | ||
|
|
||
| def forward(self, rhs): | ||
| out = torch.mm(self.weight, rhs) | ||
| return out | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(4, 5),), {} | ||
|
|
||
| def get_compile_config(self): | ||
| return CompileConfigV1(convert_lhs_const_mm_to_fc=True) | ||
|
Comment on lines
+92
to
+93
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @glistening @seockho-kim Using this compile config will enable matmul op with lhs const node conversion.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dayo09 Could you improve it to handle bmm, too? @tag.use_onert
class BmmTest(TestModuleBase):
def __init__(self):
super().__init__()
self.weight = torch.randn(2, 3, 4)
def forward(self, rhs):
out = self.weight @ rhs
return out
def get_example_inputs(self):
return (torch.randn(2, 4, 5),), {}
def get_compile_config(self):
return CompileConfigV1(convert_lhs_const_mm_to_fc=True)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @seockho-kim Above case is not supported because matmul to fc conversion can be done only if weight is 2dim. Circle FullyConnected operation assumes its weight to be in rank 2.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry, I gave you wrong example.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me add it in the next PR ! |
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,8 @@ | |
| class CompileConfigV1(CompileConfigBase): | ||
| legalize_causal_mask_value: bool = False | ||
| remove_constant_input: bool = False | ||
| convert_lhs_const_mm_to_fc: bool = False | ||
| convert_rhs_const_mm_to_fc: bool = True | ||
|
Comment on lines
+24
to
+25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, just
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rhs_const_mm_to_fc doesn't have trade-off because tranpose is foldable to const, but lhs_const_mm_to_fc requires potential latency trade-off. Therefore, the user needs separate decisions on each case. |
||
|
|
||
| def get(self, name: str): | ||
| return super().get(name) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import List, Optional, TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import torch.fx | ||
| import torch | ||
| from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param | ||
| from torch.export import ExportedProgram | ||
|
|
||
| from tico.utils import logging | ||
| from tico.utils.graph import create_node | ||
| from tico.utils.passes import PassBase, PassResult | ||
| from tico.utils.trace_decorators import trace_graph_diff_on_pass | ||
| from tico.utils.validate_args_kwargs import MatmulArgs | ||
|
|
||
|
|
||
| class Converter: # type: ignore[empty-body] | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def match(self, exported_program, node) -> bool: # type: ignore[empty-body] | ||
| return False | ||
|
|
||
| def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body] | ||
| pass | ||
|
|
||
|
|
||
| class MatmulToLinearConverter(Converter): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def convert(self, exported_program, node) -> torch.fx.Node: | ||
| graph_module = exported_program.graph_module | ||
| graph = graph_module.graph | ||
|
|
||
| mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] | ||
|
|
||
| lhs = mm_args.input | ||
| rhs = mm_args.other | ||
|
|
||
| with graph.inserting_before(node): | ||
| transpose_node = create_node( | ||
| graph, | ||
| torch.ops.aten.permute.default, | ||
| args=(rhs, [1, 0]), | ||
| ) | ||
| fc_node = create_node( | ||
| graph, | ||
| torch.ops.aten.linear.default, | ||
| args=(lhs, transpose_node), | ||
| ) | ||
| node.replace_all_uses_with(fc_node, propagate_meta=True) | ||
|
|
||
| return fc_node | ||
|
|
||
|
|
||
| class RhsConstMatmulToLinearConverter(MatmulToLinearConverter): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def match(self, exported_program, node) -> bool: | ||
| if not node.target == torch.ops.aten.mm.default: | ||
| return False | ||
|
|
||
| mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] | ||
|
|
||
| rhs = mm_args.other | ||
| if isinstance(rhs, torch.fx.Node): | ||
| if is_lifted_tensor_constant(exported_program, rhs): | ||
| return True | ||
| elif is_param(exported_program, rhs): | ||
| return True | ||
| elif is_buffer(exported_program, rhs): | ||
| return True | ||
| else: | ||
| return False | ||
| return False | ||
|
|
||
| def convert(self, exported_program, node) -> torch.fx.Node: | ||
| return super().convert(exported_program, node) | ||
|
|
||
|
|
||
| class LhsConstMatmulToLinearConverter(MatmulToLinearConverter): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def match(self, exported_program, node) -> bool: | ||
| if not node.target == torch.ops.aten.mm.default: | ||
| return False | ||
|
|
||
| mm_args = MatmulArgs(*node.args, **node.kwargs) | ||
| lhs = mm_args.input | ||
| if isinstance(lhs, torch.fx.Node): | ||
| if is_lifted_tensor_constant(exported_program, lhs): | ||
| return True | ||
| elif is_param(exported_program, lhs): | ||
| return True | ||
| elif is_buffer(exported_program, lhs): | ||
| return True | ||
| else: | ||
| return False | ||
| return False | ||
|
|
||
| def convert(self, exported_program, node) -> torch.fx.Node: | ||
dayo09 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return super().convert(exported_program, node) | ||
|
|
||
|
|
||
| @trace_graph_diff_on_pass | ||
| class ConvertMatmulToLinear(PassBase): | ||
| """ | ||
| This pass converts matmul to linear selectively | ||
|
|
||
| How to select between `matmul` and `linear`? | ||
|
|
||
| * Linear has better quantization accuracy (NPU backend) | ||
| Due to ONE compiler's quantization policy; | ||
| FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input. | ||
| BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, a new generation of NPU would support cwq for matmul.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jinevening Do you mean 3rd generation?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
|
|
||
| * Matmul to Linear requires Transpose, which may harm latency | ||
| When RHS is constant, addtional transpose can be folded. | ||
|
|
||
| [RHS non-const case] | ||
| Constant folding cannot be performed. | ||
|
|
||
| lhs rhs (non-const) | ||
| | | | ||
| | transpose | ||
| | | | ||
| -- linear -- | ||
| | | ||
| out | ||
|
|
||
| [RHS const case] | ||
| Constant folding can be performed to | ||
|
|
||
| lhs rhs (const) lh rhs (folded const) | ||
| | | | | | ||
| | transpose | | | ||
| | | | | | ||
| -- linear -- --> -- linear -- | ||
| | | | ||
| out out | ||
|
|
||
|
|
||
| enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False. | ||
| enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| enable_lhs_const: Optional[bool] = False, | ||
| enable_rhs_const: Optional[bool] = True, | ||
| ): | ||
| super().__init__() | ||
| self.converters: List[Converter] = [] | ||
| if enable_lhs_const: | ||
| self.converters.append(LhsConstMatmulToLinearConverter()) | ||
| if enable_rhs_const: | ||
| self.converters.append(RhsConstMatmulToLinearConverter()) | ||
|
|
||
| def call(self, exported_program: ExportedProgram) -> PassResult: | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| graph_module = exported_program.graph_module | ||
| graph = graph_module.graph | ||
| modified = False | ||
| for node in graph.nodes: | ||
| if not node.op == "call_function": | ||
| continue | ||
|
|
||
| for converter in self.converters: | ||
| if not converter.match(exported_program, node): | ||
| continue | ||
|
|
||
| new_node = converter.convert(exported_program, node) | ||
| modified = True | ||
| logger.debug( | ||
| f"{node.name} is replaced with {new_node.name} operator (permute + linear)" | ||
| ) | ||
| continue | ||
|
|
||
| graph.eliminate_dead_code() | ||
| graph.lint() | ||
| graph_module.recompile() | ||
|
|
||
| return PassResult(modified) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to describe what error is expected.
NNFW_STATUS_ERRORis a bit ambiguous.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is how onert throws. It should match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think using docstring or comments is also enough.