Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 86105ee

Browse files
add separate reader/writer for lgb labels (#421)
* add separate reader/writer for labels * fix linter * use free_raw_data as False in serialize test * use construct in process so that labels is never None fix some tests fix lint * remove the usage of construct * fix pylint * do not save labels when they are None * saved key shouldn't have /data/data when labels is None * fix test
1 parent 466fcac commit 86105ee

File tree

2 files changed

+180
-32
lines changed

2 files changed

+180
-32
lines changed

mlem/contrib/lightgbm.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
import posixpath
33
import tempfile
4-
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type
4+
from typing import Any, ClassVar, Iterator, Optional, Tuple, Type
55

6+
import flatdict
67
import lightgbm as lgb
78
from pydantic import BaseModel
89

@@ -28,6 +29,8 @@
2829
)
2930

3031
LGB_REQUIREMENT = UnixPackageRequirement(package_name="libgomp1")
32+
LIGHTGBM_DATA = "inner"
33+
LIGHTGBM_LABEL = "label"
3134

3235

3336
class LightGBMDataType(
@@ -37,20 +40,38 @@ class LightGBMDataType(
3740
:class:`.DataType` implementation for `lightgbm.Dataset` type
3841
3942
:param inner: :class:`.DataType` instance for underlying data
43+
:param labels: :class:`.DataType` instance for underlying labels
4044
"""
4145

4246
type: ClassVar[str] = "lightgbm"
4347
valid_types: ClassVar = (lgb.Dataset,)
4448
inner: DataType
49+
labels: Optional[DataType]
4550

4651
def serialize(self, instance: Any) -> dict:
4752
self.check_type(instance, lgb.Dataset, SerializationError)
53+
if self.labels is not None:
54+
return {
55+
LIGHTGBM_DATA: self.inner.get_serializer().serialize(
56+
instance.data
57+
),
58+
LIGHTGBM_LABEL: self.labels.get_serializer().serialize(
59+
instance.label
60+
),
61+
}
4862
return self.inner.get_serializer().serialize(instance.data)
4963

5064
def deserialize(self, obj: dict) -> Any:
51-
v = self.inner.get_serializer().deserialize(obj)
65+
if self.labels is not None:
66+
data = self.inner.get_serializer().deserialize(obj[LIGHTGBM_DATA])
67+
label = self.labels.get_serializer().deserialize(
68+
obj[LIGHTGBM_LABEL]
69+
)
70+
else:
71+
data = self.inner.get_serializer().deserialize(obj)
72+
label = None
5273
try:
53-
return lgb.Dataset(v, free_raw_data=False)
74+
return lgb.Dataset(data, label=label, free_raw_data=False)
5475
except ValueError as e:
5576
raise DeserializationError(
5677
f"object: {obj} could not be converted to lightgbm dataset"
@@ -70,7 +91,12 @@ def get_writer(
7091

7192
@classmethod
7293
def process(cls, obj: Any, **kwargs) -> DataType:
73-
return LightGBMDataType(inner=DataAnalyzer.analyze(obj.data))
94+
return LightGBMDataType(
95+
inner=DataAnalyzer.analyze(obj.data),
96+
labels=DataAnalyzer.analyze(obj.label)
97+
if obj.label is not None
98+
else None,
99+
)
74100

75101
def get_model(self, prefix: str = "") -> Type[BaseModel]:
76102
return self.inner.get_serializer().get_model(prefix)
@@ -86,33 +112,68 @@ def write(
86112
raise ValueError(
87113
f"expected data to be of LightGBMDataType, got {type(data)} instead"
88114
)
89-
lightgbm_construct = data.data.construct()
90-
raw_data = lightgbm_construct.get_data()
91-
underlying_labels = lightgbm_construct.get_label().tolist()
92-
inner_reader, art = data.inner.get_writer().write(
93-
data.inner.copy().bind(raw_data), storage, path
94-
)
115+
116+
lightgbm_raw = data.data
117+
118+
if data.labels is not None:
119+
inner_reader, inner_art = data.inner.get_writer().write(
120+
data.inner.copy().bind(lightgbm_raw.data),
121+
storage,
122+
posixpath.join(path, LIGHTGBM_DATA),
123+
)
124+
labels_reader, labels_art = data.labels.get_writer().write(
125+
data.labels.copy().bind(lightgbm_raw.label),
126+
storage,
127+
posixpath.join(path, LIGHTGBM_LABEL),
128+
)
129+
res = dict(
130+
flatdict.FlatterDict(
131+
{LIGHTGBM_DATA: inner_art, LIGHTGBM_LABEL: labels_art},
132+
delimiter="/",
133+
)
134+
)
135+
else:
136+
inner_reader, inner_art = data.inner.get_writer().write(
137+
data.inner.copy().bind(lightgbm_raw.data),
138+
storage,
139+
path,
140+
)
141+
res = inner_art
142+
labels_reader = None
143+
95144
return (
96145
LightGBMDataReader(
97146
data_type=data,
98147
inner=inner_reader,
99-
label=underlying_labels,
148+
labels=labels_reader,
100149
),
101-
art,
150+
res,
102151
)
103152

104153

105154
class LightGBMDataReader(DataReader):
106155
type: ClassVar[str] = "lightgbm"
107156
data_type: LightGBMDataType
108157
inner: DataReader
109-
label: List
158+
labels: Optional[DataReader]
110159

111160
def read(self, artifacts: Artifacts) -> DataType:
112-
inner_data_type = self.inner.read(artifacts)
113-
return LightGBMDataType(inner=inner_data_type).bind(
161+
if self.labels is not None:
162+
artifacts = flatdict.FlatterDict(artifacts, delimiter="/")
163+
inner_data_type = self.inner.read(artifacts[LIGHTGBM_DATA]) # type: ignore[arg-type]
164+
labels_data_type = self.labels.read(artifacts[LIGHTGBM_LABEL]) # type: ignore[arg-type]
165+
else:
166+
inner_data_type = self.inner.read(artifacts)
167+
labels_data_type = None
168+
return LightGBMDataType(
169+
inner=inner_data_type, labels=labels_data_type
170+
).bind(
114171
lgb.Dataset(
115-
inner_data_type.data, label=self.label, free_raw_data=False
172+
inner_data_type.data,
173+
label=labels_data_type.data
174+
if labels_data_type is not None
175+
else None,
176+
free_raw_data=False,
116177
)
117178
)
118179

tests/contrib/test_lightgbm.py

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytest
55

66
from mlem.contrib.lightgbm import (
7+
LIGHTGBM_DATA,
8+
LIGHTGBM_LABEL,
79
LightGBMDataReader,
810
LightGBMDataType,
911
LightGBMDataWriter,
@@ -12,7 +14,12 @@
1214
from mlem.contrib.numpy import NumpyNdarrayType
1315
from mlem.contrib.pandas import DataFrameType
1416
from mlem.core.artifacts import LOCAL_STORAGE
15-
from mlem.core.data_type import DataAnalyzer, DataType
17+
from mlem.core.data_type import (
18+
ArrayType,
19+
DataAnalyzer,
20+
DataType,
21+
PrimitiveType,
22+
)
1623
from mlem.core.errors import DeserializationError, SerializationError
1724
from mlem.core.model import ModelAnalyzer, ModelType
1825
from mlem.core.requirements import UnixPackageRequirement
@@ -46,7 +53,7 @@ def df_payload():
4653
def data_df(df_payload):
4754
return lgb.Dataset(
4855
df_payload,
49-
label=np.array([0, 1]).tolist(),
56+
label=np.array([0, 1]),
5057
free_raw_data=False,
5158
)
5259

@@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType):
7582
assert set(dtype_np.get_requirements().modules) == {"lightgbm", "numpy"}
7683
assert isinstance(dtype_np, LightGBMDataType)
7784
assert isinstance(dtype_np.inner, NumpyNdarrayType)
85+
assert isinstance(dtype_np.labels, ArrayType)
86+
assert dtype_np.labels.dtype == PrimitiveType(data=None, ptype="float")
7887
assert dtype_np.get_model().__name__ == dtype_np.inner.get_model().__name__
7988
assert dtype_np.get_model().schema() == {
8089
"title": "NumpyNdarray",
@@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType):
92101
assert set(dtype_df.get_requirements().modules) == {"lightgbm", "pandas"}
93102
assert isinstance(dtype_df, LightGBMDataType)
94103
assert isinstance(dtype_df.inner, DataFrameType)
104+
assert isinstance(dtype_df.labels, NumpyNdarrayType)
95105
assert dtype_df.get_model().__name__ == dtype_df.inner.get_model().__name__
96106
assert dtype_df.get_model().schema() == {
97107
"title": "DataFrame",
@@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType):
116126

117127

118128
@pytest.mark.parametrize(
119-
"lgb_dtype, data_type",
120-
[("dtype_np", NumpyNdarrayType), ("dtype_df", DataFrameType)],
129+
"lgb_dtype, data_type, label_type",
130+
[
131+
("dtype_np", NumpyNdarrayType, ArrayType),
132+
("dtype_df", DataFrameType, NumpyNdarrayType),
133+
],
121134
)
122-
def test_lightgbm_source(lgb_dtype, data_type, request):
135+
def test_lightgbm_source(lgb_dtype, data_type, label_type, request):
123136
lgb_dtype = request.getfixturevalue(lgb_dtype)
124137
assert isinstance(lgb_dtype, LightGBMDataType)
125138
assert isinstance(lgb_dtype.inner, data_type)
139+
assert isinstance(lgb_dtype.labels, label_type)
126140

127141
def custom_assert(x, y):
128142
assert hasattr(x, "data")
129143
assert hasattr(y, "data")
130144
assert all(x.data == y.data)
131-
assert all(x.label == y.label)
145+
label_check = x.label == y.label
146+
if isinstance(label_check, (list, np.ndarray)):
147+
assert all(label_check)
148+
else:
149+
assert label_check
132150

133-
data_write_read_check(
151+
artifacts = data_write_read_check(
134152
lgb_dtype,
135153
writer=LightGBMDataWriter(),
136154
reader_type=LightGBMDataReader,
137155
custom_assert=custom_assert,
138156
)
139157

158+
if isinstance(lgb_dtype.inner, NumpyNdarrayType):
159+
assert list(artifacts.keys()) == [
160+
f"{LIGHTGBM_DATA}/data",
161+
f"{LIGHTGBM_LABEL}/0/data",
162+
f"{LIGHTGBM_LABEL}/1/data",
163+
f"{LIGHTGBM_LABEL}/2/data",
164+
f"{LIGHTGBM_LABEL}/3/data",
165+
f"{LIGHTGBM_LABEL}/4/data",
166+
]
167+
assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith(
168+
f"data/{LIGHTGBM_DATA}"
169+
)
170+
assert artifacts[f"{LIGHTGBM_LABEL}/0/data"].uri.endswith(
171+
f"data/{LIGHTGBM_LABEL}/0"
172+
)
173+
assert artifacts[f"{LIGHTGBM_LABEL}/1/data"].uri.endswith(
174+
f"data/{LIGHTGBM_LABEL}/1"
175+
)
176+
assert artifacts[f"{LIGHTGBM_LABEL}/2/data"].uri.endswith(
177+
f"data/{LIGHTGBM_LABEL}/2"
178+
)
179+
assert artifacts[f"{LIGHTGBM_LABEL}/3/data"].uri.endswith(
180+
f"data/{LIGHTGBM_LABEL}/3"
181+
)
182+
assert artifacts[f"{LIGHTGBM_LABEL}/4/data"].uri.endswith(
183+
f"data/{LIGHTGBM_LABEL}/4"
184+
)
185+
else:
186+
assert list(artifacts.keys()) == [
187+
f"{LIGHTGBM_DATA}/data",
188+
f"{LIGHTGBM_LABEL}/data",
189+
]
190+
assert artifacts[f"{LIGHTGBM_DATA}/data"].uri.endswith(
191+
f"data/{LIGHTGBM_DATA}"
192+
)
193+
assert artifacts[f"{LIGHTGBM_LABEL}/data"].uri.endswith(
194+
f"data/{LIGHTGBM_LABEL}"
195+
)
196+
140197

141198
def test_serialize__np(dtype_np, np_payload):
142-
ds = lgb.Dataset(np_payload)
199+
ds = lgb.Dataset(np_payload, label=np_payload.reshape((-1,)).tolist())
143200
payload = dtype_np.serialize(ds)
144-
assert payload == np_payload.tolist()
201+
assert payload[LIGHTGBM_DATA] == np_payload.tolist()
202+
assert payload[LIGHTGBM_LABEL] == np_payload.reshape((-1,)).tolist()
145203

146204
with pytest.raises(SerializationError):
147205
dtype_np.serialize({"abc": 123}) # wrong type
148206

149207

150208
def test_deserialize__np(dtype_np, np_payload):
151-
ds = dtype_np.deserialize(np_payload)
209+
ds = dtype_np.deserialize(
210+
{
211+
LIGHTGBM_DATA: np_payload,
212+
LIGHTGBM_LABEL: np_payload.reshape((-1,)).tolist(),
213+
}
214+
)
152215
assert isinstance(ds, lgb.Dataset)
153216
assert np.all(ds.data == np_payload)
217+
assert np.all(ds.label == np_payload.reshape((-1,)).tolist())
154218

155219
with pytest.raises(DeserializationError):
156-
dtype_np.deserialize([[1], ["abc"]]) # illegal matrix
220+
dtype_np.deserialize({LIGHTGBM_DATA: [[1], ["abc"]]}) # illegal matrix
157221

158222

159-
def test_serialize__df(dtype_df, df_payload):
160-
ds = lgb.Dataset(df_payload)
161-
payload = dtype_df.serialize(ds)
162-
assert payload["values"] == df_payload.to_dict("records")
223+
def test_serialize__df(df_payload):
224+
ds = lgb.Dataset(df_payload, label=None, free_raw_data=False)
225+
payload = DataType.create(obj=ds)
226+
assert payload.serialize(ds)["values"] == df_payload.to_dict("records")
227+
assert LIGHTGBM_LABEL not in payload
228+
229+
def custom_assert(x, y):
230+
assert hasattr(x, "data")
231+
assert hasattr(y, "data")
232+
assert all(x.data == y.data)
233+
assert x.label == y.label
234+
235+
artifacts = data_write_read_check(
236+
payload,
237+
writer=LightGBMDataWriter(),
238+
reader_type=LightGBMDataReader,
239+
custom_assert=custom_assert,
240+
)
241+
242+
assert len(artifacts.keys()) == 1
243+
assert list(artifacts.keys()) == ["data"]
244+
assert artifacts["data"].uri.endswith("/data")
163245

164246

165247
def test_deserialize__df(dtype_df, df_payload):
166-
ds = dtype_df.deserialize({"values": df_payload})
248+
ds = dtype_df.deserialize(
249+
{
250+
LIGHTGBM_DATA: {"values": df_payload},
251+
LIGHTGBM_LABEL: np.array([0, 1]).tolist(),
252+
}
253+
)
167254
assert isinstance(ds, lgb.Dataset)
168255
assert ds.data.equals(df_payload)
169256

0 commit comments

Comments
 (0)