From 9624bd79039ca6a0caf4a5d120e053187205274b Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Wed, 23 Jul 2025 21:52:40 +0900 Subject: [PATCH] Refactor test.utils.tag --- test/pt2_to_circle_test/builder.py | 10 ++- test/utils/base_builders.py | 28 +++--- test/utils/tag.py | 138 ++++++++++++++++------------- 3 files changed, 101 insertions(+), 75 deletions(-) diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index 4618a35e..f5ef4f1e 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -32,7 +32,7 @@ ) from test.utils.base_builders import TestDictBuilderBase, TestRunnerBase -from test.utils.tag import is_tagged +from test.utils.tag import TestTag class NNModuleTest(TestRunnerBase): @@ -41,9 +41,11 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module): self.test_dir = Path(os.path.dirname(os.path.abspath(__file__))) / "artifacts" # Get tags - self.test_without_pt2: bool = is_tagged(self.nnmodule, "test_without_pt2") - self.test_without_inference: bool = is_tagged( - self.nnmodule, "test_without_inference" + self.test_without_pt2: bool = TestTag.get( + self.nnmodule, "test_without_pt2", False + ) + self.test_without_inference: bool = TestTag.get( + self.nnmodule, "test_without_inference", False ) # Set tolerance diff --git a/test/utils/base_builders.py b/test/utils/base_builders.py index 6037df99..633c4a05 100644 --- a/test/utils/base_builders.py +++ b/test/utils/base_builders.py @@ -16,10 +16,10 @@ import inspect import pkgutil from abc import abstractmethod +from typing import Optional import torch - -from test.utils.tag import is_tagged +from test.utils.tag import get_tag, has_tag, TestTag class TestRunnerBase: @@ -31,12 +31,17 @@ def __init__(self, test_name: str, nnmodule: torch.nn.Module): self.nnmodule = nnmodule self.example_inputs = nnmodule.get_example_inputs() # type: ignore[operator] - # Get tags - self.skip: bool = is_tagged(self.nnmodule, "skip") - self.skip_reason: str = getattr(self.nnmodule, "__tag_skip_reason", "") - self.test_negative: bool = is_tagged(self.nnmodule, "test_negative") - self.expected_err: str = getattr(self.nnmodule, "__tag_expected_err", "") - self.use_onert: bool = is_tagged(self.nnmodule, "use_onert") + skip: Optional[object] = TestTag.get(type(self.nnmodule), "skip") + self.skip: bool = skip is not None + self.skip_reason: str = skip.get("reason") if skip else "" + + test_negative: Optional[object] = TestTag.get( + type(self.nnmodule), "test_negative" + ) + self.test_negative: bool = test_negative is not None + self.expected_err: str = test_negative.get("reason") if test_negative else "" + + self.use_onert: bool = TestTag.get(type(self.nnmodule), "use_onert", False) @abstractmethod def make(self): @@ -79,16 +84,17 @@ def _get_nnmodules(self, submodule: str): ) ) - # If any of the nnmodule_classes has a tag `__tag_target`, only those nnmodule_classes will be added + # If any of the nnmodule_classes is marked as target, only those will be added target_only: bool = any( - hasattr(nnmodule_cls, "__tag_target") for nnmodule_cls in nnmodule_classes + TestTag.get(nnmodule_cls, "target", False) + for nnmodule_cls in nnmodule_classes ) if target_only: nnmodule_classes = [ nnmodule_cls for nnmodule_cls in nnmodule_classes - if hasattr(nnmodule_cls, "__tag_target") + if TestTag.get(nnmodule_cls, "target", False) ] return nnmodule_classes diff --git a/test/utils/tag.py b/test/utils/tag.py index a3014ece..e7ae7392 100644 --- a/test/utils/tag.py +++ b/test/utils/tag.py @@ -12,95 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Type -def skip(reason): - def __inner_skip(orig_class): - setattr(orig_class, "__tag_skip", True) - setattr(orig_class, "__tag_skip_reason", reason) - def __init__(self, *args_, **kwargs_): - pass +class TestTag: + """Central registry for managing test tag""" - # Ignore initialization of skipped modules - orig_class.__init__ = __init__ + _registry: Dict[Type, Dict[str, Any]] = {} - return orig_class + @classmethod + def add(cls, test_class: Type, tag_key: str, tag_value: Any = None) -> None: + """Add test tag object to a class - return __inner_skip + Args: + test_class: The test class to add tag to + tag_key: Name of Tag object to add + tag_value: Tag object to add + """ + if test_class not in cls._registry: + cls._registry[test_class] = {} + cls._registry[test_class][tag_key] = tag_value -def skip_if(predicate, reason): - def __inner_skip(orig_class): - setattr(orig_class, "__tag_skip", True) - setattr(orig_class, "__tag_skip_reason", reason) + @classmethod + def has(cls, test_class: Type, tag_key: str) -> bool: + """Check if a class has specific tag type - def __init__(self, *args_, **kwargs_): - pass + Args: + test_class: The test class to check + tag_key: Type of tag object to check for - # Ignore initialization of skipped modules - orig_class.__init__ = __init__ + Returns: + bool: True if the tag exists, False otherwise + """ + return test_class in cls._registry and tag_key in cls._registry[test_class] - return orig_class + @classmethod + def get(cls, test_class: Type, tag_key: str, default: Any = None) -> Any: + """Get tag object for a class - if predicate: - return __inner_skip - else: - return lambda x: x + Args: + test_class: The test class to get tag from + tag_key: Type of tag object to retrieve + default: Default value to return if tag not found + Returns: + The tag object or default if not found + """ + return cls._registry.get(test_class, {}).get(tag_key, default) -def test_without_inference(orig_class): - setattr(orig_class, "__tag_test_without_inference", True) - return orig_class +#################################################################### +################## Add tag here ################## +#################################################################### -def test_without_pt2(orig_class): - setattr(orig_class, "__tag_test_without_pt2", True) - return orig_class +def skip(reason): + """ + Mark a test class to be skipped with a reason -def test_negative(expected_err): - def __inner_test_negative(orig_class): - setattr(orig_class, "__tag_test_negative", True) - setattr(orig_class, "__tag_expected_err", expected_err) + e.g. + @skip(reason="Not implemented yet") + class MyTest(unittest.TestCase): # <-- This test will be skipped + """ - return orig_class + def decorator(cls): + TestTag.add(cls, "skip", {"reason": reason}) + return cls - return __inner_test_negative + return decorator -def target(orig_class): - setattr(orig_class, "__tag_target", True) - return orig_class +def skip_if(predicate, reason): + """Conditionally mark a test class to be skipped with a reason""" + if predicate: + return skip(reason) + return lambda cls: cls -def use_onert(orig_class): - """ - Decorator to mark a test class so that Circle models are executed - with the 'onert' runtime. +def test_negative(expected_err): + """Mark a test class as negative test case with expected error""" - Useful when the default 'circle-interpreter' cannot run the model - under test. - """ - setattr(orig_class, "__tag_use_onert", True) - return orig_class + def decorator(cls): + TestTag.add(cls, "test_negative", {"expected_err": expected_err}) + return cls + + return decorator -def init_args(*args, **kwargs): - def __inner_init_args(orig_class): - orig_init = orig_class.__init__ - # Make copy of original __init__, so we can call it without recursion +def target(cls): + """Mark a test class as target test case""" + TestTag.add(cls, "target") + return cls - def __init__(self, *args_, **kwargs_): - args_ = (*args, *args_) - kwargs_ = {**kwargs, **kwargs_} - orig_init(self, *args_, **kwargs_) # Call the original __init__ +def use_onert(cls): + """Mark a test class to use ONERT runtime""" + TestTag.add(cls, "use_onert") + return cls - orig_class.__init__ = __init__ - return orig_class - return __inner_init_args +def test_without_pt2(cls): + """Mark a test class to not convert along pt2 during test execution""" + TestTag.add(cls, "test_without_pt2") + return cls -def is_tagged(cls, tag: str): - return hasattr(cls, f"__tag_{tag}") +def test_without_inference(cls): + """Mark a test class to not run inference during test execution""" + TestTag.add(cls, "test_without_inference") + return cls