Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions ax/utils/common/parameter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""Utilities for working with Ax parameters."""

from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.exceptions.core import UserInputError


def is_unordered_choice(
p: Parameter, min_choices: int | None = None, max_choices: int | None = None
) -> bool:
"""Returns whether a parameter is an unordered choice (categorical) parameter.

You can also specify `min_choices` and `max_choices` to restrict how many
possible values the parameter can take on.

Args:
p: Parameter.
min_choices: The minimum number of possible values for the parameter.
max_choices: The maximum number of possible values for the parameter.

Returns:
A boolean indicating whether p is an unordered choice parameter or not.
"""
if min_choices is not None and min_choices < 0:
raise UserInputError("`min_choices` must be a non-negative integer.")
if max_choices is not None and max_choices < 0:
raise UserInputError("`max_choices` must be a non-negative integer.")
if (
min_choices is not None
and max_choices is not None
and min_choices > max_choices
):
raise UserInputError("`min_choices` cannot be larger than than `max_choices`.")
return (
isinstance(p, ChoiceParameter)
and not p.is_ordered
and (min_choices is None or min_choices <= len(p.values))
and (max_choices is None or max_choices >= len(p.values))
)


def can_map_to_binary(p: Parameter) -> bool:
"""Returns whether a parameter can be transformed to a binary parameter.

Any choice/range parameters with exactly two values can be transformed to a
binary parameter.

Args:
p: Parameter.

Returns
A boolean indicating whether p can be transformed to a binary parameter.
"""
return (isinstance(p, ChoiceParameter) and len(p.values) == 2) or (
isinstance(p, RangeParameter)
and p.parameter_type == ParameterType.INT
and p.lower == p.upper - 1
)
89 changes: 89 additions & 0 deletions ax/utils/common/tests/test_parameter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest import TestCase

from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.types import TParamValue
from ax.exceptions.core import UserInputError
from ax.utils.common.parameter_utils import can_map_to_binary, is_unordered_choice


def get_unordered_choice(
parameter_type: ParameterType, values: list[TParamValue]
) -> ChoiceParameter:
return ChoiceParameter(
"p", parameter_type=parameter_type, values=values, is_ordered=False
)


def get_ordered_choice(
parameter_type: ParameterType, values: list[TParamValue]
) -> ChoiceParameter:
return ChoiceParameter(
"p", parameter_type=parameter_type, values=values, is_ordered=True
)


class TestParameterUtils(TestCase):
def test_can_map_to_binary(self) -> None:
for p in [
RangeParameter("p", parameter_type=ParameterType.INT, lower=0, upper=1),
RangeParameter("p", parameter_type=ParameterType.INT, lower=3, upper=4),
get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1]),
get_unordered_choice(
parameter_type=ParameterType.STRING, values=["a", "b"]
),
]:
self.assertTrue(can_map_to_binary(p))

for p in [
RangeParameter("p", parameter_type=ParameterType.FLOAT, lower=0, upper=1),
get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2]),
get_unordered_choice(
parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
]:
self.assertFalse(can_map_to_binary(p))

def test_is_unordered_choice_parameter(self) -> None:
for p in [
get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2]),
get_unordered_choice(
parameter_type=ParameterType.INT, values=[0, 1, 2, 4, 5]
),
get_unordered_choice(
parameter_type=ParameterType.STRING, values=["a", "b", "c", "d"]
),
]:
self.assertTrue(is_unordered_choice(p, min_choices=3, max_choices=5))

for p in [
get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1]),
get_ordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2, 4]),
RangeParameter("p", parameter_type=ParameterType.INT, lower=0, upper=3),
get_ordered_choice(
parameter_type=ParameterType.STRING, values=["0", "1", "2"]
),
]:
self.assertFalse(is_unordered_choice(p, min_choices=3, max_choices=5))

# Check exceptions
p = get_unordered_choice(parameter_type=ParameterType.INT, values=[0, 1, 2])
with self.assertRaisesRegex(
UserInputError, "`min_choices` must be a non-negative integer."
):
is_unordered_choice(p, min_choices=-3)
with self.assertRaisesRegex(
UserInputError, "`max_choices` must be a non-negative integer."
):
is_unordered_choice(p, max_choices=-1)
with self.assertRaisesRegex(
UserInputError, "`min_choices` cannot be larger than than `max_choices`."
):
is_unordered_choice(p, min_choices=3, max_choices=2)