Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings

from bisect import bisect_right
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from logging import Logger
from math import nan
from typing import Any
Expand Down Expand Up @@ -119,23 +119,6 @@ def true_df(self) -> pd.DataFrame:
def required_columns(self) -> set[str]:
return super().required_columns().union({MAP_KEY})

@staticmethod
def from_multiple_map_data(data: Sequence[MapData]) -> MapData:
if len(data) == 0:
return MapData()

# Avoid concatenating empty dataframes which logs a warning.
non_empty_dfs = [datum.map_df for datum in data if not datum.map_df.empty]
df = (
pd.concat(non_empty_dfs)
if len(non_empty_dfs) > 0
else pd.DataFrame(
columns=[*{col for datum in data for col in datum.required_columns()}]
)
)

return MapData(df=df)

@property
def map_df(self) -> pd.DataFrame:
return self.full_df
Expand All @@ -147,12 +130,18 @@ def from_multiple_data(cls, data: Iterable[Data]) -> MapData:

If no "step" column is present, it will be filled in with NaNs.
"""
map_datas = [
(cls(df=datum.df) if not isinstance(datum, MapData) else datum)
map_dfs = [
datum.full_df
if isinstance(datum, MapData)
else datum.df.assign(**{MAP_KEY: nan})
for datum in data
if not datum.full_df.empty
]

return cls.from_multiple_map_data(data=map_datas)
if len(map_dfs) == 0:
return MapData()

return MapData(df=pd.concat(map_dfs))

@property
def df(self) -> pd.DataFrame:
Expand Down
14 changes: 12 additions & 2 deletions ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pandas as pd
from ax.core.data import Data
from ax.core.map_data import MAP_KEY, MapData
from ax.core.tests.test_data import TestDataBase
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -108,13 +109,22 @@ def test_init(self) -> None:

def test_combine(self) -> None:
with self.subTest("From no MapDatas"):
data = MapData.from_multiple_map_data([])
data = MapData.from_multiple_data([])
self.assertIsInstance(data, MapData)
self.assertEqual(data.map_df.size, 0)

with self.subTest("From two MapDatas"):
mmd_double = MapData.from_multiple_map_data([self.mmd, self.mmd])
mmd_double = MapData.from_multiple_data([self.mmd, self.mmd])
self.assertIsInstance(mmd_double, MapData)
self.assertEqual(mmd_double.map_df.size, 2 * self.mmd.map_df.size)

with self.subTest("From Datas"):
data = Data(df=self.mmd.df)
map_data = MapData.from_multiple_data([data])
self.assertIsInstance(map_data, MapData)
data = Data.from_multiple_data([data])
self.assertEqual(len(data.full_df), len(map_data.full_df))

def test_upcast(self) -> None:
fresh = MapData(df=self.df)
# Assert df is not cached before first call
Expand Down
2 changes: 1 addition & 1 deletion ax/metrics/branin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def fetch_trial_data(

datas.append(MapData(df=df))

return Ok(value=MapData.from_multiple_map_data(datas))
return Ok(value=MapData.from_multiple_data(data=datas))

except Exception as e:
return Err(
Expand Down