Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.adapter.base import DataLoaderConfig
from ax.adapter.registry import GeneratorRegistryBase, Generators
from ax.adapter.transforms.base import Transform
from ax.api.protocols.metric import IMetric
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_metric import (
BenchmarkMapMetric,
Expand Down Expand Up @@ -216,6 +217,7 @@
IsSingleObjective: transition_criterion_to_dict,
L2NormMetric: metric_to_dict,
LogNormalPrior: botorch_component_to_dict,
IMetric: metric_to_dict,
MapMetric: metric_to_dict,
MaxGenerationParallelism: transition_criterion_to_dict,
Metric: metric_to_dict,
Expand Down Expand Up @@ -334,6 +336,7 @@
"HierarchicalSearchSpace": HierarchicalSearchSpace,
"ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy,
"InputConstructorPurpose": InputConstructorPurpose,
"IMetric": IMetric,
"Interval": Interval,
"IsSingleObjective": IsSingleObjective,
"Keys": Keys,
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.adapter.transforms.base import Transform
from ax.adapter.transforms.log import Log
from ax.adapter.transforms.one_hot import OneHot
from ax.api.protocols.metric import IMetric
from ax.benchmark.methods.sobol import get_sobol_benchmark_method
from ax.benchmark.testing.benchmark_stubs import (
get_aggregated_benchmark_result,
Expand Down Expand Up @@ -353,6 +354,7 @@
),
),
("HierarchicalSearchSpace", get_hierarchical_search_space),
("IMetric", lambda: IMetric(name="test")),
("ImprovementGlobalStoppingStrategy", get_improvement_global_stopping_strategy),
("Interval", get_interval),
("MapData", get_map_data),
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/metric_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Callable
from typing import Any

from ax.api.protocols.metric import IMetric
from ax.core.map_metric import MapMetric
from ax.core.metric import Metric
from ax.metrics.branin import BraninMetric
Expand Down Expand Up @@ -43,6 +44,7 @@
ChemistryMetric: 7,
MapMetric: 8,
BraninTimestampMapMetric: 9,
IMetric: 10,
}


Expand Down
5 changes: 0 additions & 5 deletions ax/storage/sqa_store/sqa_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
class BaseNullableEnum(types.TypeDecorator):
cache_ok = True

# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def __init__(self, enum: Any, *arg: list[Any], **kw: dict[Any, Any]) -> None:
types.TypeDecorator.__init__(self, *arg, **kw)
# pyre-fixme[4]: Attribute must be annotated.
self._member_map = enum._member_map_
# pyre-fixme[4]: Attribute must be annotated.
self._value2member_map = enum._value2member_map_

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
Expand All @@ -40,8 +37,6 @@ def process_bind_param(self, value: Any, dialect: Any) -> Any:
)
return val._value_

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
Expand Down
9 changes: 9 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.adapter.registry import Generators
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.api.protocols.metric import IMetric
from ax.core.analysis_card import AnalysisCard, AnalysisCardGroup
from ax.core.arm import Arm
from ax.core.auxiliary import (
Expand Down Expand Up @@ -1864,6 +1865,14 @@ def test_MetricDecodeWithNoSignatureOverride(self) -> None:
self.assertEqual(metric.name, metric_name)
self.assertEqual(metric.signature, metric_name)

def test_IMetricEncodeDecode(self) -> None:
metric_name = "test_imetric"
imetric = IMetric(name=metric_name)
sqa_metric = self.encoder.metric_to_sqa(imetric)
decoded_metric = cast(Metric, self.decoder.metric_from_sqa(sqa_metric))
self.assertIsInstance(decoded_metric, IMetric)
self.assertEqual(decoded_metric.name, metric_name)

def test_MetricDecodeWithSignatureOverride(self) -> None:
metric_name = "testMetric"
testMetric = Metric(name=metric_name, signature_override="override")
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


# pyre-fixme[3]: Return annotation cannot be `Any`.
def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any:
"""Ensure that exactly one of `exactly_one_fields` has a value set."""
values = [getattr(instance, field) is not None for field in exactly_one_fields]
Expand Down