Skip to content

Commit ebd0b00

Browse files
esantorellafacebook-github-bot
authored andcommitted
Consolidate MapData.from_multiple_map_data into MapData.from_multiple_data (facebook#4410)
Summary: **Context**: The following very similar methods exist: * `Data.from_multiple_data`: Combines multiple `Data` (not MapData) into one `Data` * `MapData.from_multiple_data`: Converts any data arguments into `MapData` if they are not already and then calls `MapData.from_multiple_map_data` * `MapData.from_multiple_map_data`: Combines `MapData`s into one `MapData` **This PR**: * Removes `MapData.from_multiple_map_data`; `MapData.from_multiple_data` should now be used Reviewed By: Balandat Differential Revision: D84066098
1 parent cd009ec commit ebd0b00

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

ax/core/map_data.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111

1212
from bisect import bisect_right
13-
from collections.abc import Iterable, Sequence
13+
from collections.abc import Iterable
1414
from logging import Logger
1515
from math import nan
1616
from typing import Any
@@ -119,23 +119,6 @@ def true_df(self) -> pd.DataFrame:
119119
def required_columns(self) -> set[str]:
120120
return super().required_columns().union({MAP_KEY})
121121

122-
@staticmethod
123-
def from_multiple_map_data(data: Sequence[MapData]) -> MapData:
124-
if len(data) == 0:
125-
return MapData()
126-
127-
# Avoid concatenating empty dataframes which logs a warning.
128-
non_empty_dfs = [datum.map_df for datum in data if not datum.map_df.empty]
129-
df = (
130-
pd.concat(non_empty_dfs)
131-
if len(non_empty_dfs) > 0
132-
else pd.DataFrame(
133-
columns=[*{col for datum in data for col in datum.required_columns()}]
134-
)
135-
)
136-
137-
return MapData(df=df)
138-
139122
@property
140123
def map_df(self) -> pd.DataFrame:
141124
return self.full_df
@@ -147,12 +130,26 @@ def from_multiple_data(cls, data: Iterable[Data]) -> MapData:
147130
148131
If no "step" column is present, it will be filled in with NaNs.
149132
"""
150-
map_datas = [
151-
(cls(df=datum.df) if not isinstance(datum, MapData) else datum)
133+
map_dfs = [
134+
datum.full_df
135+
if isinstance(datum, MapData)
136+
else datum.df.assign(**{MAP_KEY: nan})
152137
for datum in data
138+
if not datum.full_df.empty
153139
]
154140

155-
return cls.from_multiple_map_data(data=map_datas)
141+
if len(map_dfs) == 0:
142+
return MapData()
143+
144+
# Avoid concatenating empty dataframes which logs a warning.
145+
df = (
146+
pd.concat(map_dfs)
147+
if len(map_dfs) > 0
148+
else pd.DataFrame(
149+
columns=[*{col for datum in data for col in datum.required_columns()}]
150+
)
151+
)
152+
return MapData(df=df)
156153

157154
@property
158155
def df(self) -> pd.DataFrame:

ax/core/tests/test_map_data.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import pandas as pd
11+
from ax.core.data import Data
1112
from ax.core.map_data import MAP_KEY, MapData
1213
from ax.core.tests.test_data import TestDataBase
1314
from ax.utils.common.testutils import TestCase
@@ -108,13 +109,22 @@ def test_init(self) -> None:
108109

109110
def test_combine(self) -> None:
110111
with self.subTest("From no MapDatas"):
111-
data = MapData.from_multiple_map_data([])
112+
data = MapData.from_multiple_data([])
113+
self.assertIsInstance(data, MapData)
112114
self.assertEqual(data.map_df.size, 0)
113115

114116
with self.subTest("From two MapDatas"):
115-
mmd_double = MapData.from_multiple_map_data([self.mmd, self.mmd])
117+
mmd_double = MapData.from_multiple_data([self.mmd, self.mmd])
118+
self.assertIsInstance(mmd_double, MapData)
116119
self.assertEqual(mmd_double.map_df.size, 2 * self.mmd.map_df.size)
117120

121+
with self.subTest("From Datas"):
122+
data = Data(df=self.mmd.df)
123+
map_data = MapData.from_multiple_data([data])
124+
self.assertIsInstance(map_data, MapData)
125+
data = Data.from_multiple_data([data])
126+
self.assertEqual(len(data.full_df), len(map_data.full_df))
127+
118128
def test_upcast(self) -> None:
119129
fresh = MapData(df=self.df)
120130
# Assert df is not cached before first call

ax/metrics/branin_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def fetch_trial_data(
130130

131131
datas.append(MapData(df=df))
132132

133-
return Ok(value=MapData.from_multiple_map_data(datas))
133+
return Ok(value=MapData.from_multiple_data(data=datas))
134134

135135
except Exception as e:
136136
return Err(

0 commit comments

Comments
 (0)