From 7c34365aef655fe54bae35cf9577bd7ca743a793 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 17 Sep 2025 08:07:57 -0700 Subject: [PATCH 1/5] Deprecate SQ weight override: SQ is now added as a GR right away, so its weight is encoded into that GR (#4212) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: With SQ always being added as part of a GR and with [this code block](https://fburl.com/code/pnmuwj71) no longer being there as of D80304993, I think we can completely deprecate the weight override –– it's a complication that I believe we don't use, that makes the code a lot harder to follow and the behavior –– hard to predict. I'd like to know if folks have any objections to this move, especially bletham, sdaulton, eonofrey, ItsMrLin, mgarrard, who may have found themselves needing this in notebook experiments (if you do, please help me understand the use cases) : ) Reviewed By: saitcakmak Differential Revision: D80968518 --- ax/core/batch_trial.py | 24 +++++++----------------- ax/core/tests/test_batch_trial.py | 5 ----- ax/storage/json_store/decoder.py | 1 + ax/storage/json_store/decoders.py | 6 +----- ax/storage/json_store/encoders.py | 1 - ax/storage/sqa_store/decoder.py | 2 -- 6 files changed, 9 insertions(+), 30 deletions(-) diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 57b3f611815..67086a731ea 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -8,8 +8,6 @@ from __future__ import annotations -import warnings - from collections import defaultdict, OrderedDict from collections.abc import MutableMapping from dataclasses import dataclass @@ -133,7 +131,6 @@ def __init__( self._arms_by_name: dict[str, Arm] = {} self._generator_runs: list[GeneratorRun] = [] self._abandoned_arms_metadata: dict[str, AbandonedArm] = {} - self._status_quo_weight_override: float | None = None self.should_add_status_quo_arm = should_add_status_quo_arm if generator_run is not None: if generator_runs is not None: @@ -184,6 +181,13 @@ def arm_weights(self) -> MutableMapping[Arm, float]: arm_weights[arm] += weight return arm_weights + @property + def _status_quo_weight_override(self) -> None: + raise DeprecationWarning( + "Status quo weight override is no longer supported. Please " + "contact the Ax developers for help adjusting your application." + ) + @arm_weights.setter def arm_weights(self, arm_weights: MutableMapping[Arm, float]) -> None: raise NotImplementedError("Use `trial.add_arms_and_weights`") @@ -330,16 +334,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial: if weight is not None: if weight <= 0.0: raise ValueError("Status quo weight must be positive.") - if ( - self._status_quo_weight_override is not None - and weight != self._status_quo_weight_override - ): - warnings.warn( - f"Status quo weight is being overridden from {weight} " - f"to {self._status_quo_weight_override}.", - UserWarning, - stacklevel=3, - ) sq_arm_already_added_with_correct_weight = False for existing_arm in self.arm_weights: @@ -360,7 +354,6 @@ def add_status_quo_arm(self, weight: float = 1.0) -> BatchTrial: weight=weight, generator_run_type=GeneratorRunType.STATUS_QUO, ) - self._status_quo_weight_override = weight self._refresh_arms_by_name() return self @@ -536,9 +529,6 @@ def clone_to( ], should_add_status_quo_arm=include_sq and self.should_add_status_quo_arm, ) - if (sq := self.status_quo) is not None and sq in self.arm_weights: - new_trial._status_quo_weight_override = self.arm_weights[sq] - self._update_trial_attrs_on_clone(new_trial=new_trial) return new_trial diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index b6b203b5159..4847a8e2b47 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -160,20 +160,15 @@ def test_status_quo_cannot_be_set_directly(self) -> None: def test_status_quo_weight_is_ignored_when_none(self) -> None: tot_weight = sum(self.batch.weights) self.assertEqual(sum(self.batch.weights), tot_weight) - self.assertIsNone(self.batch._status_quo_weight_override) def test_status_quo_set_on_clone_to(self) -> None: batch2 = self.batch.clone_to(include_sq=False) self.assertEqual(batch2.status_quo, self.experiment.status_quo) # Since should_add_status_quo_arm was False, - # _status_quo_weight_override should be False and the - # status_quo arm should not appear in arm_weights - self.assertIsNone(batch2._status_quo_weight_override) self.assertTrue(batch2.status_quo not in batch2.arm_weights) self.assertEqual(sum(batch2.weights), sum(self.weights)) # Test with should_add_status_quo_arm=True batch3 = self.experiment.new_batch_trial(should_add_status_quo_arm=True) - self.assertEqual(batch3._status_quo_weight_override, 1.0) self.assertTrue(batch2.status_quo in batch3.arm_weights) def test_cannot_add_status_quo_arm_without_status_quo(self) -> None: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 0d79e2c0f84..246ff0d0ba2 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -500,6 +500,7 @@ def trials_from_json( # `GeneratorRunStruct` (deprecated) will be decoded into a `GeneratorRun`, # so all we have to do here is change the key it's stored under. trial_json["generator_runs"] = trial_json.pop("generator_run_structs") + trial_json.pop("status_quo_weight_override", None) # Deprecated. loaded_trials[int(index)] = ( trial_from_json(experiment=experiment, **trial_json) if is_trial diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index 6b337ca4428..95b72cef972 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -103,7 +103,6 @@ def batch_trial_from_json( abandoned_arms_metadata: dict[str, AbandonedArm], num_arms_created: int, status_quo: Arm | None, - status_quo_weight_override: float, # Allowing default values for backwards compatibility with # objects stored before these fields were added. failed_reason: str | None = None, @@ -149,11 +148,8 @@ def batch_trial_from_json( # Trial.arms_by_name only returns arms with weights batch.should_add_status_quo_arm = batch.status_quo is not None - batch._status_quo_weight_override = status_quo_weight_override if batch.should_add_status_quo_arm: - batch.add_status_quo_arm( - weight=status_quo_weight_override, - ) + batch.add_status_quo_arm() # Set trial status last, after adding all the arms. batch._status = status diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index baceabaa2f6..a5511af373f 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -123,7 +123,6 @@ def batch_to_dict(batch: BatchTrial) -> dict[str, Any]: "ttl_seconds": batch.ttl_seconds, "status": batch.status, "status_quo": batch.status_quo, - "status_quo_weight_override": batch._status_quo_weight_override, "time_created": batch.time_created, "time_completed": batch.time_completed, "time_staged": batch.time_staged, diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 5fe85ac03a1..5ce4eb28a87 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -1017,8 +1017,6 @@ def trial_from_sqa( # Most of the time, the status quo arm has its own generator run, # of a dedicated type. if gr.generator_run_type == GeneratorRunType.STATUS_QUO.name: - status_quo_weight = gr.weights[0] - trial._status_quo_weight_override = status_quo_weight trial._status_quo_generator_run_db_id = gr.db_id trial._status_quo_arm_db_id = gr.arms[0].db_id break From bc3bb62b9d3c1ca9d03adaffe451e35163d792a7 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 17 Sep 2025 08:07:57 -0700 Subject: [PATCH 2/5] Remove `generator_run_db_id` (#4301) Summary: I believe this is a remnant of some logic that was necessary in the past but is no longer Reviewed By: mgarrard Differential Revision: D82565836 --- ax/storage/sqa_store/decoder.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 5ce4eb28a87..f74f5a095c7 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -29,7 +29,7 @@ from ax.core.batch_trial import AbandonedArm, BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment -from ax.core.generator_run import GeneratorRun, GeneratorRunType +from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -1011,20 +1011,6 @@ def trial_from_sqa( ) for generator_run_sqa in trial_sqa.generator_runs ] - if trial_sqa.status_quo_name is not None: - sq_arm = experiment.arms_by_name[trial_sqa.status_quo_name] - for gr in generator_runs: - # Most of the time, the status quo arm has its own generator run, - # of a dedicated type. - if gr.generator_run_type == GeneratorRunType.STATUS_QUO.name: - trial._status_quo_generator_run_db_id = gr.db_id - trial._status_quo_arm_db_id = gr.arms[0].db_id - break - # However, sometimes the status quo arm is part of a generator run - # that also includes other arms (e.g. in factorial or TS design). - if sq_arm in gr.arms: - trial._status_quo_generator_run_db_id = gr.db_id - trial._status_quo_arm_db_id = sq_arm.db_id trial._generator_runs = generator_runs trial._abandoned_arms_metadata = { abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( From 8792202b0dfdb4ac829f9e047f7b01e6a3162ae4 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 17 Sep 2025 08:07:57 -0700 Subject: [PATCH 3/5] Deprecate nearly all internal SQA classes (now that legacy PTS is reaped) (#4200) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/4200 Differential Revision: D80175881 --- ax/adapter/tests/test_torch_adapter.py | 3 +-- ax/storage/sqa_store/encoder.py | 12 +++++++-- ax/storage/sqa_store/sqa_config.py | 27 ++++++++++---------- ax/storage/sqa_store/tests/test_sqa_store.py | 1 - 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 50e7b2a1dd2..12a3095c1a1 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -483,8 +483,7 @@ def test_candidate_metadata_propagation(self) -> None: exp = get_branin_experiment(with_status_quo=True, with_completed_batch=True) # Check that the metadata is correctly re-added to observation # features during `fit`. - # pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run_structs`. - preexisting_batch_gr = exp.trials[0]._generator_runs[0] + preexisting_batch_gr = exp.trials[0].generator_runs[0] preexisting_batch_gr._candidate_metadata_by_arm_signature = { preexisting_batch_gr.arms[0].signature: { "preexisting_batch_cand_metadata": "some_value" diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 5d50c9a895b..ae1d529cc2f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -1043,8 +1043,16 @@ def trial_to_sqa( ) return trial_sqa - def experiment_data_to_sqa(self, experiment: Experiment) -> list[SQAData]: - """Convert Ax experiment data to SQLAlchemy.""" + def experiment_data_to_sqa( + self, + experiment: Experiment, + ) -> list[SQAData]: + if ( + experiment.experiment_type + in self.config.EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE + ): + return [] + return [ self.data_to_sqa(data=data, trial_index=trial_index, timestamp=timestamp) for trial_index, data_by_timestamp in experiment.data_by_trial.items() diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index 1102a6f14a7..07fb03cf321 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -67,9 +67,10 @@ class SQAConfig: serialization function. """ + EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE: set[str] = field(default_factory=set) + def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: - # pyre-fixme[7] - return { + ax_cls_to_sqa_cls = { AbandonedArm: SQAAbandonedArm, AnalysisCard: SQAAnalysisCard, Arm: SQAArm, @@ -84,6 +85,10 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: Trial: SQATrial, AuxiliaryExperiment: SQAAuxiliaryExperiment, } + return { + cast(type[Base], k): cast(type[SQABase], v) + for k, v in ax_cls_to_sqa_cls.items() + } class_to_sqa_class: dict[type[Base], type[SQABase]] = field( default_factory=_default_class_to_sqa_class @@ -92,27 +97,21 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( + # Encoding and decoding registries: + json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = field( default_factory=lambda: CORE_ENCODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( - default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY + json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = ( + field(default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY) ) - json_decoder_registry: TDecoderRegistry = field( default_factory=lambda: CORE_DECODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] = field( default_factory=lambda: CORE_CLASS_DECODER_REGISTRY ) + # Metric and runner class registries: metric_registry: dict[type[Metric], int] = field( default_factory=lambda: CORE_METRIC_REGISTRY ) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 3c1baa72fd1..f97ea8ca502 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -218,7 +218,6 @@ def creator() -> Mock: def test_GeneratorRunTypeValidation(self) -> None: experiment = get_experiment_with_batch_trial() - # pyre-fixme[16]: `BaseTrial` has no attribute `generator_run_structs`. generator_run = experiment.trials[0].generator_runs[0] generator_run._generator_run_type = "foobar" with self.assertRaises(SQAEncodeError): From 7d655cb31412a4eb2021096685d0d466e225bded Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 17 Sep 2025 08:07:57 -0700 Subject: [PATCH 4/5] Abstract `_add_generator_run` onto `BaseTrial` + other cleanup Summary: As titled, just a small refactor (individual changes commented inline) Differential Revision: D80968586 --- ax/core/base_trial.py | 91 ++++++++++++++++++++++++++++++++---------- ax/core/batch_trial.py | 8 ++-- ax/core/trial.py | 24 +++++------ 3 files changed, 84 insertions(+), 39 deletions(-) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 896c0ac1317..3b0e8b95859 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -30,6 +30,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys from pyre_extensions import none_throws +from typing_extensions import Self if TYPE_CHECKING: @@ -137,12 +138,60 @@ def __init__( # strategy, this property will be set to the generation step that produced # the generator run(s). self._generation_step_index: int | None = None - # Please do not store any data related to trial deployment or data- + # NOTE: Please do not store any data related to trial deployment or data- # fetching in properties. It is intended to only store properties related # to core Ax functionality and not to any third-system that the trials # might be getting deployed to. - # pyre-fixme[4]: Attribute must be annotated. - self._properties = {} + self._properties: dict[str, Any] = {} + + @abstractproperty + def arms(self) -> list[Arm]: + """All arms associated with this trial.""" + pass + + @abstractproperty + def arms_by_name(self) -> dict[str, Arm]: + """A mapping of from arm names, to all arms associated with + this trial. + """ + pass + + @abstractproperty + def abandoned_arms(self) -> list[Arm]: + """All abandoned arms, associated with this trial.""" + pass + + @abstractmethod + def add_generator_run(self, generator_run: GeneratorRun) -> Self: + """Add a generator run to the trial. + + The arms and weights from the generator run will be merged with + the existing arms and weights on the trial, and the generator run + object will be linked to the trial for tracking. + + Args: + generator_run: The generator run to be added. + + Returns: + The trial instance. + """ + pass + + @abstractmethod + def add_arm( + self, arm: Arm, candidate_metadata: dict[str, Any] | None = None + ) -> Self: + """Add arm to the trial. + + Returns: + The trial instance. + """ + pass + + @abstractmethod + def __repr__(self) -> str: + """String representation of the trial.""" + pass @property def experiment(self) -> core.experiment.Experiment: @@ -237,6 +286,25 @@ def trial_type(self) -> str | None: """ return self._trial_type + def _add_generator_run(self, generator_run: GeneratorRun) -> None: + """Helper called from ``{BatchTrial, Trial}.add_generator_run: validates + and names arms; adds the ``GeneratorRun`` to this trial's ``Experiment``. + """ + # 1. Validate the arm(s) in the generator run. + for arm in generator_run.arms: + self.experiment.search_space.check_types(arm.parameters, raise_error=True) + + # 2. Name any yet-unnamed arms: for arms that are not yet added to this + # trial's experiment, assign a new name; use name on experiment otherwise. + for arm in generator_run.arms: + self._check_existing_and_name_arm(arm) + + # 3.TODO: Add the generator run to the experiment if it's not already there; + # assign an experiment-level index if newly added. + + # 4. TODO: Capture which generator run the arms we are about to add this + # this trial, came from. + def assign_runner(self) -> BaseTrial: """Assigns default experiment runner if trial doesn't already have one.""" runner = self.experiment.runner_for_trial(self) @@ -431,23 +499,6 @@ def _set_generation_step_index(self, generation_step_index: int | None) -> None: ) self._generation_step_index = generation_step_index - @abstractproperty - def arms(self) -> list[Arm]: - pass - - @abstractproperty - def arms_by_name(self) -> dict[str, Arm]: - pass - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractproperty - def abandoned_arms(self) -> list[Arm]: - """All abandoned arms, associated with this trial.""" - pass - @property def active_arms(self) -> list[Arm]: """All non abandoned arms associated with this trial.""" diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index 67086a731ea..e3939d9cfbb 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -293,11 +293,9 @@ def add_generator_run(self, generator_run: GeneratorRun) -> BatchTrial: } ) - # Add names to arms - # For those not yet added to this experiment, create a new name - # Else, use the name of the existing arm - for arm in generator_run.arms: - self._check_existing_and_name_arm(arm) + # Call `BaseTrial._add_generator_run` to validate and name the arms, + # then attach the generator run to the experiment. + self._add_generator_run(generator_run=generator_run) self._generator_runs.append(generator_run) if generator_run._generation_step_index is not None: diff --git a/ax/core/trial.py b/ax/core/trial.py index c730c375208..5e16a8ecc85 100644 --- a/ax/core/trial.py +++ b/ax/core/trial.py @@ -74,8 +74,7 @@ def __init__( ttl_seconds=ttl_seconds, index=index, ) - # pyre-fixme[4]: Attribute must be annotated. - self._generator_run = None + self._generator_run: GeneratorRun | None = None if generator_run is not None: self.add_generator_run(generator_run=generator_run) @@ -84,16 +83,16 @@ def generator_run(self) -> GeneratorRun | None: """Generator run attached to this trial.""" return self._generator_run - # pyre-ignore[6]: T77111662. - @copy_doc(BaseTrial.generator_runs) @property def generator_runs(self) -> list[GeneratorRun]: - gr = self._generator_run - return [gr] if gr is not None else [] + """Generator runs attached to this trial. Since this is a one-arm + ``Trial`` (and not ``BatchTrial``), this will be a list of length 1. + """ + return [self._generator_run] if self._generator_run else [] @property def arm(self) -> Arm | None: - """The arm associated with this batch.""" + """The ``Arm`` associated with this ``Trial``.""" if self.generator_run is None: return None @@ -149,14 +148,11 @@ def add_generator_run(self, generator_run: GeneratorRun) -> Trial: "included multiple." ) - self.experiment.search_space.check_types( - generator_run.arms[0].parameters, raise_error=True - ) - self._check_existing_and_name_arm(generator_run.arms[0]) + # Call `BaseTrial._add_generator_run` to validate and name the arms, + # then attach the generator run to the experiment. + self._add_generator_run(generator_run=generator_run) + self._generator_run = generator_run - self._set_generation_step_index( - generation_step_index=generator_run._generation_step_index - ) return self @property From d8686c4c8c80db5944d7edca9522b8f8e6b60b87 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Wed, 17 Sep 2025 08:07:57 -0700 Subject: [PATCH 5/5] Associate arms with experiment ids Differential Revision: D80208015 --- ax/storage/sqa_store/decoder.py | 12 +++++++++++- ax/storage/sqa_store/encoder.py | 2 ++ ax/storage/sqa_store/sqa_classes.py | 4 ++++ ax/storage/sqa_store/validation.py | 1 - 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index f74f5a095c7..736c2a7dae6 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -223,7 +223,7 @@ def _init_experiment_from_sqa( else {} ) - return Experiment( + experiment = Experiment( name=experiment_sqa.name, description=experiment_sqa.description, search_space=search_space, @@ -237,6 +237,16 @@ def _init_experiment_from_sqa( auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, ) + for arm in experiment_sqa.arms: + experiment._register_arm( + arm=Arm( + parameters=arm.parameters, + name=arm.name, + ) + ) + + return experiment + def _init_mt_experiment_from_sqa( self, experiment_sqa: SQAExperiment, diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index ae1d529cc2f..3f493e501fe 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -159,6 +159,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: create and store copies of the Trials, Metrics, Parameters, ParameterConstraints, and Runner owned by this Experiment. """ + arms = [self.arm_to_sqa(arm) for arm in experiment.arms_by_name.values()] optimization_metrics = self.optimization_config_to_sqa( experiment.optimization_config @@ -239,6 +240,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: status_quo_parameters=status_quo_parameters, time_created=experiment.time_created, experiment_type=experiment_type, + arms=arms, metrics=optimization_metrics + tracking_metrics, parameters=parameters, parameter_constraints=parameter_constraints, diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 2544eec4f43..61fdf9d1aa5 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -161,6 +161,7 @@ class SQAMetric(Base): class SQAArm(Base): __tablename__: str = "arm_v2" + experiment_id: Column[int | None] = Column(Integer, ForeignKey("experiment_v2.id")) generator_run_id: Column[int] = Column( Integer, ForeignKey("generator_run_v2.id"), nullable=False ) @@ -389,6 +390,9 @@ class SQAExperiment(Base): # a child, the old one will be deleted. # Use selectin loading for collections to prevent idle timeout errors # (https://docs.sqlalchemy.org/en/13/orm/loading_relationships.html#selectin-eager-loading) + arms: list[SQAArm] = relationship( + "SQAArm", cascade="all, delete-orphan", lazy="selectin" + ) data: list[SQAData] = relationship( "SQAData", cascade="all, delete-orphan", lazy="selectin" ) diff --git a/ax/storage/sqa_store/validation.py b/ax/storage/sqa_store/validation.py index 314ee973bd0..9346ccbe6c8 100644 --- a/ax/storage/sqa_store/validation.py +++ b/ax/storage/sqa_store/validation.py @@ -51,7 +51,6 @@ def wrapper(fn: Callable) -> Callable: return wrapper -# pyre-fixme[3]: Return annotation cannot be `Any`. def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any: """Ensure that exactly one of `exactly_one_fields` has a value set.""" values = [getattr(instance, field) is not None for field in exactly_one_fields]