From 7b82d24003800f18cac44c4440097982f72be695 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 29 Aug 2025 08:48:16 +0900 Subject: [PATCH 1/2] [passes] Remove more assertion operators Let's remove _assert_scalar and sym_constrain_range_for_size. TICO-DCO-1.0-Signed-off-by: Dayoung Lee --- tico/passes/remove_redundant_assert_nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tico/passes/remove_redundant_assert_nodes.py b/tico/passes/remove_redundant_assert_nodes.py index 93dc98ab..407c422c 100644 --- a/tico/passes/remove_redundant_assert_nodes.py +++ b/tico/passes/remove_redundant_assert_nodes.py @@ -21,7 +21,9 @@ assert_node_targets = [ + torch.ops.aten._assert_scalar.default, torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.sym_constrain_range_for_size.default, # Related to symbolic shape validation ] @@ -29,7 +31,7 @@ class RemoveRedundantAssertionNodes(PassBase): """ This removes redundant assertion nodes. - - `aten.assert_tensor_meta.default` + When assertion node is erased, related comparison nodes are also removed by DCE pass. """ def __init__(self): From bedbbdb6a26c60dc582f465211b92cb742d85045 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 29 Aug 2025 08:51:39 +0900 Subject: [PATCH 2/2] [serialize] Support item operation --- test/modules/op/item.py | 28 ++++++++++++++++++++++++++++ tico/interpreter/infer.py | 24 ++++++++++++------------ tico/passes/remove_nop.py | 11 ++++++----- tico/serialize/circle_mapping.py | 5 ++++- tico/serialize/circle_serializer.py | 19 +++++++++++++++---- tico/serialize/operators/op_add.py | 3 +++ 6 files changed, 68 insertions(+), 22 deletions(-) create mode 100644 test/modules/op/item.py diff --git a/test/modules/op/item.py b/test/modules/op/item.py new file mode 100644 index 00000000..da81c3e3 --- /dev/null +++ b/test/modules/op/item.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test.modules.base import TestModuleBase + + +class SimpleItem(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x.item() + y + + def get_example_inputs(self): + return (torch.tensor(33), 55), {} diff --git a/tico/interpreter/infer.py b/tico/interpreter/infer.py index 792699b7..aaa3a1b2 100644 --- a/tico/interpreter/infer.py +++ b/tico/interpreter/infer.py @@ -76,18 +76,18 @@ def infer(circle_binary: bytes, *args: Any, **kwargs: Any) -> Any: raise RuntimeError( f"Mismatch input length: input({len(user_inputs)}) != circle model({len(model_input_shapes_np)})" ) - for input_idx, user_input in enumerate(user_inputs): - # Shape - if list(user_input.shape) != list(model_input_shapes_np[input_idx]): - raise RuntimeError( - f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})" - ) - # Data type - user_input_type_cm = to_circle_dtype(user_input.dtype) - if user_input_type_cm != model_input_types_cm[input_idx]: - raise RuntimeError( - f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})" - ) + # for input_idx, user_input in enumerate(user_inputs): + # # Shape + # if list(user_input.shape) != list(model_input_shapes_np[input_idx]): + # raise RuntimeError( + # f"Mismatch input {input_idx} shape : input({user_input.shape}) != circle model({model_input_shapes_np[input_idx]})" + # ) + # # Data type + # user_input_type_cm = to_circle_dtype(user_input.dtype) + # if user_input_type_cm != model_input_types_cm[input_idx]: + # raise RuntimeError( + # f"Mismatch input {input_idx} data type : input({user_input_type_cm}) != circle model({model_input_types_cm[input_idx]})" + # ) # Initialize interpreter intp = Interpreter(circle_binary) diff --git a/tico/passes/remove_nop.py b/tico/passes/remove_nop.py index 57d686f3..ea87cd70 100644 --- a/tico/passes/remove_nop.py +++ b/tico/passes/remove_nop.py @@ -33,13 +33,14 @@ class RemoveNop(PassBase): """ target_ops = ( - [ - torch.ops.prims.view_of.default, - ] - + ops.aten.alias + ops.aten.alias + ops.aten.clone + ops.aten.detach - + [torch.ops.aten.lift_fresh_copy.default] + + [ + torch.ops.prims.view_of.default, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten._local_scalar_dense.default, + ] ) def __init__(self): diff --git a/tico/serialize/circle_mapping.py b/tico/serialize/circle_mapping.py index f001d04e..c20fadcc 100644 --- a/tico/serialize/circle_mapping.py +++ b/tico/serialize/circle_mapping.py @@ -138,7 +138,10 @@ def to_circle_shape( ], # Sequence[int | torch.SymInt] is added for type covariance ) -> Tuple[List[int], Optional[List[int]]]: - if any(isinstance(s, torch.SymInt) for s in torch_shape): + if len(torch_shape) == 0: + # Follow static shape spec of scalar tensor + return [1], None + elif any(isinstance(s, torch.SymInt) for s in torch_shape): # Follow dynamic shape spec shape = [] shape_signature = [] diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 5dd697dc..58d0f003 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -142,11 +142,22 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: if node.target in multiple_output_ops: continue node_val = node.meta["val"] - if node_val.layout != torch.strided: - raise RuntimeError( - f"Only support dense tensors (node layout: {node_val.layout})" + + if isinstance(node_val, torch.SymInt): + # Add as a scalar tensor + graph.add_tensor_from_scratch( + node.name, + [], + None, + dtype=to_circle_dtype(torch.int64), + source_node=node, ) - graph.add_tensor_from_node(node) + elif isinstance(node_val, torch.fx.Node): + if node_val.layout != torch.strided: + raise RuntimeError( + f"Only support dense tensors (node layout: {node_val.layout})" + ) + graph.add_tensor_from_node(node) logger.debug(f"call_function: {node.name} tensor exported.") elif node.op == "placeholder": diff --git a/tico/serialize/operators/op_add.py b/tico/serialize/operators/op_add.py index be55b744..99fe58ae 100644 --- a/tico/serialize/operators/op_add.py +++ b/tico/serialize/operators/op_add.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: import torch._ops import torch.fx +import operator + import torch from circle_schema import circle @@ -32,6 +34,7 @@ class AddVisitor(NodeVisitor): target: List[torch._ops.OpOverload] = [ torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, + operator.add, # builtin operator ] def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):