Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 90 additions & 13 deletions tico/serialize/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ class BatchMatmulVisitor(NodeVisitor):
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
super().__init__(op_codes, graph)

def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
def set_fc_option(operator):
operator.builtinOptionsType = (
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
)
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()

option.fusedActivationFunction = (
circle.ActivationFunctionType.ActivationFunctionType.NONE
)
option.weightsFormat = (
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
)
option.keepNumDims = False
option.asymmetricQuantizeInputs = False
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32

operator.builtinOptions = option

fc_op_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
)
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
set_fc_option(operator)
return operator

def define_node(
self,
node: torch.fx.Node,
Expand All @@ -42,21 +68,72 @@ def define_node(
input = args.input
mat2 = args.mat2

op_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
is_const_tensor = lambda n: (
n.op == "get_attr"
or (
n.op == "placeholder"
and isinstance(n.meta.get("val", None), torch.Tensor)
and not n.meta["val"].requires_grad
)
)

inputs = [input, mat2]
outputs = [node]
lhs, rhs = input, mat2
is_const_lhs = is_const_tensor(lhs)

operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
if is_const_lhs:
fc_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED,
self._op_codes,
)

# Op-specific option
operator.builtinOptionsType = (
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
)
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
option.adjointLhs, option.adjointRhs = False, False
operator.builtinOptions = option
rhs_tid = self.graph.get_tid_registered(rhs)
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
rhs_shape = list(rhs_tensor.shape) # [..., batch, in_features]
rhs_dtype = rhs_tensor.type

return operator
# lhs : weight, shape = [..., out_features, in_features]
lhs_tid = self.graph.get_tid_registered(lhs)
lhs_tensor: circle.Tensor.TensorT = self.graph.tensors[lhs_tid]
lhs_shape = list(lhs_tensor.shape)
out_features = lhs_shape[-2]
fc_out_shape = rhs_shape[:-1] + [out_features]
fc_bias = self.graph.add_const_tensor(data=[0.0], source_node=node)
fc_out = self.graph.add_tensor_from_scratch(
prefix=f"{node.name}_fc_out",
shape=fc_out_shape,
shape_signature=fc_out_shape,
dtype=rhs_dtype,
)

fc_inputs = [rhs, lhs, fc_bias] # order: [input, weight]
fc_outputs = [fc_out]
fc_op = self.define_fc_node(fc_inputs, fc_outputs)
self.graph.add_operator(fc_op)

trs_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE,
self._op_codes,
)

perm = list(range(len(fc_out.shape)))
perm[-2], perm[-1] = perm[-1], perm[-2]
perm_tensor = self.graph.add_const_tensor(
data=torch.tensor(perm, dtype=torch.int32), # to prevent int64
)

trs_inputs = [fc_out, perm_tensor]
trs_outputs = [node]
trs_op = create_builtin_operator(
self.graph, trs_index, trs_inputs, trs_outputs
)

return trs_op

bmm_index = get_op_index(
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL,
self._op_codes,
)
inputs = [lhs, rhs]
outputs = [node]
op = create_builtin_operator(self.graph, bmm_index, inputs, outputs)
return op
Loading