Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/modules/op/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import torch

from tico.config.v1 import CompileConfigV1

from test.modules.base import TestModuleBase
from test.utils.tag import test_negative, use_onert


class SimpleMatmul(TestModuleBase):
Expand All @@ -27,3 +30,64 @@ def forward(self, lhs, rhs):

def get_example_inputs(self):
return (torch.randn(3, 4), torch.randn(4, 5)), {}


class SimpleMatmulConstRhs(TestModuleBase):
def __init__(self):
super().__init__()
self.weight = torch.randn(4, 5)

def forward(self, lhs):
out = torch.mm(lhs, self.weight)
return out

def get_example_inputs(self):
return (torch.randn(3, 4),), {}


@use_onert
class SimpleMatmulConstRhsOnert(TestModuleBase):
def __init__(self):
super().__init__()
self.weight = torch.randn(4, 5)

def forward(self, lhs):
out = torch.mm(lhs, self.weight)
return out

def get_example_inputs(self):
return (torch.randn(3, 4),), {}


@use_onert
@test_negative(expected_err="NNFW_STATUS_ERROR")
class SimpleMatmulConstLhsOnert(TestModuleBase):
""" """

def __init__(self):
Comment on lines +65 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" """
def __init__(self):
def __init__(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to describe what error is expected. NNFW_STATUS_ERROR is a bit ambiguous.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is how onert throws. It should match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think using docstring or comments is also enough.

super().__init__()
self.weight = torch.randn(3, 4)

def forward(self, rhs):
out = torch.mm(self.weight, rhs)
return out

def get_example_inputs(self):
return (torch.randn(4, 5),), {}


@use_onert
class SimpleMatmulConstLhsOnertWithLinearConversion(TestModuleBase):
def __init__(self):
super().__init__()
self.weight = torch.randn(3, 4)

def forward(self, rhs):
out = torch.mm(self.weight, rhs)
return out

def get_example_inputs(self):
return (torch.randn(4, 5),), {}

def get_compile_config(self):
return CompileConfigV1(convert_lhs_const_mm_to_fc=True)
Comment on lines +92 to +93
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glistening @seockho-kim Using this compile config will enable matmul op with lhs const node conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dayo09 Could you improve it to handle bmm, too?

@tag.use_onert
class BmmTest(TestModuleBase):
    def __init__(self):
        super().__init__()
        self.weight = torch.randn(2, 3, 4)

    def forward(self, rhs):
        out = self.weight @ rhs
        return out

    def get_example_inputs(self):
        return (torch.randn(2, 4, 5),), {}

    def get_compile_config(self):
        return CompileConfigV1(convert_lhs_const_mm_to_fc=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seockho-kim Above case is not supported because matmul to fc conversion can be done only if weight is 2dim. Circle FullyConnected operation assumes its weight to be in rank 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I gave you wrong example.
I mean bmm(batch=1) case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add it in the next PR !

45 changes: 0 additions & 45 deletions test/unit_test/serialize_test/operator/test_op_mm.py

This file was deleted.

2 changes: 2 additions & 0 deletions tico/config/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class CompileConfigV1(CompileConfigBase):
legalize_causal_mask_value: bool = False
remove_constant_input: bool = False
convert_lhs_const_mm_to_fc: bool = False
convert_rhs_const_mm_to_fc: bool = True
Comment on lines +24 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, just convert_const_mm_to_fc could be simpler choice. Do you have any reasons that chose this design?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rhs_const_mm_to_fc doesn't have trade-off because tranpose is foldable to const, but lhs_const_mm_to_fc requires potential latency trade-off. Therefore, the user needs separate decisions on each case.


def get(self, name: str):
return super().get(name)
Expand Down
200 changes: 200 additions & 0 deletions tico/passes/convert_matmul_to_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
import torch
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export import ExportedProgram

from tico.utils import logging
from tico.utils.graph import create_node
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.validate_args_kwargs import MatmulArgs


class Converter: # type: ignore[empty-body]
def __init__(self):
super().__init__()

def match(self, exported_program, node) -> bool: # type: ignore[empty-body]
return False

def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body]
pass


class MatmulToLinearConverter(Converter):
def __init__(self):
super().__init__()

def convert(self, exported_program, node) -> torch.fx.Node:
graph_module = exported_program.graph_module
graph = graph_module.graph

mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

lhs = mm_args.input
rhs = mm_args.other

with graph.inserting_before(node):
transpose_node = create_node(
graph,
torch.ops.aten.permute.default,
args=(rhs, [1, 0]),
)
fc_node = create_node(
graph,
torch.ops.aten.linear.default,
args=(lhs, transpose_node),
)
node.replace_all_uses_with(fc_node, propagate_meta=True)

return fc_node


class RhsConstMatmulToLinearConverter(MatmulToLinearConverter):
def __init__(self):
super().__init__()

def match(self, exported_program, node) -> bool:
if not node.target == torch.ops.aten.mm.default:
return False

mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

rhs = mm_args.other
if isinstance(rhs, torch.fx.Node):
if is_lifted_tensor_constant(exported_program, rhs):
return True
elif is_param(exported_program, rhs):
return True
elif is_buffer(exported_program, rhs):
return True
else:
return False
return False

def convert(self, exported_program, node) -> torch.fx.Node:
return super().convert(exported_program, node)


class LhsConstMatmulToLinearConverter(MatmulToLinearConverter):
def __init__(self):
super().__init__()

def match(self, exported_program, node) -> bool:
if not node.target == torch.ops.aten.mm.default:
return False

mm_args = MatmulArgs(*node.args, **node.kwargs)
lhs = mm_args.input
if isinstance(lhs, torch.fx.Node):
if is_lifted_tensor_constant(exported_program, lhs):
return True
elif is_param(exported_program, lhs):
return True
elif is_buffer(exported_program, lhs):
return True
else:
return False
return False

def convert(self, exported_program, node) -> torch.fx.Node:
return super().convert(exported_program, node)


@trace_graph_diff_on_pass
class ConvertMatmulToLinear(PassBase):
"""
This pass converts matmul to linear selectively

How to select between `matmul` and `linear`?

* Linear has better quantization accuracy (NPU backend)
Due to ONE compiler's quantization policy;
FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input.
BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, a new generation of NPU would support cwq for matmul.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinevening Do you mean 3rd generation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.


* Matmul to Linear requires Transpose, which may harm latency
When RHS is constant, addtional transpose can be folded.

[RHS non-const case]
Constant folding cannot be performed.

lhs rhs (non-const)
| |
| transpose
| |
-- linear --
|
out

[RHS const case]
Constant folding can be performed to

lhs rhs (const) lh rhs (folded const)
| | | |
| transpose | |
| | | |
-- linear -- --> -- linear --
| |
out out


enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False.
enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True.
"""

def __init__(
self,
enable_lhs_const: Optional[bool] = False,
enable_rhs_const: Optional[bool] = True,
):
super().__init__()
self.converters: List[Converter] = []
if enable_lhs_const:
self.converters.append(LhsConstMatmulToLinearConverter())
if enable_rhs_const:
self.converters.append(RhsConstMatmulToLinearConverter())

def call(self, exported_program: ExportedProgram) -> PassResult:
logger = logging.getLogger(__name__)

graph_module = exported_program.graph_module
graph = graph_module.graph
modified = False
for node in graph.nodes:
if not node.op == "call_function":
continue

for converter in self.converters:
if not converter.match(exported_program, node):
continue

new_node = converter.convert(exported_program, node)
modified = True
logger.debug(
f"{node.name} is replaced with {new_node.name} operator (permute + linear)"
)
continue

graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()

return PassResult(modified)
2 changes: 1 addition & 1 deletion tico/passes/convert_to_relu6.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
converter.convert(exported_program, node)
modified = True
logger.debug(f"{node.name} is replaced with ReLU6 operator")
break
continue

graph.eliminate_dead_code()
graph.lint()
Expand Down
Loading