Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions ax/analysis/healthcheck/complexity_rating.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
HealthcheckAnalysisCard,
HealthcheckStatus,
)
from ax.analysis.utils import filter_none
from ax.core.experiment import Experiment
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.service.orchestrator import OrchestratorOptions
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from ax.utils.common.complexity_utils import (
check_if_in_standard,
DEFAULT_TIER_MESSAGES,
Expand All @@ -27,18 +29,6 @@
)
from pyre_extensions import none_throws, override

# TODO: Enable ComplexityRatingAnalysis in OverviewAnalysis. Currently, this
# analysis depends on OrchestratorOptions to evaluate early stopping, global
# stopping, and other orchestrator-level settings. To enable it in
# OverviewAnalysis, we need to either:
# 1. Extract the relevant settings from OrchestratorOptions into a smaller,
# analysis-specific data structure that can be passed independently, OR
# 2. Modify summarize_ax_optimization_complexity to make the options parameter
# optional and handle missing orchestrator settings gracefully (e.g., skip
# those checks or use defaults).
# For now, this analysis can be used directly other
# orchestrator-aware callers that have access to OrchestratorOptions.


@final
class ComplexityRatingAnalysis(Analysis):
Expand All @@ -61,16 +51,17 @@ class ComplexityRatingAnalysis(Analysis):

def __init__(
self,
options: OrchestratorOptions | None = None,
tier_metadata: dict[str, Any] | None = None,
tier_messages: TierMessages = DEFAULT_TIER_MESSAGES,
early_stopping_strategy: BaseEarlyStoppingStrategy | None = None,
global_stopping_strategy: BaseGlobalStoppingStrategy | None = None,
tolerated_trial_failure_rate: float | None = None,
max_pending_trials: int | None = None,
min_failed_trials_for_failure_rate_check: int | None = None,
) -> None:
"""Initialize the ComplexityRatingAnalysis.

Args:
options: The orchestrator options used for the optimization.
Required to evaluate early stopping, global stopping, and
failure rate settings.
tier_metadata: Additional tier-related metadata from the orchestrator.
Supported keys:
- 'user_supplied_max_trials': Maximum number of trials.
Expand All @@ -82,12 +73,27 @@ def __init__(
generic messages suitable for most users. Pass a custom TierMessages
instance to provide tool-specific descriptions, support SLAs,
links to docs, or contact information.
early_stopping_strategy: The early stopping strategy, if any. Used to
determine if early stopping is enabled. Defaults to None.
global_stopping_strategy: The global stopping strategy, if any. Used to
determine if global stopping is enabled. Defaults to None.
tolerated_trial_failure_rate: Fraction of trials allowed to fail without
the whole optimization ending. Default value used is 0.5.
max_pending_trials: Maximum number of pending trials. Default used is 10.
min_failed_trials_for_failure_rate_check: Minimum failed trials before
failure rate is checked. Default value used is 5.
"""
self.options = options
self.tier_metadata: dict[str, Any] = (
tier_metadata if tier_metadata is not None else {}
)
self.tier_messages = tier_messages
self.early_stopping_strategy = early_stopping_strategy
self.global_stopping_strategy = global_stopping_strategy
self.tolerated_trial_failure_rate = tolerated_trial_failure_rate
self.max_pending_trials = max_pending_trials
self.min_failed_trials_for_failure_rate_check = (
min_failed_trials_for_failure_rate_check
)

@override
def validate_applicable_state(
Expand All @@ -98,11 +104,6 @@ def validate_applicable_state(
) -> str | None:
if experiment is None:
return "Experiment is required for ComplexityRatingAnalysis."
if self.options is None:
return (
"OrchestratorOptions is required for ComplexityRatingAnalysis. "
"Please pass options to the constructor."
)
return None

@override
Expand All @@ -120,8 +121,7 @@ def compute(

Note:
This method assumes ``validate_applicable_state`` has been called
and returned None, ensuring ``experiment`` and ``self.options``
are not None.
and returned None, ensuring ``experiment`` is not None.

Args:
experiment: The Ax Experiment to analyze. Must not be None.
Expand All @@ -135,11 +135,18 @@ def compute(
with key experiment metrics.
"""
experiment = none_throws(experiment)
options = none_throws(self.options)
optimization_summary = summarize_ax_optimization_complexity(
experiment=experiment,
options=options,
tier_metadata=self.tier_metadata,
early_stopping_strategy=self.early_stopping_strategy,
global_stopping_strategy=self.global_stopping_strategy,
**filter_none(
tolerated_trial_failure_rate=self.tolerated_trial_failure_rate,
max_pending_trials=self.max_pending_trials,
min_failed_trials_for_failure_rate_check=(
self.min_failed_trials_for_failure_rate_check
),
),
)

# Determine tier
Expand Down
Loading