diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 50e7b2a1dd2..12a3095c1a1 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -483,8 +483,7 @@ def test_candidate_metadata_propagation(self) -> None: exp = get_branin_experiment(with_status_quo=True, with_completed_batch=True) # Check that the metadata is correctly re-added to observation # features during `fit`. - # pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run_structs`. - preexisting_batch_gr = exp.trials[0]._generator_runs[0] + preexisting_batch_gr = exp.trials[0].generator_runs[0] preexisting_batch_gr._candidate_metadata_by_arm_signature = { preexisting_batch_gr.arms[0].signature: { "preexisting_batch_cand_metadata": "some_value" diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 896c0ac1317..3b0e8b95859 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -30,6 +30,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys from pyre_extensions import none_throws +from typing_extensions import Self if TYPE_CHECKING: @@ -137,12 +138,60 @@ def __init__( # strategy, this property will be set to the generation step that produced # the generator run(s). self._generation_step_index: int | None = None - # Please do not store any data related to trial deployment or data- + # NOTE: Please do not store any data related to trial deployment or data- # fetching in properties. It is intended to only store properties related # to core Ax functionality and not to any third-system that the trials # might be getting deployed to. - # pyre-fixme[4]: Attribute must be annotated. - self._properties = {} + self._properties: dict[str, Any] = {} + + @abstractproperty + def arms(self) -> list[Arm]: + """All arms associated with this trial.""" + pass + + @abstractproperty + def arms_by_name(self) -> dict[str, Arm]: + """A mapping of from arm names, to all arms associated with + this trial. + """ + pass + + @abstractproperty + def abandoned_arms(self) -> list[Arm]: + """All abandoned arms, associated with this trial.""" + pass + + @abstractmethod + def add_generator_run(self, generator_run: GeneratorRun) -> Self: + """Add a generator run to the trial. + + The arms and weights from the generator run will be merged with + the existing arms and weights on the trial, and the generator run + object will be linked to the trial for tracking. + + Args: + generator_run: The generator run to be added. + + Returns: + The trial instance. + """ + pass + + @abstractmethod + def add_arm( + self, arm: Arm, candidate_metadata: dict[str, Any] | None = None + ) -> Self: + """Add arm to the trial. + + Returns: + The trial instance. + """ + pass + + @abstractmethod + def __repr__(self) -> str: + """String representation of the trial.""" + pass @property def experiment(self) -> core.experiment.Experiment: @@ -237,6 +286,25 @@ def trial_type(self) -> str | None: """ return self._trial_type + def _add_generator_run(self, generator_run: GeneratorRun) -> None: + """Helper called from ``{BatchTrial, Trial}.add_generator_run: validates + and names arms; adds the ``GeneratorRun`` to this trial's ``Experiment``. + """ + # 1. Validate the arm(s) in the generator run. + for arm in generator_run.arms: + self.experiment.search_space.check_types(arm.parameters, raise_error=True) + + # 2. Name any yet-unnamed arms: for arms that are not yet added to this + # trial's experiment, assign a new name; use name on experiment otherwise. + for arm in generator_run.arms: + self._check_existing_and_name_arm(arm) + + # 3.TODO: Add the generator run to the experiment if it's not already there; + # assign an experiment-level index if newly added. + + # 4. TODO: Capture which generator run the arms we are about to add this + # this trial, came from. + def assign_runner(self) -> BaseTrial: """Assigns default experiment runner if trial doesn't already have one.""" runner = self.experiment.runner_for_trial(self) @@ -431,23 +499,6 @@ def _set_generation_step_index(self, generation_step_index: int | None) -> None: ) self._generation_step_index = generation_step_index - @abstractproperty - def arms(self) -> list[Arm]: - pass - - @abstractproperty - def arms_by_name(self) -> dict[str, Arm]: - pass - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractproperty - def abandoned_arms(self) -> list[Arm]: - """All abandoned arms, associated with this trial.""" - pass - @property def active_arms(self) -> list[Arm]: """All non abandoned arms associated with this trial.""" diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 57b3f611815..e3939d9cfbb 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -8,8 +8,6 @@ from __future__ import annotations -import warnings - from collections import defaultdict, OrderedDict from collections.abc import MutableMapping from dataclasses import dataclass @@ -133,7 +131,6 @@ def __init__( self._arms_by_name: dict[str, Arm] = {} self._generator_runs: list[GeneratorRun] = [] self._abandoned_arms_metadata: dict[str, AbandonedArm] = {} - self._status_quo_weight_override: float | None = None self.should_add_status_quo_arm = should_add_status_quo_arm if generator_run is not None: if generator_runs is not None: @@ -184,6 +181,13 @@ def arm_weights(self) -> MutableMapping[Arm, float]: arm_weights[arm] += weight return arm_weights + @property + def _status_quo_weight_override(self) -> None: + raise DeprecationWarning( + "Status quo weight override is no longer supported. Please " + "contact the Ax developers for help adjusting your application." + ) + @arm_weights.setter def arm_weights(self, arm_weights: MutableMapping[Arm, float]) -> None: raise NotImplementedError("Use `trial.add_arms_and_weights`") @@ -289,11 +293,9 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial: } ) - # Add names to arms - # For those not yet added to this experiment, create a new name - # Else, use the name of the existing arm - for arm in generator_run.arms: - self._check_existing_and_name_arm(arm) + # Call `BaseTrial._add_generator_run` to validate and name the arms, + # then attach the generator run to the experiment. + self._add_generator_run(generator_run=generator_run) self._generator_runs.append(generator_run) if generator_run._generation_step_index is not None: @@ -330,16 +332,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial: if weight is not None: if weight <= 0.0: raise ValueError("Status quo weight must be positive.") - if ( - self._status_quo_weight_override is not None - and weight != self._status_quo_weight_override - ): - warnings.warn( - f"Status quo weight is being overridden from {weight} " - f"to {self._status_quo_weight_override}.", - UserWarning, - stacklevel=3, - ) sq_arm_already_added_with_correct_weight = False for existing_arm in self.arm_weights: @@ -360,7 +352,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial: weight=weight, generator_run_type=GeneratorRunType.STATUS_QUO, ) - self._status_quo_weight_override = weight self._refresh_arms_by_name() return self @@ -536,9 +527,6 @@ def clone_to( ], should_add_status_quo_arm=include_sq and self.should_add_status_quo_arm, ) - if (sq := self.status_quo) is not None and sq in self.arm_weights: - new_trial._status_quo_weight_override = self.arm_weights[sq] - self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index b6b203b5159..4847a8e2b47 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -160,20 +160,15 @@ def test_status_quo_cannot_be_set_directly(self) -> None: def test_status_quo_weight_is_ignored_when_none(self) -> None: tot_weight = sum(self.batch.weights) self.assertEqual(sum(self.batch.weights), tot_weight) - self.assertIsNone(self.batch._status_quo_weight_override) def test_status_quo_set_on_clone_to(self) -> None: batch2 = self.batch.clone_to(include_sq=False) self.assertEqual(batch2.status_quo, self.experiment.status_quo) # Since should_add_status_quo_arm was False, - # _status_quo_weight_override should be False and the - # status_quo arm should not appear in arm_weights - self.assertIsNone(batch2._status_quo_weight_override) self.assertTrue(batch2.status_quo not in batch2.arm_weights) self.assertEqual(sum(batch2.weights), sum(self.weights)) # Test with should_add_status_quo_arm=True batch3 = self.experiment.new_batch_trial(should_add_status_quo_arm=True) - self.assertEqual(batch3._status_quo_weight_override, 1.0) self.assertTrue(batch2.status_quo in batch3.arm_weights) def test_cannot_add_status_quo_arm_without_status_quo(self) -> None: diff --git a/ax/core/trial.py b/ax/core/trial.py index c730c375208..5e16a8ecc85 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -74,8 +74,7 @@ def __init__( ttl_seconds=ttl_seconds, index=index, ) - # pyre-fixme[4]: Attribute must be annotated. - self._generator_run = None + self._generator_run: GeneratorRun | None = None if generator_run is not None: self.add_generator_run(generator_run=generator_run) @@ -84,16 +83,16 @@ def generator_run(self) -> GeneratorRun | None: """Generator run attached to this trial.""" return self._generator_run - # pyre-ignore[6]: T77111662. - @copy_doc(BaseTrial.generator_runs) @property def generator_runs(self) -> list[GeneratorRun]: - gr = self._generator_run - return [gr] if gr is not None else [] + """Generator runs attached to this trial. Since this is a one-arm + ``Trial`` (and not ``BatchTrial``), this will be a list of length 1. + """ + return [self._generator_run] if self._generator_run else [] @property def arm(self) -> Arm | None: - """The arm associated with this batch.""" + """The ``Arm`` associated with this ``Trial``.""" if self.generator_run is None: return None @@ -149,14 +148,11 @@ def add_generator_run(self, generator_run: GeneratorRun) -> Trial: "included multiple." ) - self.experiment.search_space.check_types( - generator_run.arms[0].parameters, raise_error=True - ) - self._check_existing_and_name_arm(generator_run.arms[0]) + # Call `BaseTrial._add_generator_run` to validate and name the arms, + # then attach the generator run to the experiment. + self._add_generator_run(generator_run=generator_run) + self._generator_run = generator_run - self._set_generation_step_index( - generation_step_index=generator_run._generation_step_index - ) return self @property diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0d79e2c0f84..246ff0d0ba2 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -500,6 +500,7 @@ def trials_from_json( # `GeneratorRunStruct` (deprecated) will be decoded into a `GeneratorRun`, # so all we have to do here is change the key it's stored under. trial_json["generator_runs"] = trial_json.pop("generator_run_structs") + trial_json.pop("status_quo_weight_override", None) # Deprecated. loaded_trials[int(index)] = ( trial_from_json(experiment=experiment, **trial_json) if is_trial diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 6b337ca4428..95b72cef972 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -103,7 +103,6 @@ def batch_trial_from_json( abandoned_arms_metadata: dict[str, AbandonedArm], num_arms_created: int, status_quo: Arm | None, - status_quo_weight_override: float, # Allowing default values for backwards compatibility with # objects stored before these fields were added. failed_reason: str | None = None, @@ -149,11 +148,8 @@ def batch_trial_from_json( # Trial.arms_by_name only returns arms with weights batch.should_add_status_quo_arm = batch.status_quo is not None - batch._status_quo_weight_override = status_quo_weight_override if batch.should_add_status_quo_arm: - batch.add_status_quo_arm( - weight=status_quo_weight_override, - ) + batch.add_status_quo_arm() # Set trial status last, after adding all the arms. batch._status = status diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index baceabaa2f6..a5511af373f 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -123,7 +123,6 @@ def batch_to_dict(batch: BatchTrial) -> dict[str, Any]: "ttl_seconds": batch.ttl_seconds, "status": batch.status, "status_quo": batch.status_quo, - "status_quo_weight_override": batch._status_quo_weight_override, "time_created": batch.time_created, "time_completed": batch.time_completed, "time_staged": batch.time_staged, diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 5fe85ac03a1..736c2a7dae6 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -29,7 +29,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.generator_run import GeneratorRun, GeneratorRunType +from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -223,7 +223,7 @@ def _init_experiment_from_sqa( else {} ) - return Experiment( + experiment = Experiment( name=experiment_sqa.name, description=experiment_sqa.description, search_space=search_space, @@ -237,6 +237,16 @@ def _init_experiment_from_sqa( auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, ) + for arm in experiment_sqa.arms: + experiment._register_arm( + arm=Arm( + parameters=arm.parameters, + name=arm.name, + ) + ) + + return experiment + def _init_mt_experiment_from_sqa( self, experiment_sqa: SQAExperiment, @@ -1011,22 +1021,6 @@ def trial_from_sqa( ) for generator_run_sqa in trial_sqa.generator_runs ] - if trial_sqa.status_quo_name is not None: - sq_arm = experiment.arms_by_name[trial_sqa.status_quo_name] - for gr in generator_runs: - # Most of the time, the status quo arm has its own generator run, - # of a dedicated type. - if gr.generator_run_type == GeneratorRunType.STATUS_QUO.name: - status_quo_weight = gr.weights[0] - trial._status_quo_weight_override = status_quo_weight - trial._status_quo_generator_run_db_id = gr.db_id - trial._status_quo_arm_db_id = gr.arms[0].db_id - break - # However, sometimes the status quo arm is part of a generator run - # that also includes other arms (e.g. in factorial or TS design). - if sq_arm in gr.arms: - trial._status_quo_generator_run_db_id = gr.db_id - trial._status_quo_arm_db_id = sq_arm.db_id trial._generator_runs = generator_runs trial._abandoned_arms_metadata = { abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 5d50c9a895b..3f493e501fe 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -159,6 +159,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: create and store copies of the Trials, Metrics, Parameters, ParameterConstraints, and Runner owned by this Experiment. """ + arms = [self.arm_to_sqa(arm) for arm in experiment.arms_by_name.values()] optimization_metrics = self.optimization_config_to_sqa( experiment.optimization_config @@ -239,6 +240,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: status_quo_parameters=status_quo_parameters, time_created=experiment.time_created, experiment_type=experiment_type, + arms=arms, metrics=optimization_metrics + tracking_metrics, parameters=parameters, parameter_constraints=parameter_constraints, @@ -1043,8 +1045,16 @@ def trial_to_sqa( ) return trial_sqa - def experiment_data_to_sqa(self, experiment: Experiment) -> list[SQAData]: - """Convert Ax experiment data to SQLAlchemy.""" + def experiment_data_to_sqa( + self, + experiment: Experiment, + ) -> list[SQAData]: + if ( + experiment.experiment_type + in self.config.EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE + ): + return [] + return [ self.data_to_sqa(data=data, trial_index=trial_index, timestamp=timestamp) for trial_index, data_by_timestamp in experiment.data_by_trial.items() diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 2544eec4f43..61fdf9d1aa5 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -161,6 +161,7 @@ class SQAMetric(Base): class SQAArm(Base): __tablename__: str = "arm_v2" + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) generator_run_id: Column[int] = Column( Integer, ForeignKey("generator_run_v2.id"), nullable=False ) @@ -389,6 +390,9 @@ class SQAExperiment(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) + arms: list[SQAArm] = relationship( + "SQAArm", cascade="all, delete-orphan", lazy="selectin" + ) data: list[SQAData] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index 1102a6f14a7..07fb03cf321 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -67,9 +67,10 @@ class SQAConfig: serialization function. """ + EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE: set[str] = field(default_factory=set) + def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: - # pyre-fixme[7] - return { + ax_cls_to_sqa_cls = { AbandonedArm: SQAAbandonedArm, AnalysisCard: SQAAnalysisCard, Arm: SQAArm, @@ -84,6 +85,10 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: Trial: SQATrial, AuxiliaryExperiment: SQAAuxiliaryExperiment, } + return { + cast(type[Base], k): cast(type[SQABase], v) + for k, v in ax_cls_to_sqa_cls.items() + } class_to_sqa_class: dict[type[Base], type[SQABase]] = field( default_factory=_default_class_to_sqa_class @@ -92,27 +97,21 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( + # Encoding and decoding registries: + json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = field( default_factory=lambda: CORE_ENCODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( - default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY + json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = ( + field(default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY) ) - json_decoder_registry: TDecoderRegistry = field( default_factory=lambda: CORE_DECODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] = field( default_factory=lambda: CORE_CLASS_DECODER_REGISTRY ) + # Metric and runner class registries: metric_registry: dict[type[Metric], int] = field( default_factory=lambda: CORE_METRIC_REGISTRY ) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 3c1baa72fd1..f97ea8ca502 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -218,7 +218,6 @@ def creator() -> Mock: def test_GeneratorRunTypeValidation(self) -> None: experiment = get_experiment_with_batch_trial() - # pyre-fixme[16]: `BaseTrial` has no attribute `generator_run_structs`. generator_run = experiment.trials[0].generator_runs[0] generator_run._generator_run_type = "foobar" with self.assertRaises(SQAEncodeError): diff --git a/ax/storage/sqa_store/validation.py b/ax/storage/sqa_store/validation.py index 314ee973bd0..9346ccbe6c8 100644 --- a/ax/storage/sqa_store/validation.py +++ b/ax/storage/sqa_store/validation.py @@ -51,7 +51,6 @@ def wrapper(fn: Callable) -> Callable: return wrapper -# pyre-fixme[3]: Return annotation cannot be `Any`. def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any: """Ensure that exactly one of `exactly_one_fields` has a value set.""" values = [getattr(instance, field) is not None for field in exactly_one_fields]