From 16608c80978fb7eee882f385afb17ff1c21df5dc Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 1/6] Show Pareto frontier on MOO objective scatter plots (#4708) Summary: When `ResultsAnalysis` creates scatter plots for multi-objective optimization (MOO) experiments, the Pareto frontier line is not currently being displayed. This diff enables the Pareto frontier visualization by passing `show_pareto_frontier=True` to `ScatterPlot` when creating scatter plots for objective pairs in MOO experiments. The Pareto frontier is rendered as a dashed gold line connecting the non-dominated points, helping users identify the set of optimal solutions. Differential Revision: D89775987 Privacy Context Container: L1307644 --- ax/analysis/plotly/scatter.py | 6 +++++- ax/analysis/plotly/tests/test_scatter.py | 20 ++++++++++++++++++++ ax/analysis/results.py | 3 ++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ax/analysis/plotly/scatter.py b/ax/analysis/plotly/scatter.py index a8a10cbc62c..83ebe14f346 100644 --- a/ax/analysis/plotly/scatter.py +++ b/ax/analysis/plotly/scatter.py @@ -571,7 +571,11 @@ def _prepare_figure( pareto_x.append(sorted_df[f"{x_metric_name}_mean"].iloc[i]) pareto_y.append(sorted_df[f"{y_metric_name}_mean"].iloc[i]) - pareto_trace = go.Scatter(x=pareto_x, y=pareto_y, **BEST_LINE_SETTINGS) + pareto_trace = go.Scatter( + x=pareto_x, + y=pareto_y, + **{**BEST_LINE_SETTINGS, "showlegend": True, "name": "Pareto Frontier"}, + ) figure.add_trace(pareto_trace) diff --git a/ax/analysis/plotly/tests/test_scatter.py b/ax/analysis/plotly/tests/test_scatter.py index 59b5e4a5568..71dfefb59c2 100644 --- a/ax/analysis/plotly/tests/test_scatter.py +++ b/ax/analysis/plotly/tests/test_scatter.py @@ -142,6 +142,26 @@ def test_compute_raw(self) -> None: self.assertTrue(card.df["foo_sem"].isna().all()) self.assertTrue(card.df["bar_sem"].isna().all()) + def test_show_pareto_frontier(self) -> None: + analysis = ScatterPlot( + x_metric_name="foo", + y_metric_name="bar", + show_pareto_frontier=True, + use_model_predictions=False, + ) + card = analysis.compute( + experiment=self.client._experiment, + generation_strategy=self.client._generation_strategy, + ) + fig_data = json.loads(none_throws(card.blob)) + pareto_traces = [ + trace + for trace in fig_data.get("data", []) + if trace.get("name") == "Pareto Frontier" + ] + self.assertEqual(len(pareto_traces), 1) + self.assertTrue(pareto_traces[0].get("showlegend")) + def test_compute_with_modeled(self) -> None: default_analysis = ScatterPlot( x_metric_name="foo", y_metric_name="bar", use_model_predictions=True diff --git a/ax/analysis/results.py b/ax/analysis/results.py index cd61e820c8f..63b5dd791fe 100644 --- a/ax/analysis/results.py +++ b/ax/analysis/results.py @@ -120,7 +120,7 @@ def compute( ) # If there are multiple objectives, compute scatter plots of each combination - # of two objectives. + # of two objectives. For MOO experiments, show the Pareto frontier line. objective_scatter_group = ( AnalysisCardGroup( name="Objective Scatter Plots", @@ -131,6 +131,7 @@ def compute( x_metric_name=x, y_metric_name=y, relativize=relativize, + show_pareto_frontier=True, ).compute_or_error_card( experiment=experiment, generation_strategy=generation_strategy, From 893ca42b27bd29beba10ec3bbe5ebdb480db1efa Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 2/6] Add Progression Plots for MapMetric experiments to ResultsAnalysis (#4705) Summary: This diff adds learning curve visualization (progression plots) to ResultsAnalysis for experiments with MapData and MapMetrics. Differential Revision: D89776181 Privacy Context Container: L1307644 --- ax/analysis/plotly/progression.py | 9 ++++++++ ax/analysis/results.py | 35 +++++++++++++++++++++++++++++++ ax/analysis/tests/test_results.py | 34 ++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/ax/analysis/plotly/progression.py b/ax/analysis/plotly/progression.py index 53ad42174f9..b2a6e4d8253 100644 --- a/ax/analysis/plotly/progression.py +++ b/ax/analysis/plotly/progression.py @@ -26,6 +26,15 @@ from plotly import graph_objects as go from pyre_extensions import none_throws, override +PROGRESSION_CARDGROUP_TITLE = "Learning Curves: Metric progression over trials" +PROGRESSION_CARDGROUP_SUBTITLE = ( + "These plots show curve metrics (learning curves) that track the evolution of " + "each metric over the course of the experiment. The plots display how metrics " + "change during trial execution, both by progression (e.g., epochs or steps) " + "and by wallclock time. This is useful for monitoring optimization progress and " + "informing early stopping decisions." +) + @final class ProgressionPlot(Analysis): diff --git a/ax/analysis/results.py b/ax/analysis/results.py index 63b5dd791fe..49fec7c2dfa 100644 --- a/ax/analysis/results.py +++ b/ax/analysis/results.py @@ -13,6 +13,11 @@ from ax.analysis.best_trials import BestTrials from ax.analysis.plotly.arm_effects import ArmEffectsPlot from ax.analysis.plotly.bandit_rollout import BanditRollout +from ax.analysis.plotly.progression import ( + PROGRESSION_CARDGROUP_SUBTITLE, + PROGRESSION_CARDGROUP_TITLE, + ProgressionPlot, +) from ax.analysis.plotly.scatter import ( SCATTER_CARDGROUP_SUBTITLE, SCATTER_CARDGROUP_TITLE, @@ -25,6 +30,8 @@ from ax.core.arm import Arm from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment +from ax.core.map_data import MapData +from ax.core.map_metric import MapMetric from ax.core.outcome_constraint import ScalarizedOutcomeConstraint from ax.core.trial_status import TrialStatus from ax.core.utils import is_bandit_experiment @@ -240,6 +247,33 @@ def compute( adapter=adapter, ) + # Compute progression plots for MapMetrics (learning curves) + progression_group = None + data = experiment.lookup_data() + has_map_data = isinstance(data, MapData) + metrics = experiment.metrics.values() + map_metrics = [m for m in metrics if isinstance(m, MapMetric)] + if has_map_data and len(map_metrics) > 0: + map_metric_names = [m.name for m in map_metrics] + progression_cards = [ + ProgressionPlot( + metric_name=metric_name, by_wallclock_time=by_wallclock_time + ).compute_or_error_card( + experiment=experiment, + generation_strategy=generation_strategy, + adapter=adapter, + ) + for metric_name in map_metric_names + for by_wallclock_time in (False, True) + ] + if progression_cards: + progression_group = AnalysisCardGroup( + name="ProgressionAnalysis", + title=PROGRESSION_CARDGROUP_TITLE, + subtitle=PROGRESSION_CARDGROUP_SUBTITLE, + children=progression_cards, + ) + return self._create_analysis_card_group( title=RESULTS_CARDGROUP_TITLE, subtitle=RESULTS_CARDGROUP_SUBTITLE, @@ -252,6 +286,7 @@ def compute( bandit_rollout_card, best_trials_card, utility_progression_card, + progression_group, summary, ) if child is not None diff --git a/ax/analysis/tests/test_results.py b/ax/analysis/tests/test_results.py index cd3dfa296ef..2f99cf1bffa 100644 --- a/ax/analysis/tests/test_results.py +++ b/ax/analysis/tests/test_results.py @@ -35,6 +35,7 @@ get_experiment_with_scalarized_objective_and_outcome_constraint, get_offline_experiments, get_online_experiments, + get_test_map_data_experiment, ) from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.modeling_stubs import get_default_generation_strategy_at_MBM_node @@ -499,6 +500,39 @@ def test_offline_experiments(self) -> None: self.assertIsNotNone(card_group) self.assertGreater(len(card_group.children), 0) + @mock_botorch_optimize + def test_compute_with_map_data_includes_progression_plots(self) -> None: + # Setup: Create experiment with MapData and MapMetrics + experiment = get_test_map_data_experiment( + num_trials=3, num_fetches=2, num_complete=2 + ) + generation_strategy = get_default_generation_strategy_at_MBM_node( + experiment=experiment + ) + + # Execute: Compute ResultsAnalysis + card_group = ResultsAnalysis().compute( + experiment=experiment, + generation_strategy=generation_strategy, + ) + + # Assert: ProgressionAnalysis group exists with children + progression_group = None + for child in card_group.children: + if child.name == "ProgressionAnalysis": + progression_group = child + break + + self.assertIsNotNone( + progression_group, + "ProgressionAnalysis group should be present for MapMetric experiments", + ) + self.assertGreater( + len(assert_is_instance(progression_group, AnalysisCardGroup).children), + 0, + "ProgressionAnalysis group should have at least one progression plot", + ) + class TestArmEffectsPair(TestCase): @mock_botorch_optimize From b9207ff72168ab05b269438c9e0841e81b9f75e5 Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 3/6] Decouple summarize_ax_optimization_complexity from OrchestratorOptions (#4706) Summary: This change decouples `summarize_ax_optimization_complexity` from requiring an `OrchestratorOptions` instance by accepting individual optional fields instead. Differential Revision: D89778530 --- ax/analysis/healthcheck/complexity_rating.py | 8 +++- ax/service/utils/orchestrator_options.py | 13 ++++-- ax/utils/common/complexity_utils.py | 46 +++++++++++++++---- .../common/tests/test_complexity_utils.py | 28 ----------- 4 files changed, 53 insertions(+), 42 deletions(-) diff --git a/ax/analysis/healthcheck/complexity_rating.py b/ax/analysis/healthcheck/complexity_rating.py index b4af7cdce2b..7484cb6f9a9 100644 --- a/ax/analysis/healthcheck/complexity_rating.py +++ b/ax/analysis/healthcheck/complexity_rating.py @@ -138,8 +138,14 @@ def compute( options = none_throws(self.options) optimization_summary = summarize_ax_optimization_complexity( experiment=experiment, - options=options, tier_metadata=self.tier_metadata, + early_stopping_strategy=options.early_stopping_strategy, + global_stopping_strategy=options.global_stopping_strategy, + tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, + max_pending_trials=options.max_pending_trials, + min_failed_trials_for_failure_rate_check=( + options.min_failed_trials_for_failure_rate_check + ), ) # Determine tier diff --git a/ax/service/utils/orchestrator_options.py b/ax/service/utils/orchestrator_options.py index 70ab3b04cf3..2534e276bf4 100644 --- a/ax/service/utils/orchestrator_options.py +++ b/ax/service/utils/orchestrator_options.py @@ -13,6 +13,11 @@ from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy +# Default values for OrchestratorOptions fields +DEFAULT_MAX_PENDING_TRIALS: int = 10 +DEFAULT_TOLERATED_TRIAL_FAILURE_RATE: float = 0.5 +DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK: int = 5 + class TrialType(Enum): TRIAL = 0 @@ -125,12 +130,14 @@ class OrchestratorOptions: Default to False. """ - max_pending_trials: int = 10 + max_pending_trials: int = DEFAULT_MAX_PENDING_TRIALS trial_type: TrialType = TrialType.TRIAL batch_size: int | None = None total_trials: int | None = None - tolerated_trial_failure_rate: float = 0.5 - min_failed_trials_for_failure_rate_check: int = 5 + tolerated_trial_failure_rate: float = DEFAULT_TOLERATED_TRIAL_FAILURE_RATE + min_failed_trials_for_failure_rate_check: int = ( + DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK + ) log_filepath: str | None = None logging_level: int = INFO ttl_seconds_for_trials: int | None = None diff --git a/ax/utils/common/complexity_utils.py b/ax/utils/common/complexity_utils.py index 795599de7c2..75a863343d2 100644 --- a/ax/utils/common/complexity_utils.py +++ b/ax/utils/common/complexity_utils.py @@ -12,8 +12,14 @@ from ax.adapter.adapter_utils import can_map_to_binary, is_unordered_choice from ax.core.experiment import Experiment from ax.core.objective import MultiObjective +from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError -from ax.service.orchestrator import OrchestratorOptions +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy +from ax.service.utils.orchestrator_options import ( + DEFAULT_MAX_PENDING_TRIALS, + DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK, + DEFAULT_TOLERATED_TRIAL_FAILURE_RATE, +) STANDARD_TIER_MESSAGE = """This experiment is in tier 'Standard'. @@ -141,8 +147,14 @@ class OptimizationSummary: def summarize_ax_optimization_complexity( experiment: Experiment, - options: OrchestratorOptions, tier_metadata: dict[str, Any], + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, + global_stopping_strategy: BaseGlobalStoppingStrategy | None = None, + tolerated_trial_failure_rate: float = DEFAULT_TOLERATED_TRIAL_FAILURE_RATE, + max_pending_trials: int = DEFAULT_MAX_PENDING_TRIALS, + min_failed_trials_for_failure_rate_check: int = ( + DEFAULT_MIN_FAILED_TRIALS_FOR_FAILURE_RATE_CHECK + ), ) -> OptimizationSummary: """Summarize the experiment's optimization complexity. @@ -151,11 +163,25 @@ def summarize_ax_optimization_complexity( Args: experiment: The Ax Experiment. - options: The orchestrator options. - tier_metadata: tier-related meta-data from the orchestrator. + tier_metadata: Tier-related metadata from the orchestrator. Supported keys: + - 'user_supplied_max_trials': Maximum number of trials. + - 'uses_standard_api': Whether standard api is used, ensuring the full + experiment configuration is known upfront. + - 'all_inputs_are_configs': Whether high-level configs are used (as + opposed to low-level Ax abstractions), ensuring the full + experiment configuration is known upfront. + early_stopping_strategy: The early stopping strategy, if any. Used to + determine if early stopping is enabled. Defaults to None. + global_stopping_strategy: The global stopping strategy, if any. Used to + determine if global stopping is enabled. Defaults to None. + tolerated_trial_failure_rate: Fraction of trials allowed to fail without + the whole optimization ending. Defaults to 0.5. + max_pending_trials: Maximum number of pending trials. Defaults to 10. + min_failed_trials_for_failure_rate_check: Minimum failed trials before + failure rate is checked. Defaults to 5. Returns: - A dictionary summarizing the experiment. + An OptimizationSummary containing experiment complexity metrics. """ search_space = experiment.search_space optimization_config = experiment.optimization_config @@ -179,8 +205,8 @@ def summarize_ax_optimization_complexity( else 1 ) num_outcome_constraints = len(optimization_config.outcome_constraints) - uses_early_stopping = options.early_stopping_strategy is not None - uses_global_stopping = options.global_stopping_strategy is not None + uses_early_stopping = early_stopping_strategy is not None + uses_global_stopping = global_stopping_strategy is not None # Check if any metrics use merge_multiple_curves uses_merge_multiple_curves = False @@ -210,10 +236,10 @@ def summarize_ax_optimization_complexity( uses_global_stopping=uses_global_stopping, uses_merge_multiple_curves=uses_merge_multiple_curves, uses_standard_api=uses_standard_api, - tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, - max_pending_trials=options.max_pending_trials, + tolerated_trial_failure_rate=tolerated_trial_failure_rate, + max_pending_trials=max_pending_trials, min_failed_trials_for_failure_rate_check=( - options.min_failed_trials_for_failure_rate_check + min_failed_trials_for_failure_rate_check ), ) diff --git a/ax/utils/common/tests/test_complexity_utils.py b/ax/utils/common/tests/test_complexity_utils.py index 2cbe221662d..8578ac86084 100644 --- a/ax/utils/common/tests/test_complexity_utils.py +++ b/ax/utils/common/tests/test_complexity_utils.py @@ -8,7 +8,6 @@ from ax.core.metric import Metric from ax.exceptions.core import OptimizationNotConfiguredError, UserInputError -from ax.service.orchestrator import OrchestratorOptions from ax.utils.common.complexity_utils import ( check_if_in_standard, DEFAULT_TIER_MESSAGES, @@ -30,7 +29,6 @@ class TestSummarizeAxOptimizationComplexity(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_experiment() - self.options = OrchestratorOptions() self.tier_metadata: dict[str, object] = {} def test_basic_experiment_summary(self) -> None: @@ -39,7 +37,6 @@ def test_basic_experiment_summary(self) -> None: # WHEN we summarize the experiment summary = summarize_ax_optimization_complexity( experiment=self.experiment, - options=self.options, tier_metadata=self.tier_metadata, ) @@ -58,7 +55,6 @@ def test_multi_objective_experiment(self) -> None: # WHEN we summarize the experiment summary = summarize_ax_optimization_complexity( experiment=experiment, - options=self.options, tier_metadata=self.tier_metadata, ) @@ -76,7 +72,6 @@ def test_experiment_without_optimization_config_raises(self) -> None: ): summarize_ax_optimization_complexity( experiment=self.experiment, - options=self.options, tier_metadata=self.tier_metadata, ) @@ -107,7 +102,6 @@ def test_tier_metadata_extraction(self) -> None: # WHEN we summarize the experiment summary = summarize_ax_optimization_complexity( experiment=self.experiment, - options=self.options, tier_metadata=tier_metadata, ) @@ -115,26 +109,6 @@ def test_tier_metadata_extraction(self) -> None: self.assertEqual(summary.max_trials, expected_max_trials) self.assertEqual(summary.uses_standard_api, expected_all_configs) - def test_orchestrator_options_extraction(self) -> None: - # GIVEN custom orchestrator options - options = OrchestratorOptions( - tolerated_trial_failure_rate=0.25, - max_pending_trials=5, - min_failed_trials_for_failure_rate_check=10, - ) - - # WHEN we summarize the experiment - summary = summarize_ax_optimization_complexity( - experiment=self.experiment, - options=options, - tier_metadata=self.tier_metadata, - ) - - # THEN the summary should reflect orchestrator options - self.assertEqual(summary.tolerated_trial_failure_rate, 0.25) - self.assertEqual(summary.max_pending_trials, 5) - self.assertEqual(summary.min_failed_trials_for_failure_rate_check, 10) - def test_parameter_constraints_counted(self) -> None: # GIVEN an experiment with parameter constraints experiment = get_experiment(constrain_search_space=True) @@ -142,7 +116,6 @@ def test_parameter_constraints_counted(self) -> None: # WHEN we summarize the experiment summary = summarize_ax_optimization_complexity( experiment=experiment, - options=self.options, tier_metadata=self.tier_metadata, ) @@ -159,7 +132,6 @@ def test_merge_multiple_curves_detection(self) -> None: # WHEN we summarize the experiment summary = summarize_ax_optimization_complexity( experiment=self.experiment, - options=self.options, tier_metadata=self.tier_metadata, ) From 94a61d46cba461bdd37d3b8e4a86f28409d3b402 Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 4/6] Update Complexity Rating Healthcheck (#4707) Summary: This diff refactors ComplexityRatingAnalysis to decouple it from OrchestratorOptions by accepting individual configuration parameters directly instead of the entire options object. Differential Revision: D89778632 Privacy Context Container: L1307644 --- ax/analysis/healthcheck/complexity_rating.py | 53 +++--- .../tests/test_complexity_rating.py | 169 ++++++++---------- ax/analysis/utils.py | 18 +- 3 files changed, 127 insertions(+), 113 deletions(-) diff --git a/ax/analysis/healthcheck/complexity_rating.py b/ax/analysis/healthcheck/complexity_rating.py index 7484cb6f9a9..fb015309294 100644 --- a/ax/analysis/healthcheck/complexity_rating.py +++ b/ax/analysis/healthcheck/complexity_rating.py @@ -15,9 +15,11 @@ HealthcheckAnalysisCard, HealthcheckStatus, ) +from ax.analysis.utils import filter_none from ax.core.experiment import Experiment +from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.generation_strategy.generation_strategy import GenerationStrategy -from ax.service.orchestrator import OrchestratorOptions +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from ax.utils.common.complexity_utils import ( check_if_in_standard, DEFAULT_TIER_MESSAGES, @@ -61,16 +63,17 @@ class ComplexityRatingAnalysis(Analysis): def __init__( self, - options: OrchestratorOptions | None = None, tier_metadata: dict[str, Any] | None = None, tier_messages: TierMessages = DEFAULT_TIER_MESSAGES, + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, + global_stopping_strategy: BaseGlobalStoppingStrategy | None = None, + tolerated_trial_failure_rate: float | None = None, + max_pending_trials: int | None = None, + min_failed_trials_for_failure_rate_check: int | None = None, ) -> None: """Initialize the ComplexityRatingAnalysis. Args: - options: The orchestrator options used for the optimization. - Required to evaluate early stopping, global stopping, and - failure rate settings. tier_metadata: Additional tier-related metadata from the orchestrator. Supported keys: - 'user_supplied_max_trials': Maximum number of trials. @@ -82,12 +85,27 @@ def __init__( generic messages suitable for most users. Pass a custom TierMessages instance to provide tool-specific descriptions, support SLAs, links to docs, or contact information. + early_stopping_strategy: The early stopping strategy, if any. Used to + determine if early stopping is enabled. Defaults to None. + global_stopping_strategy: The global stopping strategy, if any. Used to + determine if global stopping is enabled. Defaults to None. + tolerated_trial_failure_rate: Fraction of trials allowed to fail without + the whole optimization ending. Default value used is 0.5. + max_pending_trials: Maximum number of pending trials. Default used is 10. + min_failed_trials_for_failure_rate_check: Minimum failed trials before + failure rate is checked. Default value used is 5. """ - self.options = options self.tier_metadata: dict[str, Any] = ( tier_metadata if tier_metadata is not None else {} ) self.tier_messages = tier_messages + self.early_stopping_strategy = early_stopping_strategy + self.global_stopping_strategy = global_stopping_strategy + self.tolerated_trial_failure_rate = tolerated_trial_failure_rate + self.max_pending_trials = max_pending_trials + self.min_failed_trials_for_failure_rate_check = ( + min_failed_trials_for_failure_rate_check + ) @override def validate_applicable_state( @@ -98,11 +116,6 @@ def validate_applicable_state( ) -> str | None: if experiment is None: return "Experiment is required for ComplexityRatingAnalysis." - if self.options is None: - return ( - "OrchestratorOptions is required for ComplexityRatingAnalysis. " - "Please pass options to the constructor." - ) return None @override @@ -120,8 +133,7 @@ def compute( Note: This method assumes ``validate_applicable_state`` has been called - and returned None, ensuring ``experiment`` and ``self.options`` - are not None. + and returned None, ensuring ``experiment`` is not None. Args: experiment: The Ax Experiment to analyze. Must not be None. @@ -135,16 +147,17 @@ def compute( with key experiment metrics. """ experiment = none_throws(experiment) - options = none_throws(self.options) optimization_summary = summarize_ax_optimization_complexity( experiment=experiment, tier_metadata=self.tier_metadata, - early_stopping_strategy=options.early_stopping_strategy, - global_stopping_strategy=options.global_stopping_strategy, - tolerated_trial_failure_rate=options.tolerated_trial_failure_rate, - max_pending_trials=options.max_pending_trials, - min_failed_trials_for_failure_rate_check=( - options.min_failed_trials_for_failure_rate_check + early_stopping_strategy=self.early_stopping_strategy, + global_stopping_strategy=self.global_stopping_strategy, + **filter_none( + tolerated_trial_failure_rate=self.tolerated_trial_failure_rate, + max_pending_trials=self.max_pending_trials, + min_failed_trials_for_failure_rate_check=( + self.min_failed_trials_for_failure_rate_check + ), ), ) diff --git a/ax/analysis/healthcheck/tests/test_complexity_rating.py b/ax/analysis/healthcheck/tests/test_complexity_rating.py index d32ed272d42..f9b189395f1 100644 --- a/ax/analysis/healthcheck/tests/test_complexity_rating.py +++ b/ax/analysis/healthcheck/tests/test_complexity_rating.py @@ -20,48 +20,36 @@ from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import SearchSpace from ax.core.types import ComparisonOp -from ax.service.orchestrator import OrchestratorOptions from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_improvement_global_stopping_strategy, + get_percentile_early_stopping_strategy, +) class TestComplexityRatingAnalysis(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_branin_experiment() - self.options = OrchestratorOptions() self.tier_metadata: dict[str, object] = { "user_supplied_max_trials": 100, "uses_standard_api": True, } def test_validate_applicable_state_requires_experiment(self) -> None: - healthcheck = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ) + healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata) result = healthcheck.validate_applicable_state(experiment=None) self.assertIsNotNone(result) self.assertIn("Experiment is required", result) - def test_validate_applicable_state_requires_options(self) -> None: - healthcheck = ComplexityRatingAnalysis( - options=None, tier_metadata=self.tier_metadata - ) - result = healthcheck.validate_applicable_state(experiment=self.experiment) - self.assertIsNotNone(result) - self.assertIn("OrchestratorOptions is required", result) - def test_validate_applicable_state_passes_with_valid_inputs(self) -> None: - healthcheck = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ) + healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata) result = healthcheck.validate_applicable_state(experiment=self.experiment) self.assertIsNone(result) def test_standard_configuration(self) -> None: - healthcheck = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ) + healthcheck = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata) card = healthcheck.compute(experiment=self.experiment) self.assertEqual(card.name, "ComplexityRatingAnalysis") @@ -90,7 +78,7 @@ def test_parameter_counts(self) -> None: ] self.experiment._search_space = SearchSpace(parameters=params) card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata + tier_metadata=self.tier_metadata ).compute(experiment=self.experiment) self.assertEqual(card.get_status(), expected_status) @@ -115,7 +103,7 @@ def test_objectives_count(self) -> None: ) ) card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata + tier_metadata=self.tier_metadata ).compute(experiment=self.experiment) self.assertEqual(card.get_status(), expected_status) @@ -132,9 +120,9 @@ def test_constraints(self) -> None: for m in metrics ], ) - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute( + experiment=self.experiment + ) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) self.assertIn("Advanced", card.subtitle) @@ -156,35 +144,34 @@ def test_constraints(self) -> None: self.experiment._search_space = SearchSpace( parameters=params, parameter_constraints=parameter_constraints ) - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute( + experiment=self.experiment + ) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) self.assertIn("Advanced", card.subtitle) self.assertIn("3 parameter constraints", card.subtitle) def test_stopping_strategies(self) -> None: - test_cases = [ - ("early_stopping", True, False, "Early stopping"), - ("global_stopping", False, True, "Global stopping"), - ] + with self.subTest(strategy="early_stopping"): + card = ComplexityRatingAnalysis( + tier_metadata=self.tier_metadata, + early_stopping_strategy=get_percentile_early_stopping_strategy(), + ).compute(experiment=self.experiment) - for name, uses_early, uses_global, expected_msg in test_cases: - with self.subTest(strategy=name): - options = OrchestratorOptions( - # pyre-fixme[6]: Using a mock value for testing - early_stopping_strategy="mock" if uses_early else None, - # pyre-fixme[6]: Using a mock value for testing - global_stopping_strategy="mock" if uses_global else None, - ) - card = ComplexityRatingAnalysis( - options=options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + self.assertIn("Advanced", card.subtitle) + self.assertIn("Early stopping", card.subtitle) - self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) - self.assertIn("Advanced", card.subtitle) - self.assertIn(expected_msg, card.subtitle) + with self.subTest(strategy="global_stopping"): + card = ComplexityRatingAnalysis( + tier_metadata=self.tier_metadata, + global_stopping_strategy=get_improvement_global_stopping_strategy(), + ).compute(experiment=self.experiment) + + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + self.assertIn("Advanced", card.subtitle) + self.assertIn("Global stopping", card.subtitle) def test_trial_counts(self) -> None: test_cases = [ @@ -198,48 +185,47 @@ def test_trial_counts(self) -> None: "user_supplied_max_trials": max_trials, "uses_standard_api": True, } - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=tier_metadata).compute( + experiment=self.experiment + ) self.assertEqual(card.get_status(), expected_status) self.assertIn(expected_tier, card.subtitle) self.assertIn(expected_msg, card.subtitle) def test_unsupported_configurations(self) -> None: - test_cases = [ - ( - "not_using_standard_api", - OrchestratorOptions(), - {"user_supplied_max_trials": 100, "uses_standard_api": False}, - "uses_standard_api=False", - ), - ( - "high_failure_rate", - OrchestratorOptions(tolerated_trial_failure_rate=0.95), - {"user_supplied_max_trials": 100, "uses_standard_api": True}, - "0.95", - ), - ( - "invalid_failure_rate_check", - OrchestratorOptions( - max_pending_trials=10, - min_failed_trials_for_failure_rate_check=50, - ), - {"user_supplied_max_trials": 100, "uses_standard_api": True}, - "min_failed_trials_for_failure_rate_check", - ), - ] + with self.subTest(config="not_using_standard_api"): + tier_metadata = { + "user_supplied_max_trials": 100, + "uses_standard_api": False, + } + card = ComplexityRatingAnalysis(tier_metadata=tier_metadata).compute( + experiment=self.experiment + ) + self.assertEqual(card.get_status(), HealthcheckStatus.FAIL) + self.assertIn("Unsupported", card.subtitle) + self.assertIn("uses_standard_api=False", card.subtitle) - for name, options, tier_metadata, expected_msg in test_cases: - with self.subTest(config=name): - card = ComplexityRatingAnalysis( - options=options, tier_metadata=tier_metadata - ).compute(experiment=self.experiment) + with self.subTest(config="high_failure_rate"): + tier_metadata = {"user_supplied_max_trials": 100, "uses_standard_api": True} + card = ComplexityRatingAnalysis( + tier_metadata=tier_metadata, + tolerated_trial_failure_rate=0.95, + ).compute(experiment=self.experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.FAIL) + self.assertIn("Unsupported", card.subtitle) + self.assertIn("0.95", card.subtitle) - self.assertEqual(card.get_status(), HealthcheckStatus.FAIL) - self.assertIn("Unsupported", card.subtitle) - self.assertIn(expected_msg, card.subtitle) + with self.subTest(config="invalid_failure_rate_check"): + tier_metadata = {"user_supplied_max_trials": 100, "uses_standard_api": True} + card = ComplexityRatingAnalysis( + tier_metadata=tier_metadata, + max_pending_trials=10, + min_failed_trials_for_failure_rate_check=50, + ).compute(experiment=self.experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.FAIL) + self.assertIn("Unsupported", card.subtitle) + self.assertIn("min_failed_trials_for_failure_rate_check", card.subtitle) def test_unordered_choice_parameters(self) -> None: params = [ @@ -257,9 +243,9 @@ def test_unordered_choice_parameters(self) -> None: self.assertTrue(is_unordered_choice(params[1], min_choices=3, max_choices=5)) - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute( + experiment=self.experiment + ) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) self.assertIn("Advanced", card.subtitle) @@ -279,9 +265,9 @@ def test_binary_parameters_count(self) -> None: for p in params: self.assertTrue(can_map_to_binary(p)) - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute( + experiment=self.experiment + ) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) self.assertIn("Advanced", card.subtitle) @@ -300,10 +286,9 @@ def test_multiple_violations(self) -> None: experiment = self.experiment experiment._search_space = SearchSpace(parameters=params) tier_metadata = {"user_supplied_max_trials": 300, "uses_standard_api": True} - # pyre-ignore[6]: Using a mock value for testing - options = OrchestratorOptions(early_stopping_strategy="mock") card = ComplexityRatingAnalysis( - options=options, tier_metadata=tier_metadata + tier_metadata=tier_metadata, + early_stopping_strategy=get_percentile_early_stopping_strategy(), ).compute(experiment=experiment) self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) @@ -313,9 +298,9 @@ def test_multiple_violations(self) -> None: self.assertIn("Early stopping is enabled", card.subtitle) def test_dataframe_summary(self) -> None: - card = ComplexityRatingAnalysis( - options=self.options, tier_metadata=self.tier_metadata - ).compute(experiment=self.experiment) + card = ComplexityRatingAnalysis(tier_metadata=self.tier_metadata).compute( + experiment=self.experiment + ) df = card.df self.assertIsNotNone(df) diff --git a/ax/analysis/utils.py b/ax/analysis/utils.py index fedb9af79c6..12a82a634f5 100644 --- a/ax/analysis/utils.py +++ b/ax/analysis/utils.py @@ -6,7 +6,7 @@ # pyre-strict from logging import Logger -from typing import Sequence +from typing import Any, Sequence import numpy as np @@ -1071,3 +1071,19 @@ def validate_outcome_constraints( ) return None + + +def filter_none(**kwargs: Any) -> dict[str, Any]: + """Return a dict with only non-None values. + + Useful for conditionally passing optional keyword arguments to functions + that have defaults. Only non-None values are included in the returned dict, + allowing the called function to use its own defaults for omitted parameters. + + Args: + **kwargs: Keyword arguments to filter. + + Returns: + A dict containing only the key-value pairs where value is not None. + """ + return {k: v for k, v in kwargs.items() if v is not None} From c683572a9217c58dc8541def92184dbb76d9302f Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 5/6] Add Complexity Rating Healthcheck to OverviewAnalysis (#4709) Summary: Enable ComplexityRatingAnalysis in OverviewAnalysis by passing orchestrator stopping strategies and trial failure settings as parameters. Differential Revision: D89778765 Privacy Context Container: L1307644 --- ax/analysis/healthcheck/complexity_rating.py | 12 -------- ax/analysis/overview.py | 30 +++++++++++++++++++- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/ax/analysis/healthcheck/complexity_rating.py b/ax/analysis/healthcheck/complexity_rating.py index fb015309294..5ccaeb95c2d 100644 --- a/ax/analysis/healthcheck/complexity_rating.py +++ b/ax/analysis/healthcheck/complexity_rating.py @@ -29,18 +29,6 @@ ) from pyre_extensions import none_throws, override -# TODO: Enable ComplexityRatingAnalysis in OverviewAnalysis. Currently, this -# analysis depends on OrchestratorOptions to evaluate early stopping, global -# stopping, and other orchestrator-level settings. To enable it in -# OverviewAnalysis, we need to either: -# 1. Extract the relevant settings from OrchestratorOptions into a smaller, -# analysis-specific data structure that can be passed independently, OR -# 2. Modify summarize_ax_optimization_complexity to make the options parameter -# optional and handle missing orchestrator settings gracefully (e.g., skip -# those checks or use defaults). -# For now, this analysis can be used directly other -# orchestrator-aware callers that have access to OrchestratorOptions. - @final class ComplexityRatingAnalysis(Analysis): diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index b179fa777e3..6b12ba857ba 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -5,7 +5,7 @@ # pyre-strict -from typing import final +from typing import Any, final from ax.adapter.base import Adapter from ax.analysis.analysis import Analysis, ErrorAnalysisCard @@ -14,6 +14,7 @@ from ax.analysis.healthcheck.can_generate_candidates import ( CanGenerateCandidatesAnalysis, ) +from ax.analysis.healthcheck.complexity_rating import ComplexityRatingAnalysis from ax.analysis.healthcheck.constraints_feasibility import ( ConstraintsFeasibilityAnalysis, ) @@ -33,8 +34,10 @@ from ax.core.map_data import MapData from ax.core.map_metric import MapMetric from ax.core.trial_status import TrialStatus +from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from pyre_extensions import override @@ -87,6 +90,7 @@ class OverviewAnalysis(Analysis): * ConstraintsFeasibilityAnalysis * SearchSpaceAnalysis * ShouldGenerateCandidates + * ComplexityRatingAnalysis * Trial-Level Analyses * Trial 0 * ArmEffectsPlot @@ -100,6 +104,12 @@ def __init__( can_generate_days_till_fail: int | None = None, should_generate: bool | None = None, should_generate_reason: str | None = None, + early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, + global_stopping_strategy: BaseGlobalStoppingStrategy | None = None, + tolerated_trial_failure_rate: float | None = None, + max_pending_trials: int | None = None, + min_failed_trials_for_failure_rate_check: int | None = None, + tier_metadata: dict[str, Any] | None = None, ) -> None: super().__init__() self.can_generate = can_generate @@ -107,6 +117,14 @@ def __init__( self.can_generate_days_till_fail = can_generate_days_till_fail self.should_generate = should_generate self.should_generate_reason = should_generate_reason + self.early_stopping_strategy = early_stopping_strategy + self.global_stopping_strategy = global_stopping_strategy + self.tolerated_trial_failure_rate = tolerated_trial_failure_rate + self.max_pending_trials = max_pending_trials + self.min_failed_trials_for_failure_rate_check = ( + min_failed_trials_for_failure_rate_check + ) + self.tier_metadata = tier_metadata @override def validate_applicable_state( @@ -183,6 +201,16 @@ def compute( and self.can_generate_reason is not None and self.can_generate_days_till_fail is not None else None, + ComplexityRatingAnalysis( + tier_metadata=self.tier_metadata, + early_stopping_strategy=self.early_stopping_strategy, + global_stopping_strategy=self.global_stopping_strategy, + tolerated_trial_failure_rate=self.tolerated_trial_failure_rate, + max_pending_trials=self.max_pending_trials, + min_failed_trials_for_failure_rate_check=( + self.min_failed_trials_for_failure_rate_check + ), + ), ConstraintsFeasibilityAnalysis(), PredictableMetricsAnalysis(), BaselineImprovementAnalysis() if not has_batch_trials else None, From 83a29a40201dcb34d723caadb878a63d6b156f3f Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Fri, 26 Dec 2025 09:38:30 -0800 Subject: [PATCH 6/6] Add Auto ES Config to OverviewAnalysis (#4710) Summary: This diff adds support for AutoEarlyStoppingConfig in the OverviewAnalysis class. This change enables OverviewAnalysis to forward auto early stopping configuration to the underlying EarlyStoppingAnalysis. Differential Revision: D89793744 Privacy Context Container: L1307644 --- ax/analysis/overview.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index 6b12ba857ba..99b3be996ba 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -18,7 +18,10 @@ from ax.analysis.healthcheck.constraints_feasibility import ( ConstraintsFeasibilityAnalysis, ) -from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis +from ax.analysis.healthcheck.early_stopping_healthcheck import ( + AutoEarlyStoppingConfig, + EarlyStoppingAnalysis, +) from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis @@ -27,7 +30,7 @@ from ax.analysis.insights import InsightsAnalysis from ax.analysis.results import ResultsAnalysis from ax.analysis.trials import AllTrialsAnalysis -from ax.analysis.utils import validate_experiment +from ax.analysis.utils import filter_none, validate_experiment from ax.core.analysis_card import AnalysisCardGroup from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment @@ -110,6 +113,7 @@ def __init__( max_pending_trials: int | None = None, min_failed_trials_for_failure_rate_check: int | None = None, tier_metadata: dict[str, Any] | None = None, + auto_early_stopping_config: AutoEarlyStoppingConfig | None = None, ) -> None: super().__init__() self.can_generate = can_generate @@ -125,6 +129,7 @@ def __init__( min_failed_trials_for_failure_rate_check ) self.tier_metadata = tier_metadata + self.auto_early_stopping_config = auto_early_stopping_config @override def validate_applicable_state( @@ -191,7 +196,17 @@ def compute( health_check_analyses = [ MetricFetchingErrorsAnalysis(), - EarlyStoppingAnalysis() if has_map_data and has_map_metrics else None, + ( + EarlyStoppingAnalysis( + **filter_none( + early_stopping_strategy=self.early_stopping_strategy, + auto_early_stopping_config=self.auto_early_stopping_config, + max_pending_trials=self.max_pending_trials, + ) + ) + if has_map_data and has_map_metrics + else None + ), CanGenerateCandidatesAnalysis( can_generate_candidates=self.can_generate, reason=self.can_generate_reason,