Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from ax.core.observation import ObservationFeatures
from ax.core.trial_status import TrialStatus
from ax.core.utils import extend_pending_observations, extract_pending_observations
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.core import (
AxError,
DataRequiredError,
UnsupportedError,
UserInputError,
)
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
Expand Down Expand Up @@ -286,23 +291,24 @@ def gen_single_trial(
observations for that metric, used by some nodes to avoid
resuggesting points that are currently being evaluated.
"""
self.experiment = experiment

gr = self._gen_with_multiple_nodes(
grs_for_trials = self.gen(
experiment=experiment,
data=data,
n=n,
pending_observations=pending_observations,
n=n,
fixed_features=fixed_features,
num_trials=1,
)
if len(gr) > 1:
raise UnsupportedError(
"By calling into GenerationStrategy.gen(), you are should be "
"expecting a single `Trial` with only one `GeneratorRun`. However, "
"the underlying GenerationStrategy produced multiple `GeneratorRuns` "
f"and returned the following list of `GeneratorRun`-s: {gr}"
# `gen` returns list[list[GeneratorRun]], so grs_for_trials[0] is the
# list of GeneratorRuns for the first (and only) trial.
if len(grs_for_trials) != 1 or len(grs := grs_for_trials[0]) != 1:
raise AxError( # Unexpected state of the GS; raise informatively.
"By calling into GenerationStrategy.gen_single_trial(), you are should"
" be expecting a single `Trial` with only one `GeneratorRun`. However,"
"the underlying GenerationStrategy returned the following list "
f" of `GeneratorRun`-s: {grs_for_trials}."
)
return gr[0]
return grs[0]

def gen(
self,
Expand Down Expand Up @@ -359,19 +365,14 @@ def gen(
if pending_observations is None
else deepcopy(pending_observations)
)
# TODO[@drfreund, @mgarrard]: Can we avoid having to check all TCs here?
# To do so, we would need: 1) another way to understand that there are
# no trial-counting TCs with a trial limit, 2) a way to, during `_gen_from
# multiple_nodes`, stop once we've generated (limit - pre-existing trials)
# new trials (just checking TCs won't work because it will look at the number
# of trials on the experiment but not at the would-be trials already produced
# in the loop).
new_trials_limit = self._curr.new_trial_limit(raise_generation_errors=False)
if new_trials_limit == -1: # There is no additional limit on new trials.
num_trials = max(num_trials, 1)
else:
num_trials = max(min(num_trials, new_trials_limit), 1)
for _i in range(num_trials):
# Only check trial limit when requesting multiple trials; when num_trials <= 1,
# the result is always 1 regardless of the limit.
if num_trials > 1:
new_trials_limit = self._curr.new_trial_limit(raise_generation_errors=False)
if new_trials_limit != -1: # There is an additional limit on new trials.
num_trials = min(num_trials, new_trials_limit)
num_trials = max(num_trials, 1) # Ensure at least 1 trial
for _ in range(num_trials):
grs_for_multiple_trials.append(
self._gen_with_multiple_nodes(
experiment=experiment,
Expand Down
55 changes: 55 additions & 0 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_pending_observation_features_based_on_trial_status as get_pending,
)
from ax.exceptions.core import (
AxError,
DataRequiredError,
SearchSpaceExhausted,
UnsupportedError,
Expand Down Expand Up @@ -694,6 +695,60 @@ def test_store_experiment(self) -> None:
sobol_generation_strategy.gen_single_trial(exp)
self.assertIsNotNone(sobol_generation_strategy._experiment)

def test_gen_single_trial_extracts_pending_observations(self) -> None:
"""Test that gen_single_trial extracts pending_observations from the
experiment when none are passed in."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
# Create a trial and mark it as running so it becomes a pending observation.
trial = exp.new_trial(generator_run=gs.gen_single_trial(exp))
trial.mark_running(no_runner_required=True)

# Now call gen_single_trial without passing pending_observations.
# It should extract them from the experiment (the running trial's arms).
with mock_patch_method_original(
mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen",
original_method=GeneratorSpec.gen,
) as gen_spec_gen_mock:
gs.gen_single_trial(exp)
# Check that pending_observations was passed to the underlying gen call.
pending_obs = gen_spec_gen_mock.call_args.kwargs.get("pending_observations")
self.assertIsNotNone(pending_obs)
# The pending observations should contain the arm from the running trial.
expected_obs_feat = ObservationFeatures.from_arm(
arm=none_throws(trial.arm), trial_index=trial.index
)
for metric_name in exp.metrics:
self.assertIn(metric_name, pending_obs)
self.assertIn(expected_obs_feat, pending_obs[metric_name])

def test_gen_single_trial_raises_error_for_multiple_trials(self) -> None:
"""Test that gen_single_trial raises AxError if gen returns multiple trials."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
gr = gs.gen_single_trial(exp)
# Mock gen to return multiple trials
with patch.object(gs, "gen", return_value=[[gr], [gr]]):
with self.assertRaisesRegex(AxError, "single `Trial`"):
gs.gen_single_trial(exp)

def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None:
"""Test that gen_single_trial raises AxError if gen returns multiple
GeneratorRuns for a single trial."""
exp = get_branin_experiment()
gs = GenerationStrategy(
steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)]
)
gr = gs.gen_single_trial(exp)
# Mock gen to return a single trial with multiple GeneratorRuns
with patch.object(gs, "gen", return_value=[[gr, gr]]):
with self.assertRaisesRegex(AxError, "only one `GeneratorRun`"):
gs.gen_single_trial(exp)

def test_max_parallelism_reached(self) -> None:
exp = get_branin_experiment()
sobol_generation_strategy = GenerationStrategy(
Expand Down