Skip to content

Commit 0ccd306

Browse files
esantorellafacebook-github-bot
authored andcommitted
Merge MapData into Data (facebook#4715)
Summary: This diff merges MapData into Data by giving Data an attribute `has_step_column`. MapData becomes an empty subclass which will be removed in a subsequent PR (D89814417). Note on how this diff is split up: In this diff, functions that previously required MapData now can consume Data, and type checks stop referencing MapData. However, functions can still return MapData; all references are removed in the next diff. Changes: * Functionality from map_data.py is moved into data.py, and functionality from MapData moves into Data; MapData is an empty subclass. * Data (and MapData) get an attribute `has_step_column`; the important distinction becomes `has_step_column`, not type. * It is now possible for both `Data` and `MapData` to either have a step column or not. Having both `Data` and `MapData` like this is an unpleasant intermediate state that we should move off of immediately (landing this and D89814417 together) * Many `isinstance` checks become `has_step_coumn` checks * Data's new methods `subsample` and `latest` get special cases for when there is no "step" column * Make Data's required columns always be the same and not contain "step" (since it isn't really required). Remove `required_columns` method. Reviewed By: lena-kashtelyan Differential Revision: D89820078
1 parent 65fbdd6 commit 0ccd306

File tree

20 files changed

+425
-445
lines changed

20 files changed

+425
-445
lines changed

ax/adapter/data_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
from typing import Any
2323

2424
import numpy as np
25-
from ax.core.data import Data
25+
from ax.core.data import Data, MAP_KEY
2626
from ax.core.experiment import Experiment
27-
from ax.core.map_data import MAP_KEY, MapData
2827
from ax.core.map_metric import MapMetric
2928
from ax.core.observation import Observation, ObservationData, ObservationFeatures
3029
from ax.core.trial_status import STATUSES_EXPECTING_DATA, TrialStatus
@@ -403,7 +402,7 @@ def _extract_observation_data(
403402
retrived with ``("metadata", "start_time")``.
404403
"""
405404
data = data if data is not None else experiment.lookup_data()
406-
if isinstance(data, MapData):
405+
if data.has_step_column:
407406
if data_loader_config.latest_rows_per_group is not None:
408407
data = data.latest(
409408
rows_per_group=data_loader_config.latest_rows_per_group,
@@ -455,7 +454,7 @@ def _extract_observation_data(
455454

456455
# Identify potential metadata columns.
457456
index_cols = ["trial_index", "arm_name"]
458-
if isinstance(data, MapData):
457+
if data.has_step_column:
459458
index_cols.append(MAP_KEY)
460459

461460
standard_columns = set(index_cols).union(

ax/analysis/overview.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from ax.core.analysis_card import AnalysisCardGroup
3232
from ax.core.batch_trial import BatchTrial
3333
from ax.core.experiment import Experiment
34-
from ax.core.map_data import MapData
3534
from ax.core.map_metric import MapMetric
3635
from ax.core.trial_status import TrialStatus
3736
from ax.exceptions.core import UserInputError
@@ -179,7 +178,7 @@ def compute(
179178

180179
# Check if the experiment has MapData and MapMetrics (required for
181180
# early stopping)
182-
has_map_data = isinstance(experiment.lookup_data(), MapData)
181+
has_map_data = experiment.lookup_data().has_step_column
183182
has_map_metrics = any(
184183
isinstance(m, MapMetric) for m in experiment.metrics.values()
185184
)

ax/analysis/plotly/progression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ax.analysis.plotly.utils import select_metric
2020
from ax.analysis.utils import validate_experiment
2121
from ax.core.experiment import Experiment
22-
from ax.core.map_data import MAP_KEY, MapData
22+
from ax.core.map_data import MAP_KEY
2323
from ax.core.trial_status import TrialStatus
2424
from ax.generation_strategy.generation_strategy import GenerationStrategy
2525

@@ -87,8 +87,8 @@ def validate_applicable_state(
8787
return experiment_invalid_reason
8888

8989
data = none_throws(experiment).lookup_data()
90-
if not isinstance(data, MapData):
91-
return "Requires MapData."
90+
if not data.has_step_column:
91+
return "Requires data to have a column 'step.'"
9292

9393
@override
9494
def compute(

ax/analysis/results.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,9 @@ def compute(
250250
# Compute progression plots for MapMetrics (learning curves)
251251
progression_group = None
252252
data = experiment.lookup_data()
253-
has_map_data = isinstance(data, MapData)
254253
metrics = experiment.metrics.values()
255254
map_metrics = [m for m in metrics if isinstance(m, MapMetric)]
256-
if has_map_data and len(map_metrics) > 0:
255+
if data.has_step_column and len(map_metrics) > 0:
257256
progression_cards = [
258257
ProgressionPlot(
259258
metric_name=m.name, by_wallclock_time=by_wallclock_time

ax/analysis/summary.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from ax.analysis.utils import validate_experiment
1515
from ax.core.analysis_card import AnalysisCard
1616
from ax.core.experiment import Experiment
17-
from ax.core.map_data import MapData
1817
from ax.core.trial_status import NON_STALE_STATUSES, TrialStatus
1918
from ax.exceptions.core import UserInputError
2019
from ax.generation_strategy.generation_strategy import GenerationStrategy
@@ -87,7 +86,7 @@ def compute(
8786
should_relativize = (
8887
len(experiment.metrics) > 0
8988
and experiment.status_quo is not None
90-
and not isinstance(data, MapData)
89+
and not data.has_step_column
9190
)
9291

9392
return self._create_analysis_card(

ax/core/base_trial.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from ax.core.data import Data
2121
from ax.core.evaluations_to_data import raw_evaluations_to_data
2222
from ax.core.generator_run import GeneratorRun, GeneratorRunType
23-
from ax.core.map_data import MapData
2423
from ax.core.metric import Metric, MetricFetchResult
2524
from ax.core.runner import Runner
2625
from ax.core.trial_status import TrialStatus
@@ -449,7 +448,7 @@ def fetch_data(self, metrics: list[Metric] | None = None, **kwargs: Any) -> Data
449448
data = Metric._unwrap_trial_data_multi(
450449
results=self.fetch_data_results(metrics=metrics, **kwargs)
451450
)
452-
if not isinstance(data, MapData):
451+
if not data.has_step_column:
453452
data.full_df = sort_by_trial_index_and_arm_name(data.full_df)
454453

455454
return data

0 commit comments

Comments
 (0)