From c7625cb8e72ce6fc21505d88b6be7863b790ae0f Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 12 Sep 2025 20:47:53 +0900 Subject: [PATCH 1/5] [passes] Add ConvertMatmulToLinear pass Let's add convert matmul to linear pass. This commit... refactors mm serialization logic and make convert_matmul_to_linear pass introduces new CompileConfig attribute convert_lhs/rhs_const_mm_to_fc. TICO-DCO-1.0-Signed-off-by: Dayoung Lee --- test/modules/op/mm.py | 64 +++++++ tico/config/v1.py | 2 + tico/passes/convert_matmul_to_linear.py | 212 ++++++++++++++++++++++++ tico/passes/convert_to_relu6.py | 2 +- tico/serialize/operators/op_mm.py | 110 +----------- tico/utils/convert.py | 5 + 6 files changed, 286 insertions(+), 109 deletions(-) create mode 100644 tico/passes/convert_matmul_to_linear.py diff --git a/test/modules/op/mm.py b/test/modules/op/mm.py index c7253d57..6964e22e 100644 --- a/test/modules/op/mm.py +++ b/test/modules/op/mm.py @@ -14,7 +14,10 @@ import torch +from tico.config.v1 import CompileConfigV1 + from test.modules.base import TestModuleBase +from test.utils.tag import test_negative, use_onert class SimpleMatmul(TestModuleBase): @@ -27,3 +30,64 @@ def forward(self, lhs, rhs): def get_example_inputs(self): return (torch.randn(3, 4), torch.randn(4, 5)), {} + + +class SimpleMatmulConstRhs(TestModuleBase): + def __init__(self): + super().__init__() + self.weight = torch.randn(4, 5) + + def forward(self, lhs): + out = torch.mm(lhs, self.weight) + return out + + def get_example_inputs(self): + return (torch.randn(3, 4),), {} + + +@use_onert +class SimpleMatmulConstRhsOnert(TestModuleBase): + def __init__(self): + super().__init__() + self.weight = torch.randn(4, 5) + + def forward(self, lhs): + out = torch.mm(lhs, self.weight) + return out + + def get_example_inputs(self): + return (torch.randn(3, 4),), {} + + +@use_onert +@test_negative(expected_err="NNFW_STATUS_ERROR") +class SimpleMatmulConstLhsOnert(TestModuleBase): + """ """ + + def __init__(self): + super().__init__() + self.weight = torch.randn(3, 4) + + def forward(self, rhs): + out = torch.mm(self.weight, rhs) + return out + + def get_example_inputs(self): + return (torch.randn(4, 5),), {} + + +@use_onert +class SimpleMatmulConstLhsOnertWithLinearConversion(TestModuleBase): + def __init__(self): + super().__init__() + self.weight = torch.randn(3, 4) + + def forward(self, rhs): + out = torch.mm(self.weight, rhs) + return out + + def get_example_inputs(self): + return (torch.randn(4, 5),), {} + + def get_compile_config(self): + return CompileConfigV1(convert_lhs_const_mm_to_fc=True) diff --git a/tico/config/v1.py b/tico/config/v1.py index 3ce340b3..7962a9c6 100644 --- a/tico/config/v1.py +++ b/tico/config/v1.py @@ -21,6 +21,8 @@ class CompileConfigV1(CompileConfigBase): legalize_causal_mask_value: bool = False remove_constant_input: bool = False + convert_lhs_const_mm_to_fc: bool = False + convert_rhs_const_mm_to_fc: bool = True def get(self, name: str): return super().get(name) diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py new file mode 100644 index 00000000..373afb9e --- /dev/null +++ b/tico/passes/convert_matmul_to_linear.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param +from torch.export import ExportedProgram + +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.validate_args_kwargs import MatmulArgs + + +class Converter: # type: ignore[empty-body] + def __init__(self): + super().__init__() + + def match(self, exported_program, node) -> bool: # type: ignore[empty-body] + return False + + def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empty-body] + pass + + +class ConvertRhsConstMatmulToLinear(Converter): + def __init__(self): + super().__init__() + + def match(self, exported_program, node) -> bool: + if not node.target == torch.ops.aten.mm.default: + return False + + mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + rhs = mm_args.other + if isinstance(rhs, torch.fx.Node): + if is_lifted_tensor_constant(exported_program, rhs): + return True + elif is_param(exported_program, rhs): + return True + elif is_buffer(exported_program, rhs): + return True + else: + return False + return False + + def convert(self, exported_program, node) -> torch.fx.Node: + graph_module = exported_program.graph_module + graph = graph_module.graph + + mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + lhs = mm_args.input + rhs = mm_args.other + + with graph.inserting_before(node): + transpose_node = create_node( + graph, + torch.ops.aten.permute.default, + args=(rhs, [1, 0]), + ) + fc_node = create_node( + graph, + torch.ops.aten.linear.default, + args=(lhs, transpose_node), + ) + node.replace_all_uses_with(fc_node, propagate_meta=True) + + return fc_node + + +class ConvertLhsConstMatmulToLinear(Converter): + def __init__(self): + super().__init__() + + def match(self, exported_program, node) -> bool: + if not node.target == torch.ops.aten.mm.default: + return False + + mm_args = MatmulArgs(*node.args, **node.kwargs) + lhs = mm_args.input + if isinstance(lhs, torch.fx.Node): + if is_lifted_tensor_constant(exported_program, lhs): + return True + elif is_param(exported_program, lhs): + return True + elif is_buffer(exported_program, lhs): + return True + else: + return False + + def convert(self, exported_program, node) -> torch.fx.Node: + graph_module = exported_program.graph_module + graph = graph_module.graph + + mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + lhs = mm_args.input + rhs = mm_args.other + + with graph.inserting_before(node): + transpose_node = create_node( + graph, + torch.ops.aten.permute.default, + args=(rhs, [1, 0]), + ) + fc_node = create_node( + graph, + torch.ops.aten.linear.default, + args=(lhs, transpose_node), + ) + node.replace_all_uses_with(fc_node, propagate_meta=True) + + return fc_node + + +@trace_graph_diff_on_pass +class ConvertMatmulToLinear(PassBase): + """ + This pass converts matmul to linear selectively + + How to select between `matmul` and `linear`? + + * Linear has better quantization accuracy (NPU backend) + Due to ONE compiler's quantization policy; + FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input. + BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs. + + * Matmul to Linear requires Transpose, which may harm latency + When RHS is constant, addtional transpose can be folded. + + [RHS non-const case] + Constant folding cannot be performed. + + lhs rhs (non-const) + | | + | transpose + | | + -- linear -- + | + out + + [RHS const case] + Constant folding can be performed to + + lhs rhs (const) lh rhs (folded const) + | | | | + | transpose | | + | | | | + -- linear -- --> -- linear -- + | | + out out + + + enable_lhs_const: If true, also convert matmul where LHS is constant tensor. Default is False. + enable_rhs_const: If true, also convert matmul where RHS is constant tensor. Default is True. + """ + + def __init__( + self, + enable_lhs_const: Optional[bool] = False, + enable_rhs_const: Optional[bool] = True, + ): + super().__init__() + self.converters: List[Converter] = [] + if enable_lhs_const: + self.converters.append(ConvertLhsConstMatmulToLinear()) + if enable_rhs_const: + self.converters.append(ConvertRhsConstMatmulToLinear()) + + def call(self, exported_program: ExportedProgram) -> PassResult: + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + for node in graph.nodes: + if not node.op == "call_function": + continue + + for converter in self.converters: + if not converter.match(exported_program, node): + continue + + new_node = converter.convert(exported_program, node) + modified = True + logger.debug( + f"{node.name} is replaced with {new_node.name} operator (permute + linear)" + ) + continue + + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified) diff --git a/tico/passes/convert_to_relu6.py b/tico/passes/convert_to_relu6.py index c155fb71..76d2a576 100644 --- a/tico/passes/convert_to_relu6.py +++ b/tico/passes/convert_to_relu6.py @@ -172,7 +172,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult: converter.convert(exported_program, node) modified = True logger.debug(f"{node.name} is replaced with ReLU6 operator") - break + continue graph.eliminate_dead_code() graph.lint() diff --git a/tico/serialize/operators/op_mm.py b/tico/serialize/operators/op_mm.py index 90e98872..00d1ea47 100644 --- a/tico/serialize/operators/op_mm.py +++ b/tico/serialize/operators/op_mm.py @@ -30,7 +30,7 @@ @register_node_visitor class MatmulDefaultVisitor(NodeVisitor): """ - Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose. + Convert matmul to equavalent BatchMatMul """ target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default] @@ -57,112 +57,9 @@ def set_bmm_option(operator): return operator - def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT: - def set_transpose_option(operator): - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.TransposeOptions - ) - option = circle.TransposeOptions.TransposeOptionsT() - operator.builtinOptions = option - - transpose_op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes - ) - operator = create_builtin_operator( - self.graph, transpose_op_index, inputs, outputs - ) - set_transpose_option(operator) - return operator - - def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT: - def set_fc_option(operator): - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions - ) - option = circle.FullyConnectedOptions.FullyConnectedOptionsT() - - option.fusedActivationFunction = ( - circle.ActivationFunctionType.ActivationFunctionType.NONE - ) - option.weightsFormat = ( - circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT - ) - option.keepNumDims = False - option.asymmetricQuantizeInputs = False - option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32 - - operator.builtinOptions = option - - fc_op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes - ) - operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs) - set_fc_option(operator) - return operator - - """ - Define FullyConnnected with Tranpose operator. - Note that those sets of operators are equivalent. - (1) Matmul - matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W') - - (2) Transpose + FullyConneccted - transpose( rhs[K, W'] ) -> trs_output[W', K] - fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W') - """ - - def define_fc_with_transpose( - self, node, inputs, outputs - ) -> circle.Operator.OperatorT: - lhs, rhs = inputs - - # get transpose shape - rhs_tid: int = self.graph.get_tid_registered(rhs) - rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid] - rhs_name: str = rhs.name - rhs_type: int = rhs_tensor.type - rhs_shape: List[int] = rhs_tensor.shape - assert len(rhs_shape) == 2, len(rhs_shape) - rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]] - - # create transpose output tensor - trs_output = self.graph.add_tensor_from_scratch( - prefix=f"{rhs_name}_transposed_output", - shape=rhs_shape_transpose, - shape_signature=None, - dtype=rhs_type, - source_node=node, - ) - trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node) - trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output]) - self.graph.add_operator(trs_operator) - - # define fc node - fc_input = lhs - fc_weight = trs_output - fc_shape = [fc_weight.shape[0]] - fc_bias = self.graph.add_const_tensor( - data=[0.0] * fc_shape[0], source_node=node - ) - - operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs) - - return operator - def define_node( self, node: torch.fx.Node, prior_latency=True ) -> circle.Operator.OperatorT: - """ - NOTE: Possibility of accuracy-latency trade-off - From ONE compiler's perspective: - - BMM uses per-tensor quantization for both rhs and lhs. - - FC uses per-channel quantization for weight and per-tensor for input. - Thus, FC is better in terms of accuracy. - FC necessarily involves an additional transpose operation to be identical with mm. - If transposed operand is const, it can be optimized by constant folding. - Thus, convert FC only if tranpose can be folded. - TODO set prior_latency outside - """ args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] input = args.input other = args.other @@ -170,9 +67,6 @@ def define_node( inputs = [input, other] outputs = [node] - if not is_const(other) and prior_latency: - operator = self.define_bmm_node(inputs, outputs) - else: - operator = self.define_fc_with_transpose(node, inputs, outputs) + operator = self.define_bmm_node(inputs, outputs) return operator diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 1e61557f..8368986b 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -40,6 +40,7 @@ from tico.passes.const_prop_pass import ConstPropPass from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape +from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy from tico.passes.convert_to_relu6 import ConvertToReLU6 from tico.passes.decompose_addmm import DecomposeAddmm @@ -249,6 +250,10 @@ def convert_exported_module_to_circle( ConstPropPass(), SegmentIndexSelectConst(), LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")), + ConvertMatmulToLinear( + enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"), + enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"), + ), LowerToResizeNearestNeighbor(), LegalizePreDefinedLayoutOperators(), LowerPow2ToMul(), From 831aa031307236cb71cfb65addcc308897a50e4d Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 17 Sep 2025 16:39:15 +0900 Subject: [PATCH 2/5] Apply feedback --- tico/passes/convert_matmul_to_linear.py | 71 ++++++++++--------------- 1 file changed, 29 insertions(+), 42 deletions(-) diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py index 373afb9e..676ef878 100644 --- a/tico/passes/convert_matmul_to_linear.py +++ b/tico/passes/convert_matmul_to_linear.py @@ -38,28 +38,10 @@ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empt pass -class ConvertRhsConstMatmulToLinear(Converter): +class ConvertMatmulToLinear(Converter): def __init__(self): super().__init__() - def match(self, exported_program, node) -> bool: - if not node.target == torch.ops.aten.mm.default: - return False - - mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] - - rhs = mm_args.other - if isinstance(rhs, torch.fx.Node): - if is_lifted_tensor_constant(exported_program, rhs): - return True - elif is_param(exported_program, rhs): - return True - elif is_buffer(exported_program, rhs): - return True - else: - return False - return False - def convert(self, exported_program, node) -> torch.fx.Node: graph_module = exported_program.graph_module graph = graph_module.graph @@ -85,7 +67,33 @@ def convert(self, exported_program, node) -> torch.fx.Node: return fc_node -class ConvertLhsConstMatmulToLinear(Converter): +class ConvertRhsConstMatmulToLinear(ConvertMatmulToLinear): + def __init__(self): + super().__init__() + + def match(self, exported_program, node) -> bool: + if not node.target == torch.ops.aten.mm.default: + return False + + mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + rhs = mm_args.other + if isinstance(rhs, torch.fx.Node): + if is_lifted_tensor_constant(exported_program, rhs): + return True + elif is_param(exported_program, rhs): + return True + elif is_buffer(exported_program, rhs): + return True + else: + return False + return False + + def convert(self, exported_program, node) -> torch.fx.Node: + return super().convert(exported_program, node) + + +class ConvertLhsConstMatmulToLinear(ConvertMatmulToLinear): def __init__(self): super().__init__() @@ -106,28 +114,7 @@ def match(self, exported_program, node) -> bool: return False def convert(self, exported_program, node) -> torch.fx.Node: - graph_module = exported_program.graph_module - graph = graph_module.graph - - mm_args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] - - lhs = mm_args.input - rhs = mm_args.other - - with graph.inserting_before(node): - transpose_node = create_node( - graph, - torch.ops.aten.permute.default, - args=(rhs, [1, 0]), - ) - fc_node = create_node( - graph, - torch.ops.aten.linear.default, - args=(lhs, transpose_node), - ) - node.replace_all_uses_with(fc_node, propagate_meta=True) - - return fc_node + return super().convert(exported_program, node) @trace_graph_diff_on_pass From d6bc302d5e07d76cf19af18b5fd11c0a2f7ad1d7 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Mon, 15 Sep 2025 14:50:36 +0900 Subject: [PATCH 3/5] Update tico/passes/convert_matmul_to_linear.py Co-authored-by: Hyukjin Jeong --- tico/passes/convert_matmul_to_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py index 676ef878..0e18f19c 100644 --- a/tico/passes/convert_matmul_to_linear.py +++ b/tico/passes/convert_matmul_to_linear.py @@ -155,8 +155,8 @@ class ConvertMatmulToLinear(PassBase): out out - enable_lhs_const: If true, also convert matmul where LHS is constant tensor. Default is False. - enable_rhs_const: If true, also convert matmul where RHS is constant tensor. Default is True. + enable_lhs_const: If true, convert matmul where LHS is constant tensor. Default is False. + enable_rhs_const: If true, convert matmul where RHS is constant tensor. Default is True. """ def __init__( From ec94c768908ec437f7ca6378b81087e062cb163b Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 17 Sep 2025 16:52:17 +0900 Subject: [PATCH 4/5] Remove unused feature --- .../serialize_test/operator/test_op_mm.py | 45 ------------------- tico/serialize/operators/op_mm.py | 41 +++++++---------- 2 files changed, 15 insertions(+), 71 deletions(-) delete mode 100644 test/unit_test/serialize_test/operator/test_op_mm.py diff --git a/test/unit_test/serialize_test/operator/test_op_mm.py b/test/unit_test/serialize_test/operator/test_op_mm.py deleted file mode 100644 index 38dfe992..00000000 --- a/test/unit_test/serialize_test/operator/test_op_mm.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from circle_schema import circle - -from test.modules.op.mm import SimpleMatmul -from test.unit_test.serialize_test.operator.fixture import SingleOpGraphFixture - - -class MatmulVisitorTest(unittest.TestCase): - def test_op_mm_to_fullyconnected(self): - bp = SingleOpGraphFixture(SimpleMatmul(), torch.ops.aten.mm.default) - mmVisitor = bp.target_visitor() - res = mmVisitor.define_node(bp.target_node(), prior_latency=False) - - self.assertTrue( - isinstance( - res.builtinOptions, circle.FullyConnectedOptions.FullyConnectedOptionsT - ) - ) - - def test_op_mm_to_bmm(self): - bp = SingleOpGraphFixture(SimpleMatmul(), torch.ops.aten.mm.default) - mmVisitor = bp.target_visitor() - res = mmVisitor.define_node(bp.target_node(), prior_latency=True) - - self.assertTrue( - isinstance( - res.builtinOptions, circle.BatchMatMulOptions.BatchMatMulOptionsT - ) - ) diff --git a/tico/serialize/operators/op_mm.py b/tico/serialize/operators/op_mm.py index 00d1ea47..c4fe4fff 100644 --- a/tico/serialize/operators/op_mm.py +++ b/tico/serialize/operators/op_mm.py @@ -20,7 +20,7 @@ import torch from circle_schema import circle -from tico.serialize.circle_graph import CircleSubgraph, is_const +from tico.serialize.circle_graph import CircleSubgraph from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index @@ -28,9 +28,9 @@ @register_node_visitor -class MatmulDefaultVisitor(NodeVisitor): +class MatmulVisitor(NodeVisitor): """ - Convert matmul to equavalent BatchMatMul + Convert matmul to Circle BatchMatMul """ target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default] @@ -38,28 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor): def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): super().__init__(op_codes, graph) - # NOTE: Matmul is equivalent to Batch MatMul (batch=1) - def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT: - def set_bmm_option(operator): - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions - ) - option = circle.BatchMatMulOptions.BatchMatMulOptionsT() - option.adjointLhs, option.adjointRhs = False, False - option.asymmetricQuantizeInputs = False - operator.builtinOptions = option - - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes - ) - operator = create_builtin_operator(self.graph, op_index, inputs, outputs) - set_bmm_option(operator) - - return operator - - def define_node( - self, node: torch.fx.Node, prior_latency=True - ) -> circle.Operator.OperatorT: + def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT: args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type] input = args.input other = args.other @@ -67,6 +46,16 @@ def define_node( inputs = [input, other] outputs = [node] - operator = self.define_bmm_node(inputs, outputs) + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes + ) + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions + ) + option = circle.BatchMatMulOptions.BatchMatMulOptionsT() + option.adjointLhs, option.adjointRhs = False, False + option.asymmetricQuantizeInputs = False + operator.builtinOptions = option return operator From 4588a1b7ee4c6897ede0754b5e344ae630bb9748 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 17 Sep 2025 16:57:11 +0900 Subject: [PATCH 5/5] Fix --- tico/passes/convert_matmul_to_linear.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tico/passes/convert_matmul_to_linear.py b/tico/passes/convert_matmul_to_linear.py index 0e18f19c..4070fae6 100644 --- a/tico/passes/convert_matmul_to_linear.py +++ b/tico/passes/convert_matmul_to_linear.py @@ -38,7 +38,7 @@ def convert(self, exported_program, node) -> torch.fx.Node: # type: ignore[empt pass -class ConvertMatmulToLinear(Converter): +class MatmulToLinearConverter(Converter): def __init__(self): super().__init__() @@ -67,7 +67,7 @@ def convert(self, exported_program, node) -> torch.fx.Node: return fc_node -class ConvertRhsConstMatmulToLinear(ConvertMatmulToLinear): +class RhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() @@ -93,7 +93,7 @@ def convert(self, exported_program, node) -> torch.fx.Node: return super().convert(exported_program, node) -class ConvertLhsConstMatmulToLinear(ConvertMatmulToLinear): +class LhsConstMatmulToLinearConverter(MatmulToLinearConverter): def __init__(self): super().__init__() @@ -112,6 +112,7 @@ def match(self, exported_program, node) -> bool: return True else: return False + return False def convert(self, exported_program, node) -> torch.fx.Node: return super().convert(exported_program, node) @@ -167,9 +168,9 @@ def __init__( super().__init__() self.converters: List[Converter] = [] if enable_lhs_const: - self.converters.append(ConvertLhsConstMatmulToLinear()) + self.converters.append(LhsConstMatmulToLinearConverter()) if enable_rhs_const: - self.converters.append(ConvertRhsConstMatmulToLinear()) + self.converters.append(RhsConstMatmulToLinearConverter()) def call(self, exported_program: ExportedProgram) -> PassResult: logger = logging.getLogger(__name__)