diff --git a/test/unit_test/utils_test/test_record_input.py b/test/unit_test/utils_test/test_record_input.py new file mode 100644 index 00000000..4d1fb4fc --- /dev/null +++ b/test/unit_test/utils_test/test_record_input.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import tico +import torch +from tico.utils.record_input import RecordingInput +from torch.export import export, save + +from test.modules.op.add import SimpleAdd + + +class RecordInputTest(unittest.TestCase): + def test_args(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + with RecordingInput(m) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_kwargs(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + kwargs = {"x": inputs[0], "y": inputs[1]} + with RecordingInput(m) as rec: + m.eval() + m(**kwargs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_args_kwargs(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + args = (inputs[0],) + kwargs = {"y": inputs[1]} + with RecordingInput(m) as rec: + m.eval() + m(*args, **kwargs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertEqual(captured_input, inputs) + tico.convert(m, captured_input) + + def test_input_to_remove(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + with RecordingInput(m, input_to_remove=["x"]) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertIsNotNone(captured_input) + self.assertIsNone(captured_input[0]) # arg[0] = 'x' + + def test_condition(self): + m = SimpleAdd() + inputs = m.get_example_inputs() + condition = lambda arg_dict: False + with RecordingInput(m, condition) as rec: + m.eval() + m(*inputs) + captured_input = rec.captured_input + + self.assertEqual(captured_input, None) diff --git a/tico/utils/record_input.py b/tico/utils/record_input.py new file mode 100644 index 00000000..9cacd0e5 --- /dev/null +++ b/tico/utils/record_input.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import inspect +from contextlib import contextmanager +from typing import Callable, List, Optional + +import torch.nn as nn + + +class RecordingInput: + r"""Context-manager that records the input values of model::forward() + + Recording input is useful for preparing example input for torch.export + + Args: + condition: lambda to provide the condition whether to record or not + + For examples, if you want to capture only args["past_key_values"] is not None, + conditon = lambda args_dict: args_dict["past_key_value"] is not None + + input_to_remove: list of arg names to remove + + Sometimes you would like to remove some arg values to make exported graph tidy or correct + For example, "past_key_values" may be not None, but just an empty cache. Then, + input_to_remove = [ "past_key_values" ]; makes the life easy + + Example:: + >>> with RecordingInput(model, input_to_remove=input_to_remove) as rec: + ... outputs = model.generate( + ... **inputs, + ... ) + ... captured_input = rec.captured_input + >>> circle_model = tico.convert(model, captured_input) + """ + + def __init__( + self, + module: nn.Module, + condition: Callable[[dict], bool] = lambda args_dict: True, + *, + input_to_remove: Optional[List[str]] = [], + ): + self.module = module + self.forward_org = module.forward + self.condition = condition + self.input_to_remove = input_to_remove + self.sig = inspect.signature(self.forward_org) + self.args_names = [ + name + for name in self.sig.parameters.keys() + if name not in ("self", "kwargs") + ] + self.captured_input = None + + def __enter__(self): + def capture_and_forward(*args, **kwargs): + bound = self.sig.bind(*args, **kwargs) + bound.apply_defaults() + args_dict = dict(bound.arguments) + + def populate_args(args_dict, input_to_remove): + for key in input_to_remove: + args_dict.pop(key, None) + args_tuple = tuple( + args_dict.get(name, None) for name in self.args_names + ) + return copy.deepcopy(args_tuple) + + if self.condition(args_dict) and self.captured_input is None: + self.captured_input = populate_args(args_dict, self.input_to_remove) + + return self.forward_org(*args, **kwargs) + + self.module.forward = capture_and_forward + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.module.forward = self.forward_org