Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ def test_candidate_metadata_propagation(self) -> None:
exp = get_branin_experiment(with_status_quo=True, with_completed_batch=True)
# Check that the metadata is correctly re-added to observation
# features during `fit`.
# pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run_structs`.
preexisting_batch_gr = exp.trials[0]._generator_runs[0]
preexisting_batch_gr = exp.trials[0].generator_runs[0]
preexisting_batch_gr._candidate_metadata_by_arm_signature = {
preexisting_batch_gr.arms[0].signature: {
"preexisting_batch_cand_metadata": "some_value"
Expand Down
91 changes: 71 additions & 20 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ax.utils.common.base import SortableBase
from ax.utils.common.constants import Keys
from pyre_extensions import none_throws
from typing_extensions import Self


if TYPE_CHECKING:
Expand Down Expand Up @@ -137,12 +138,60 @@ def __init__(
# strategy, this property will be set to the generation step that produced
# the generator run(s).
self._generation_step_index: int | None = None
# Please do not store any data related to trial deployment or data-
# NOTE: Please do not store any data related to trial deployment or data-
# fetching in properties. It is intended to only store properties related
# to core Ax functionality and not to any third-system that the trials
# might be getting deployed to.
# pyre-fixme[4]: Attribute must be annotated.
self._properties = {}
self._properties: dict[str, Any] = {}

@abstractproperty
def arms(self) -> list[Arm]:
"""All arms associated with this trial."""
pass

@abstractproperty
def arms_by_name(self) -> dict[str, Arm]:
"""A mapping of from arm names, to all arms associated with
this trial.
"""
pass

@abstractproperty
def abandoned_arms(self) -> list[Arm]:
"""All abandoned arms, associated with this trial."""
pass

@abstractmethod
def add_generator_run(self, generator_run: GeneratorRun) -> Self:
"""Add a generator run to the trial.

The arms and weights from the generator run will be merged with
the existing arms and weights on the trial, and the generator run
object will be linked to the trial for tracking.

Args:
generator_run: The generator run to be added.

Returns:
The trial instance.
"""
pass

@abstractmethod
def add_arm(
self, arm: Arm, candidate_metadata: dict[str, Any] | None = None
) -> Self:
"""Add arm to the trial.

Returns:
The trial instance.
"""
pass

@abstractmethod
def __repr__(self) -> str:
"""String representation of the trial."""
pass

@property
def experiment(self) -> core.experiment.Experiment:
Expand Down Expand Up @@ -237,6 +286,25 @@ def trial_type(self) -> str | None:
"""
return self._trial_type

def _add_generator_run(self, generator_run: GeneratorRun) -> None:
"""Helper called from ``{BatchTrial, Trial}.add_generator_run: validates
and names arms; adds the ``GeneratorRun`` to this trial's ``Experiment``.
"""
# 1. Validate the arm(s) in the generator run.
for arm in generator_run.arms:
self.experiment.search_space.check_types(arm.parameters, raise_error=True)

# 2. Name any yet-unnamed arms: for arms that are not yet added to this
# trial's experiment, assign a new name; use name on experiment otherwise.
for arm in generator_run.arms:
self._check_existing_and_name_arm(arm)

# 3.TODO: Add the generator run to the experiment if it's not already there;
# assign an experiment-level index if newly added.

# 4. TODO: Capture which generator run the arms we are about to add this
# this trial, came from.

def assign_runner(self) -> BaseTrial:
"""Assigns default experiment runner if trial doesn't already have one."""
runner = self.experiment.runner_for_trial(self)
Expand Down Expand Up @@ -431,23 +499,6 @@ def _set_generation_step_index(self, generation_step_index: int | None) -> None:
)
self._generation_step_index = generation_step_index

@abstractproperty
def arms(self) -> list[Arm]:
pass

@abstractproperty
def arms_by_name(self) -> dict[str, Arm]:
pass

@abstractmethod
def __repr__(self) -> str:
pass

@abstractproperty
def abandoned_arms(self) -> list[Arm]:
"""All abandoned arms, associated with this trial."""
pass

@property
def active_arms(self) -> list[Arm]:
"""All non abandoned arms associated with this trial."""
Expand Down
32 changes: 10 additions & 22 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from __future__ import annotations

import warnings

from collections import defaultdict, OrderedDict
from collections.abc import MutableMapping
from dataclasses import dataclass
Expand Down Expand Up @@ -133,7 +131,6 @@ def __init__(
self._arms_by_name: dict[str, Arm] = {}
self._generator_runs: list[GeneratorRun] = []
self._abandoned_arms_metadata: dict[str, AbandonedArm] = {}
self._status_quo_weight_override: float | None = None
self.should_add_status_quo_arm = should_add_status_quo_arm
if generator_run is not None:
if generator_runs is not None:
Expand Down Expand Up @@ -184,6 +181,13 @@ def arm_weights(self) -> MutableMapping[Arm, float]:
arm_weights[arm] += weight
return arm_weights

@property
def _status_quo_weight_override(self) -> None:
raise DeprecationWarning(
"Status quo weight override is no longer supported. Please "
"contact the Ax developers for help adjusting your application."
)

@arm_weights.setter
def arm_weights(self, arm_weights: MutableMapping[Arm, float]) -> None:
raise NotImplementedError("Use `trial.add_arms_and_weights`")
Expand Down Expand Up @@ -289,11 +293,9 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial:
}
)

# Add names to arms
# For those not yet added to this experiment, create a new name
# Else, use the name of the existing arm
for arm in generator_run.arms:
self._check_existing_and_name_arm(arm)
# Call `BaseTrial._add_generator_run` to validate and name the arms,
# then attach the generator run to the experiment.
self._add_generator_run(generator_run=generator_run)

self._generator_runs.append(generator_run)
if generator_run._generation_step_index is not None:
Expand Down Expand Up @@ -330,16 +332,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial:
if weight is not None:
if weight <= 0.0:
raise ValueError("Status quo weight must be positive.")
if (
self._status_quo_weight_override is not None
and weight != self._status_quo_weight_override
):
warnings.warn(
f"Status quo weight is being overridden from {weight} "
f"to {self._status_quo_weight_override}.",
UserWarning,
stacklevel=3,
)

sq_arm_already_added_with_correct_weight = False
for existing_arm in self.arm_weights:
Expand All @@ -360,7 +352,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial:
weight=weight,
generator_run_type=GeneratorRunType.STATUS_QUO,
)
self._status_quo_weight_override = weight
self._refresh_arms_by_name()
return self

Expand Down Expand Up @@ -536,9 +527,6 @@ def clone_to(
],
should_add_status_quo_arm=include_sq and self.should_add_status_quo_arm,
)
if (sq := self.status_quo) is not None and sq in self.arm_weights:
new_trial._status_quo_weight_override = self.arm_weights[sq]

self._update_trial_attrs_on_clone(new_trial=new_trial)
return new_trial

Expand Down
5 changes: 0 additions & 5 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,15 @@ def test_status_quo_cannot_be_set_directly(self) -> None:
def test_status_quo_weight_is_ignored_when_none(self) -> None:
tot_weight = sum(self.batch.weights)
self.assertEqual(sum(self.batch.weights), tot_weight)
self.assertIsNone(self.batch._status_quo_weight_override)

def test_status_quo_set_on_clone_to(self) -> None:
batch2 = self.batch.clone_to(include_sq=False)
self.assertEqual(batch2.status_quo, self.experiment.status_quo)
# Since should_add_status_quo_arm was False,
# _status_quo_weight_override should be False and the
# status_quo arm should not appear in arm_weights
self.assertIsNone(batch2._status_quo_weight_override)
self.assertTrue(batch2.status_quo not in batch2.arm_weights)
self.assertEqual(sum(batch2.weights), sum(self.weights))
# Test with should_add_status_quo_arm=True
batch3 = self.experiment.new_batch_trial(should_add_status_quo_arm=True)
self.assertEqual(batch3._status_quo_weight_override, 1.0)
self.assertTrue(batch2.status_quo in batch3.arm_weights)

def test_cannot_add_status_quo_arm_without_status_quo(self) -> None:
Expand Down
24 changes: 10 additions & 14 deletions ax/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(
ttl_seconds=ttl_seconds,
index=index,
)
# pyre-fixme[4]: Attribute must be annotated.
self._generator_run = None
self._generator_run: GeneratorRun | None = None
if generator_run is not None:
self.add_generator_run(generator_run=generator_run)

Expand All @@ -84,16 +83,16 @@ def generator_run(self) -> GeneratorRun | None:
"""Generator run attached to this trial."""
return self._generator_run

# pyre-ignore[6]: T77111662.
@copy_doc(BaseTrial.generator_runs)
@property
def generator_runs(self) -> list[GeneratorRun]:
gr = self._generator_run
return [gr] if gr is not None else []
"""Generator runs attached to this trial. Since this is a one-arm
``Trial`` (and not ``BatchTrial``), this will be a list of length 1.
"""
return [self._generator_run] if self._generator_run else []

@property
def arm(self) -> Arm | None:
"""The arm associated with this batch."""
"""The ``Arm`` associated with this ``Trial``."""
if self.generator_run is None:
return None

Expand Down Expand Up @@ -149,14 +148,11 @@ def add_generator_run(self, generator_run: GeneratorRun) -> Trial:
"included multiple."
)

self.experiment.search_space.check_types(
generator_run.arms[0].parameters, raise_error=True
)
self._check_existing_and_name_arm(generator_run.arms[0])
# Call `BaseTrial._add_generator_run` to validate and name the arms,
# then attach the generator run to the experiment.
self._add_generator_run(generator_run=generator_run)

self._generator_run = generator_run
self._set_generation_step_index(
generation_step_index=generator_run._generation_step_index
)
return self

@property
Expand Down
1 change: 1 addition & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def trials_from_json(
# `GeneratorRunStruct` (deprecated) will be decoded into a `GeneratorRun`,
# so all we have to do here is change the key it's stored under.
trial_json["generator_runs"] = trial_json.pop("generator_run_structs")
trial_json.pop("status_quo_weight_override", None) # Deprecated.
loaded_trials[int(index)] = (
trial_from_json(experiment=experiment, **trial_json)
if is_trial
Expand Down
6 changes: 1 addition & 5 deletions ax/storage/json_store/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def batch_trial_from_json(
abandoned_arms_metadata: dict[str, AbandonedArm],
num_arms_created: int,
status_quo: Arm | None,
status_quo_weight_override: float,
# Allowing default values for backwards compatibility with
# objects stored before these fields were added.
failed_reason: str | None = None,
Expand Down Expand Up @@ -149,11 +148,8 @@ def batch_trial_from_json(

# Trial.arms_by_name only returns arms with weights
batch.should_add_status_quo_arm = batch.status_quo is not None
batch._status_quo_weight_override = status_quo_weight_override
if batch.should_add_status_quo_arm:
batch.add_status_quo_arm(
weight=status_quo_weight_override,
)
batch.add_status_quo_arm()

# Set trial status last, after adding all the arms.
batch._status = status
Expand Down
1 change: 0 additions & 1 deletion ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def batch_to_dict(batch: BatchTrial) -> dict[str, Any]:
"ttl_seconds": batch.ttl_seconds,
"status": batch.status,
"status_quo": batch.status_quo,
"status_quo_weight_override": batch._status_quo_weight_override,
"time_created": batch.time_created,
"time_completed": batch.time_completed,
"time_staged": batch.time_staged,
Expand Down
30 changes: 12 additions & 18 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ax.core.batch_trial import AbandonedArm, BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.multi_type_experiment import MultiTypeExperiment
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand Down Expand Up @@ -223,7 +223,7 @@ def _init_experiment_from_sqa(
else {}
)

return Experiment(
experiment = Experiment(
name=experiment_sqa.name,
description=experiment_sqa.description,
search_space=search_space,
Expand All @@ -237,6 +237,16 @@ def _init_experiment_from_sqa(
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
)

for arm in experiment_sqa.arms:
experiment._register_arm(
arm=Arm(
parameters=arm.parameters,
name=arm.name,
)
)

return experiment

def _init_mt_experiment_from_sqa(
self,
experiment_sqa: SQAExperiment,
Expand Down Expand Up @@ -1011,22 +1021,6 @@ def trial_from_sqa(
)
for generator_run_sqa in trial_sqa.generator_runs
]
if trial_sqa.status_quo_name is not None:
sq_arm = experiment.arms_by_name[trial_sqa.status_quo_name]
for gr in generator_runs:
# Most of the time, the status quo arm has its own generator run,
# of a dedicated type.
if gr.generator_run_type == GeneratorRunType.STATUS_QUO.name:
status_quo_weight = gr.weights[0]
trial._status_quo_weight_override = status_quo_weight
trial._status_quo_generator_run_db_id = gr.db_id
trial._status_quo_arm_db_id = gr.arms[0].db_id
break
# However, sometimes the status quo arm is part of a generator run
# that also includes other arms (e.g. in factorial or TS design).
if sq_arm in gr.arms:
trial._status_quo_generator_run_db_id = gr.db_id
trial._status_quo_arm_db_id = sq_arm.db_id
trial._generator_runs = generator_runs
trial._abandoned_arms_metadata = {
abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
Expand Down
Loading
Loading