diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 76d5206464f..86bc0b91a45 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -21,7 +21,12 @@ from ax.core.observation import ObservationFeatures from ax.core.trial_status import TrialStatus from ax.core.utils import extend_pending_observations, extract_pending_observations -from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError +from ax.exceptions.core import ( + AxError, + DataRequiredError, + UnsupportedError, + UserInputError, +) from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, @@ -286,23 +291,24 @@ def gen_single_trial( observations for that metric, used by some nodes to avoid resuggesting points that are currently being evaluated. """ - self.experiment = experiment - - gr = self._gen_with_multiple_nodes( + grs_for_trials = self.gen( experiment=experiment, data=data, - n=n, pending_observations=pending_observations, + n=n, fixed_features=fixed_features, + num_trials=1, ) - if len(gr) > 1: - raise UnsupportedError( - "By calling into GenerationStrategy.gen(), you are should be " - "expecting a single `Trial` with only one `GeneratorRun`. However, " - "the underlying GenerationStrategy produced multiple `GeneratorRuns` " - f"and returned the following list of `GeneratorRun`-s: {gr}" + # `gen` returns list[list[GeneratorRun]], so grs_for_trials[0] is the + # list of GeneratorRuns for the first (and only) trial. + if len(grs_for_trials) != 1 or len(grs := grs_for_trials[0]) != 1: + raise AxError( # Unexpected state of the GS; raise informatively. + "By calling into GenerationStrategy.gen_single_trial(), you are should" + " be expecting a single `Trial` with only one `GeneratorRun`. However," + "the underlying GenerationStrategy returned the following list " + f" of `GeneratorRun`-s: {grs_for_trials}." ) - return gr[0] + return grs[0] def gen( self, @@ -359,19 +365,14 @@ def gen( if pending_observations is None else deepcopy(pending_observations) ) - # TODO[@drfreund, @mgarrard]: Can we avoid having to check all TCs here? - # To do so, we would need: 1) another way to understand that there are - # no trial-counting TCs with a trial limit, 2) a way to, during `_gen_from - # multiple_nodes`, stop once we've generated (limit - pre-existing trials) - # new trials (just checking TCs won't work because it will look at the number - # of trials on the experiment but not at the would-be trials already produced - # in the loop). - new_trials_limit = self._curr.new_trial_limit(raise_generation_errors=False) - if new_trials_limit == -1: # There is no additional limit on new trials. - num_trials = max(num_trials, 1) - else: - num_trials = max(min(num_trials, new_trials_limit), 1) - for _i in range(num_trials): + # Only check trial limit when requesting multiple trials; when num_trials <= 1, + # the result is always 1 regardless of the limit. + if num_trials > 1: + new_trials_limit = self._curr.new_trial_limit(raise_generation_errors=False) + if new_trials_limit != -1: # There is an additional limit on new trials. + num_trials = min(num_trials, new_trials_limit) + num_trials = max(num_trials, 1) # Ensure at least 1 trial + for _ in range(num_trials): grs_for_multiple_trials.append( self._gen_with_multiple_nodes( experiment=experiment, diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 654977da5e0..472e2d6fb57 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -33,6 +33,7 @@ get_pending_observation_features_based_on_trial_status as get_pending, ) from ax.exceptions.core import ( + AxError, DataRequiredError, SearchSpaceExhausted, UnsupportedError, @@ -694,6 +695,60 @@ def test_store_experiment(self) -> None: sobol_generation_strategy.gen_single_trial(exp) self.assertIsNotNone(sobol_generation_strategy._experiment) + def test_gen_single_trial_extracts_pending_observations(self) -> None: + """Test that gen_single_trial extracts pending_observations from the + experiment when none are passed in.""" + exp = get_branin_experiment() + gs = GenerationStrategy( + steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + ) + # Create a trial and mark it as running so it becomes a pending observation. + trial = exp.new_trial(generator_run=gs.gen_single_trial(exp)) + trial.mark_running(no_runner_required=True) + + # Now call gen_single_trial without passing pending_observations. + # It should extract them from the experiment (the running trial's arms). + with mock_patch_method_original( + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, + ) as gen_spec_gen_mock: + gs.gen_single_trial(exp) + # Check that pending_observations was passed to the underlying gen call. + pending_obs = gen_spec_gen_mock.call_args.kwargs.get("pending_observations") + self.assertIsNotNone(pending_obs) + # The pending observations should contain the arm from the running trial. + expected_obs_feat = ObservationFeatures.from_arm( + arm=none_throws(trial.arm), trial_index=trial.index + ) + for metric_name in exp.metrics: + self.assertIn(metric_name, pending_obs) + self.assertIn(expected_obs_feat, pending_obs[metric_name]) + + def test_gen_single_trial_raises_error_for_multiple_trials(self) -> None: + """Test that gen_single_trial raises AxError if gen returns multiple trials.""" + exp = get_branin_experiment() + gs = GenerationStrategy( + steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + ) + gr = gs.gen_single_trial(exp) + # Mock gen to return multiple trials + with patch.object(gs, "gen", return_value=[[gr], [gr]]): + with self.assertRaisesRegex(AxError, "single `Trial`"): + gs.gen_single_trial(exp) + + def test_gen_single_trial_raises_error_for_multiple_generator_runs(self) -> None: + """Test that gen_single_trial raises AxError if gen returns multiple + GeneratorRuns for a single trial.""" + exp = get_branin_experiment() + gs = GenerationStrategy( + steps=[GenerationStep(generator=Generators.SOBOL, num_trials=5)] + ) + gr = gs.gen_single_trial(exp) + # Mock gen to return a single trial with multiple GeneratorRuns + with patch.object(gs, "gen", return_value=[[gr, gr]]): + with self.assertRaisesRegex(AxError, "only one `GeneratorRun`"): + gs.gen_single_trial(exp) + def test_max_parallelism_reached(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy(