From ae6aba4824bba7a59405cdcaccb0cda1443b0027 Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Sun, 7 Dec 2025 10:21:37 -0800 Subject: [PATCH] Move parameter utility functions to OSS Differential Revision: D88594524 --- ax/utils/common/parameter_utils.py | 65 ++++++++++++++ ax/utils/common/tests/test_parameter_utils.py | 89 +++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 ax/utils/common/parameter_utils.py create mode 100644 ax/utils/common/tests/test_parameter_utils.py diff --git a/ax/utils/common/parameter_utils.py b/ax/utils/common/parameter_utils.py new file mode 100644 index 00000000000..316d095303e --- /dev/null +++ b/ax/utils/common/parameter_utils.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +"""Utilities for working with Ax parameters.""" + +from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter +from ax.exceptions.core import UserInputError + + +def is_unordered_choice( + p: Parameter, min_choices: int | None = None, max_choices: int | None = None +) -> bool: + """Returns whether a parameter is an unordered choice (categorical) parameter. + + You can also specify `min_choices` and `max_choices` to restrict how many + possible values the parameter can take on. + + Args: + p: Parameter. + min_choices: The minimum number of possible values for the parameter. + max_choices: The maximum number of possible values for the parameter. + + Returns: + A boolean indicating whether p is an unordered choice parameter or not. + """ + if min_choices is not None and min_choices < 0: + raise UserInputError("`min_choices` must be a non-negative integer.") + if max_choices is not None and max_choices < 0: + raise UserInputError("`max_choices` must be a non-negative integer.") + if ( + min_choices is not None + and max_choices is not None + and min_choices > max_choices + ): + raise UserInputError("`min_choices` cannot be larger than than `max_choices`.") + return ( + isinstance(p, ChoiceParameter) + and not p.is_ordered + and (min_choices is None or min_choices <= len(p.values)) + and (max_choices is None or max_choices >= len(p.values)) + ) + + +def can_map_to_binary(p: Parameter) -> bool: + """Returns whether a parameter can be transformed to a binary parameter. + + Any choice/range parameters with exactly two values can be transformed to a + binary parameter. + + Args: + p: Parameter. + + Returns + A boolean indicating whether p can be transformed to a binary parameter. + """ + return (isinstance(p, ChoiceParameter) and len(p.values) == 2) or ( + isinstance(p, RangeParameter) + and p.parameter_type == ParameterType.INT + and p.lower == p.upper - 1 + ) diff --git a/ax/utils/common/tests/test_parameter_utils.py b/ax/utils/common/tests/test_parameter_utils.py new file mode 100644 index 00000000000..04e082ccbc9 --- /dev/null +++ b/ax/utils/common/tests/test_parameter_utils.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest import TestCase + +from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter +from ax.core.types import TParamValue +from ax.exceptions.core import UserInputError +from ax.utils.common.parameter_utils import can_map_to_binary, is_unordered_choice + + +def get_unordered_choice( + parameter_type: ParameterType, values: list[TParamValue] +) -> ChoiceParameter: + return ChoiceParameter( + "p", parameter_type=parameter_type, values=values, is_ordered=False + ) + + +def get_ordered_choice( + parameter_type: ParameterType, values: list[TParamValue] +) -> ChoiceParameter: + return ChoiceParameter( + "p", parameter_type=parameter_type, values=values, is_ordered=True + ) + + +class TestParameterUtils(TestCase): + def test_can_map_to_binary(self) -> None: + for p in [ + RangeParameter("p", parameter_type=ParameterType.INT, lower=0, upper=1), + RangeParameter("p", parameter_type=ParameterType.INT, lower=3, upper=4), + get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1]), + get_unordered_choice( + parameter_type=ParameterType.STRING, values=["a", "b"] + ), + ]: + self.assertTrue(can_map_to_binary(p)) + + for p in [ + RangeParameter("p", parameter_type=ParameterType.FLOAT, lower=0, upper=1), + get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2]), + get_unordered_choice( + parameter_type=ParameterType.STRING, values=["a", "b", "c"] + ), + ]: + self.assertFalse(can_map_to_binary(p)) + + def test_is_unordered_choice_parameter(self) -> None: + for p in [ + get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2]), + get_unordered_choice( + parameter_type=ParameterType.INT, values=[0, 1, 2, 4, 5] + ), + get_unordered_choice( + parameter_type=ParameterType.STRING, values=["a", "b", "c", "d"] + ), + ]: + self.assertTrue(is_unordered_choice(p, min_choices=3, max_choices=5)) + + for p in [ + get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1]), + get_ordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2, 4]), + RangeParameter("p", parameter_type=ParameterType.INT, lower=0, upper=3), + get_ordered_choice( + parameter_type=ParameterType.STRING, values=["0", "1", "2"] + ), + ]: + self.assertFalse(is_unordered_choice(p, min_choices=3, max_choices=5)) + + # Check exceptions + p = get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2]) + with self.assertRaisesRegex( + UserInputError, "`min_choices` must be a non-negative integer." + ): + is_unordered_choice(p, min_choices=-3) + with self.assertRaisesRegex( + UserInputError, "`max_choices` must be a non-negative integer." + ): + is_unordered_choice(p, max_choices=-1) + with self.assertRaisesRegex( + UserInputError, "`min_choices` cannot be larger than than `max_choices`." + ): + is_unordered_choice(p, min_choices=3, max_choices=2)