From 3f5568ac2375ee74572440aac402ce2624bdb5ff Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 29 Jul 2025 06:35:45 +0900 Subject: [PATCH 1/4] [utils] Add forward's input recorder It adds forward's input recorder. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee --- .../unit_test/utils_test/test_record_input.py | 85 ++++++++++++++++++ tico/utils/record_input.py | 89 +++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 test/unit_test/utils_test/test_record_input.py create mode 100644 tico/utils/record_input.py diff --git a/test/unit_test/utils_test/test_record_input.py b/test/unit_test/utils_test/test_record_input.py new file mode 100644 index 00000000..a602a8ba --- /dev/null +++ b/test/unit_test/utils_test/test_record_input.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import tico +import torch +from tico.utils.record_input import RecordingInput +from torch.export import export, save + +from test.modules.op.add import SimpleAdd + + +class RecordInputTest(unittest.TestCase): + def test_args(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + with RecordingInput(m) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_kwargs(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + kwargs = {"x": inputs[0], "y": inputs[1]} + with RecordingInput(m) as rec: + m.eval() + m(**kwargs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_args_kwargs(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + args = (inputs[0],) + kwargs = {"y": inputs[1]} + with RecordingInput(m) as rec: + m.eval() + m(*args, **kwargs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_input_to_remove(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + with RecordingInput(m, input_to_remove=["x"]) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertIsNone(captured_input[0]) # arg[0] = 'x' + + def test_condition(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + condition = lambda arg_dict: False + with RecordingInput(m, condition) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertEqual(captured_input, ()) diff --git a/tico/utils/record_input.py b/tico/utils/record_input.py new file mode 100644 index 00000000..4cf2b839 --- /dev/null +++ b/tico/utils/record_input.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import inspect +from contextlib import contextmanager +from typing import Callable, List, Optional + +import torch.nn as nn + + +class RecordingInput: + r"""Context-manager that records the input values of model::forward() + + Recording input is useful for preparing example input for torch.export + + Args: + condition: lambda to provide the condition whether to record or not + + For examples, if you want to capture only args["past_key_values"] is not None, + conditon = lambda args_dict: args_dict["past_key_value"] is not None + + input_to_remove: list of arg names to remove + + Sometimes you would like to remove some arg values to make exported graph tidy or correct + For example, "past_key_values" may be not None, but just an empty cache. Then, + input_to_remove = [ "past_key_values" ]; makes the life easy + + Example:: + >>> with RecordingInput(model, input_to_remove=input_to_remove) as rec: + ... outputs = model.generate( + ... **inputs, + ... ) + ... captured_input = rec.captured_input + >>> circle_model = tico.convert(model, captured_input) + """ + + def __init__( + self, + module: nn.Module, + condition: Callable[[dict], bool] = lambda args_dict: True, + *, + input_to_remove: Optional[List[str]] = [], + ): + self.module = module + self.forward_org = module.forward + self.condition = condition + self.input_to_remove = input_to_remove + sig = inspect.signature(self.forward_org) + self.args_names = [ + name for name in sig.parameters.keys() if name not in ("self", "kwargs") + ] + self.captured_input = () + + def __enter__(self): + def capture_and_forward(*args, **kwargs): + args_dict = dict(zip(self.args_names, args)) + args_dict.update(kwargs) + + def populate_args(args_dict, filter): + for key in filter: + args_dict.pop(key, None) + args_tuple = tuple( + args_dict.get(name, None) for name in self.args_names + ) + return copy.deepcopy(args_tuple) + + if self.condition(args_dict) and self.captured_input == (): + self.captured_input = populate_args(args_dict, self.input_to_remove) + + return self.forward_org(*args, **kwargs) + + self.module.forward = capture_and_forward + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.module.forward = self.forward_org From 0e16661ef80ae6030d46e16847191550d598933a Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 29 Jul 2025 14:45:46 +0900 Subject: [PATCH 2/4] Update tico/utils/record_input.py Co-authored-by: Dayoung Lee --- tico/utils/record_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tico/utils/record_input.py b/tico/utils/record_input.py index 4cf2b839..8941f32b 100644 --- a/tico/utils/record_input.py +++ b/tico/utils/record_input.py @@ -69,7 +69,7 @@ def capture_and_forward(*args, **kwargs): args_dict = dict(zip(self.args_names, args)) args_dict.update(kwargs) - def populate_args(args_dict, filter): + def populate_args(args_dict, input_to_remove): for key in filter: args_dict.pop(key, None) args_tuple = tuple( From 459cd2df3145b59a183855337d7b463c573bbbb9 Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 29 Jul 2025 14:51:57 +0900 Subject: [PATCH 3/4] Rename filter to input_to_remove --- tico/utils/record_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tico/utils/record_input.py b/tico/utils/record_input.py index 8941f32b..4dbb0b76 100644 --- a/tico/utils/record_input.py +++ b/tico/utils/record_input.py @@ -70,7 +70,7 @@ def capture_and_forward(*args, **kwargs): args_dict.update(kwargs) def populate_args(args_dict, input_to_remove): - for key in filter: + for key in input_to_remove: args_dict.pop(key, None) args_tuple = tuple( args_dict.get(name, None) for name in self.args_names From 6afcaef4a5808cd8ebbc7f3a41329783decbbecd Mon Sep 17 00:00:00 2001 From: Sanggyu Lee Date: Tue, 29 Jul 2025 17:05:26 +0900 Subject: [PATCH 4/4] Update as requested ( () -> None, sig.bind ) --- test/unit_test/utils_test/test_record_input.py | 2 +- tico/utils/record_input.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/unit_test/utils_test/test_record_input.py b/test/unit_test/utils_test/test_record_input.py index a602a8ba..4d1fb4fc 100644 --- a/test/unit_test/utils_test/test_record_input.py +++ b/test/unit_test/utils_test/test_record_input.py @@ -82,4 +82,4 @@ def test_condition(self): m(*inputs) captured_input = rec.captured_input - self.assertEqual(captured_input, ()) + self.assertEqual(captured_input, None) diff --git a/tico/utils/record_input.py b/tico/utils/record_input.py index 4dbb0b76..9cacd0e5 100644 --- a/tico/utils/record_input.py +++ b/tico/utils/record_input.py @@ -58,16 +58,19 @@ def __init__( self.forward_org = module.forward self.condition = condition self.input_to_remove = input_to_remove - sig = inspect.signature(self.forward_org) + self.sig = inspect.signature(self.forward_org) self.args_names = [ - name for name in sig.parameters.keys() if name not in ("self", "kwargs") + name + for name in self.sig.parameters.keys() + if name not in ("self", "kwargs") ] - self.captured_input = () + self.captured_input = None def __enter__(self): def capture_and_forward(*args, **kwargs): - args_dict = dict(zip(self.args_names, args)) - args_dict.update(kwargs) + bound = self.sig.bind(*args, **kwargs) + bound.apply_defaults() + args_dict = dict(bound.arguments) def populate_args(args_dict, input_to_remove): for key in input_to_remove: @@ -77,7 +80,7 @@ def populate_args(args_dict, input_to_remove): ) return copy.deepcopy(args_tuple) - if self.condition(args_dict) and self.captured_input == (): + if self.condition(args_dict) and self.captured_input is None: self.captured_input = populate_args(args_dict, self.input_to_remove) return self.forward_org(*args, **kwargs)