From 9573413129a32622abb2e0ee6e9b0cd386012076 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 16 Sep 2025 09:42:48 +0900 Subject: [PATCH] [serialize] prefer fc over bmm on const lhs If lhs of matmul is const, emit fullyconnected instead of bmm. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- tico/serialize/operators/op_bmm.py | 103 +++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 13 deletions(-) diff --git a/tico/serialize/operators/op_bmm.py b/tico/serialize/operators/op_bmm.py index 7ab7e480..c0564d79 100644 --- a/tico/serialize/operators/op_bmm.py +++ b/tico/serialize/operators/op_bmm.py @@ -34,6 +34,32 @@ class BatchMatmulVisitor(NodeVisitor): def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): super().__init__(op_codes, graph) + def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT: + def set_fc_option(operator): + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions + ) + option = circle.FullyConnectedOptions.FullyConnectedOptionsT() + + option.fusedActivationFunction = ( + circle.ActivationFunctionType.ActivationFunctionType.NONE + ) + option.weightsFormat = ( + circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT + ) + option.keepNumDims = False + option.asymmetricQuantizeInputs = False + option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32 + + operator.builtinOptions = option + + fc_op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes + ) + operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs) + set_fc_option(operator) + return operator + def define_node( self, node: torch.fx.Node, @@ -42,21 +68,72 @@ def define_node( input = args.input mat2 = args.mat2 - op_index = get_op_index( - circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes + is_const_tensor = lambda n: ( + n.op == "get_attr" + or ( + n.op == "placeholder" + and isinstance(n.meta.get("val", None), torch.Tensor) + and not n.meta["val"].requires_grad + ) ) - inputs = [input, mat2] - outputs = [node] + lhs, rhs = input, mat2 + is_const_lhs = is_const_tensor(lhs) - operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + if is_const_lhs: + fc_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, + self._op_codes, + ) - # Op-specific option - operator.builtinOptionsType = ( - circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions - ) - option = circle.BatchMatMulOptions.BatchMatMulOptionsT() - option.adjointLhs, option.adjointRhs = False, False - operator.builtinOptions = option + rhs_tid = self.graph.get_tid_registered(rhs) + rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid] + rhs_shape = list(rhs_tensor.shape) # [..., batch, in_features] + rhs_dtype = rhs_tensor.type - return operator + # lhs : weight, shape = [..., out_features, in_features] + lhs_tid = self.graph.get_tid_registered(lhs) + lhs_tensor: circle.Tensor.TensorT = self.graph.tensors[lhs_tid] + lhs_shape = list(lhs_tensor.shape) + out_features = lhs_shape[-2] + fc_out_shape = rhs_shape[:-1] + [out_features] + fc_bias = self.graph.add_const_tensor(data=[0.0], source_node=node) + fc_out = self.graph.add_tensor_from_scratch( + prefix=f"{node.name}_fc_out", + shape=fc_out_shape, + shape_signature=fc_out_shape, + dtype=rhs_dtype, + ) + + fc_inputs = [rhs, lhs, fc_bias] # order: [input, weight] + fc_outputs = [fc_out] + fc_op = self.define_fc_node(fc_inputs, fc_outputs) + self.graph.add_operator(fc_op) + + trs_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, + self._op_codes, + ) + + perm = list(range(len(fc_out.shape))) + perm[-2], perm[-1] = perm[-1], perm[-2] + perm_tensor = self.graph.add_const_tensor( + data=torch.tensor(perm, dtype=torch.int32), # to prevent int64 + ) + + trs_inputs = [fc_out, perm_tensor] + trs_outputs = [node] + trs_op = create_builtin_operator( + self.graph, trs_index, trs_inputs, trs_outputs + ) + + return trs_op + + bmm_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, + self._op_codes, + ) + inputs = [lhs, rhs] + outputs = [node] + op = create_builtin_operator(self.graph, bmm_index, inputs, outputs) + return op