4
4
import pytest
5
5
6
6
from mlem .contrib .lightgbm import (
7
+ LIGHTGBM_DATA ,
8
+ LIGHTGBM_LABEL ,
7
9
LightGBMDataReader ,
8
10
LightGBMDataType ,
9
11
LightGBMDataWriter ,
12
14
from mlem .contrib .numpy import NumpyNdarrayType
13
15
from mlem .contrib .pandas import DataFrameType
14
16
from mlem .core .artifacts import LOCAL_STORAGE
15
- from mlem .core .data_type import DataAnalyzer , DataType
17
+ from mlem .core .data_type import (
18
+ ArrayType ,
19
+ DataAnalyzer ,
20
+ DataType ,
21
+ PrimitiveType ,
22
+ )
16
23
from mlem .core .errors import DeserializationError , SerializationError
17
24
from mlem .core .model import ModelAnalyzer , ModelType
18
25
from mlem .core .requirements import UnixPackageRequirement
@@ -46,7 +53,7 @@ def df_payload():
46
53
def data_df (df_payload ):
47
54
return lgb .Dataset (
48
55
df_payload ,
49
- label = np .array ([0 , 1 ]). tolist () ,
56
+ label = np .array ([0 , 1 ]),
50
57
free_raw_data = False ,
51
58
)
52
59
@@ -75,6 +82,8 @@ def test_hook_np(dtype_np: DataType):
75
82
assert set (dtype_np .get_requirements ().modules ) == {"lightgbm" , "numpy" }
76
83
assert isinstance (dtype_np , LightGBMDataType )
77
84
assert isinstance (dtype_np .inner , NumpyNdarrayType )
85
+ assert isinstance (dtype_np .labels , ArrayType )
86
+ assert dtype_np .labels .dtype == PrimitiveType (data = None , ptype = "float" )
78
87
assert dtype_np .get_model ().__name__ == dtype_np .inner .get_model ().__name__
79
88
assert dtype_np .get_model ().schema () == {
80
89
"title" : "NumpyNdarray" ,
@@ -92,6 +101,7 @@ def test_hook_df(dtype_df: DataType):
92
101
assert set (dtype_df .get_requirements ().modules ) == {"lightgbm" , "pandas" }
93
102
assert isinstance (dtype_df , LightGBMDataType )
94
103
assert isinstance (dtype_df .inner , DataFrameType )
104
+ assert isinstance (dtype_df .labels , NumpyNdarrayType )
95
105
assert dtype_df .get_model ().__name__ == dtype_df .inner .get_model ().__name__
96
106
assert dtype_df .get_model ().schema () == {
97
107
"title" : "DataFrame" ,
@@ -116,54 +126,131 @@ def test_hook_df(dtype_df: DataType):
116
126
117
127
118
128
@pytest .mark .parametrize (
119
- "lgb_dtype, data_type" ,
120
- [("dtype_np" , NumpyNdarrayType ), ("dtype_df" , DataFrameType )],
129
+ "lgb_dtype, data_type, label_type" ,
130
+ [
131
+ ("dtype_np" , NumpyNdarrayType , ArrayType ),
132
+ ("dtype_df" , DataFrameType , NumpyNdarrayType ),
133
+ ],
121
134
)
122
- def test_lightgbm_source (lgb_dtype , data_type , request ):
135
+ def test_lightgbm_source (lgb_dtype , data_type , label_type , request ):
123
136
lgb_dtype = request .getfixturevalue (lgb_dtype )
124
137
assert isinstance (lgb_dtype , LightGBMDataType )
125
138
assert isinstance (lgb_dtype .inner , data_type )
139
+ assert isinstance (lgb_dtype .labels , label_type )
126
140
127
141
def custom_assert (x , y ):
128
142
assert hasattr (x , "data" )
129
143
assert hasattr (y , "data" )
130
144
assert all (x .data == y .data )
131
- assert all (x .label == y .label )
145
+ label_check = x .label == y .label
146
+ if isinstance (label_check , (list , np .ndarray )):
147
+ assert all (label_check )
148
+ else :
149
+ assert label_check
132
150
133
- data_write_read_check (
151
+ artifacts = data_write_read_check (
134
152
lgb_dtype ,
135
153
writer = LightGBMDataWriter (),
136
154
reader_type = LightGBMDataReader ,
137
155
custom_assert = custom_assert ,
138
156
)
139
157
158
+ if isinstance (lgb_dtype .inner , NumpyNdarrayType ):
159
+ assert list (artifacts .keys ()) == [
160
+ f"{ LIGHTGBM_DATA } /data" ,
161
+ f"{ LIGHTGBM_LABEL } /0/data" ,
162
+ f"{ LIGHTGBM_LABEL } /1/data" ,
163
+ f"{ LIGHTGBM_LABEL } /2/data" ,
164
+ f"{ LIGHTGBM_LABEL } /3/data" ,
165
+ f"{ LIGHTGBM_LABEL } /4/data" ,
166
+ ]
167
+ assert artifacts [f"{ LIGHTGBM_DATA } /data" ].uri .endswith (
168
+ f"data/{ LIGHTGBM_DATA } "
169
+ )
170
+ assert artifacts [f"{ LIGHTGBM_LABEL } /0/data" ].uri .endswith (
171
+ f"data/{ LIGHTGBM_LABEL } /0"
172
+ )
173
+ assert artifacts [f"{ LIGHTGBM_LABEL } /1/data" ].uri .endswith (
174
+ f"data/{ LIGHTGBM_LABEL } /1"
175
+ )
176
+ assert artifacts [f"{ LIGHTGBM_LABEL } /2/data" ].uri .endswith (
177
+ f"data/{ LIGHTGBM_LABEL } /2"
178
+ )
179
+ assert artifacts [f"{ LIGHTGBM_LABEL } /3/data" ].uri .endswith (
180
+ f"data/{ LIGHTGBM_LABEL } /3"
181
+ )
182
+ assert artifacts [f"{ LIGHTGBM_LABEL } /4/data" ].uri .endswith (
183
+ f"data/{ LIGHTGBM_LABEL } /4"
184
+ )
185
+ else :
186
+ assert list (artifacts .keys ()) == [
187
+ f"{ LIGHTGBM_DATA } /data" ,
188
+ f"{ LIGHTGBM_LABEL } /data" ,
189
+ ]
190
+ assert artifacts [f"{ LIGHTGBM_DATA } /data" ].uri .endswith (
191
+ f"data/{ LIGHTGBM_DATA } "
192
+ )
193
+ assert artifacts [f"{ LIGHTGBM_LABEL } /data" ].uri .endswith (
194
+ f"data/{ LIGHTGBM_LABEL } "
195
+ )
196
+
140
197
141
198
def test_serialize__np (dtype_np , np_payload ):
142
- ds = lgb .Dataset (np_payload )
199
+ ds = lgb .Dataset (np_payload , label = np_payload . reshape (( - 1 ,)). tolist () )
143
200
payload = dtype_np .serialize (ds )
144
- assert payload == np_payload .tolist ()
201
+ assert payload [LIGHTGBM_DATA ] == np_payload .tolist ()
202
+ assert payload [LIGHTGBM_LABEL ] == np_payload .reshape ((- 1 ,)).tolist ()
145
203
146
204
with pytest .raises (SerializationError ):
147
205
dtype_np .serialize ({"abc" : 123 }) # wrong type
148
206
149
207
150
208
def test_deserialize__np (dtype_np , np_payload ):
151
- ds = dtype_np .deserialize (np_payload )
209
+ ds = dtype_np .deserialize (
210
+ {
211
+ LIGHTGBM_DATA : np_payload ,
212
+ LIGHTGBM_LABEL : np_payload .reshape ((- 1 ,)).tolist (),
213
+ }
214
+ )
152
215
assert isinstance (ds , lgb .Dataset )
153
216
assert np .all (ds .data == np_payload )
217
+ assert np .all (ds .label == np_payload .reshape ((- 1 ,)).tolist ())
154
218
155
219
with pytest .raises (DeserializationError ):
156
- dtype_np .deserialize ([[1 ], ["abc" ]]) # illegal matrix
220
+ dtype_np .deserialize ({ LIGHTGBM_DATA : [[1 ], ["abc" ]]} ) # illegal matrix
157
221
158
222
159
- def test_serialize__df (dtype_df , df_payload ):
160
- ds = lgb .Dataset (df_payload )
161
- payload = dtype_df .serialize (ds )
162
- assert payload ["values" ] == df_payload .to_dict ("records" )
223
+ def test_serialize__df (df_payload ):
224
+ ds = lgb .Dataset (df_payload , label = None , free_raw_data = False )
225
+ payload = DataType .create (obj = ds )
226
+ assert payload .serialize (ds )["values" ] == df_payload .to_dict ("records" )
227
+ assert LIGHTGBM_LABEL not in payload
228
+
229
+ def custom_assert (x , y ):
230
+ assert hasattr (x , "data" )
231
+ assert hasattr (y , "data" )
232
+ assert all (x .data == y .data )
233
+ assert x .label == y .label
234
+
235
+ artifacts = data_write_read_check (
236
+ payload ,
237
+ writer = LightGBMDataWriter (),
238
+ reader_type = LightGBMDataReader ,
239
+ custom_assert = custom_assert ,
240
+ )
241
+
242
+ assert len (artifacts .keys ()) == 1
243
+ assert list (artifacts .keys ()) == ["data" ]
244
+ assert artifacts ["data" ].uri .endswith ("/data" )
163
245
164
246
165
247
def test_deserialize__df (dtype_df , df_payload ):
166
- ds = dtype_df .deserialize ({"values" : df_payload })
248
+ ds = dtype_df .deserialize (
249
+ {
250
+ LIGHTGBM_DATA : {"values" : df_payload },
251
+ LIGHTGBM_LABEL : np .array ([0 , 1 ]).tolist (),
252
+ }
253
+ )
167
254
assert isinstance (ds , lgb .Dataset )
168
255
assert ds .data .equals (df_payload )
169
256
0 commit comments