4
4
import torch .nn as nn
5
5
from torch .autograd import Variable
6
6
from torch import optim
7
+ import pandas as pd
8
+ from sklearn .model_selection import train_test_split
9
+ from sklearn .externals import joblib
10
+ from sklearn .ensemble import RandomForestRegressor
7
11
8
12
from .callbacks import NeptuneMonitorSegmentation , ValidationMonitorSegmentation
9
13
from .steps .pytorch .architectures .unet import UNet
10
14
from .steps .pytorch .callbacks import CallbackList , TrainingMonitor , ModelCheckpoint , \
11
15
ExperimentTiming , ExponentialLRScheduler , EarlyStopping
12
16
from .steps .pytorch .models import Model
13
17
from .steps .pytorch .validation import multiclass_segmentation_loss , DiceLoss
18
+ from .steps .sklearn .models import LightGBM , make_transformer , SklearnRegressor
14
19
from .utils import softmax
15
20
from .unet_models import AlbuNet , UNet11 , UNetVGG16 , UNetResNet
16
21
@@ -159,9 +164,12 @@ def __init__(self, architecture_config, training_config, callbacks_config):
159
164
class PyTorchUNetWeightedStream (BasePyTorchUNet ):
160
165
def __init__ (self , architecture_config , training_config , callbacks_config ):
161
166
super ().__init__ (architecture_config , training_config , callbacks_config )
162
- weighted_loss = partial (multiclass_weighted_cross_entropy ,
163
- ** get_loss_variables (** architecture_config ['weighted_cross_entropy' ]))
164
- loss = partial (mixed_dice_cross_entropy_loss , dice_weight = architecture_config ['loss_weights' ]['dice_mask' ],
167
+ weights_function = partial (get_weights , ** architecture_config ['weighted_cross_entropy' ])
168
+ weighted_loss = partial (multiclass_weighted_cross_entropy , weights_function = weights_function )
169
+ dice_loss = partial (multiclass_dice_loss , excluded_classes = [0 ])
170
+ loss = partial (mixed_dice_cross_entropy_loss ,
171
+ dice_loss = dice_loss ,
172
+ dice_weight = architecture_config ['loss_weights' ]['dice_mask' ],
165
173
cross_entropy_weight = architecture_config ['loss_weights' ]['bce_mask' ],
166
174
cross_entropy_loss = weighted_loss ,
167
175
** architecture_config ['dice' ])
@@ -201,6 +209,81 @@ def _transform(self, datagen, validation_datagen=None):
201
209
self .model .train ()
202
210
203
211
212
+ class ScoringLightGBM (LightGBM ):
213
+ def __init__ (self , model_params , training_params , train_size , target ):
214
+ self .train_size = train_size
215
+ self .target = target
216
+ self .feature_names = []
217
+ self .estimator = None
218
+ super ().__init__ (model_params , training_params )
219
+
220
+ def fit (self , features , ** kwargs ):
221
+ df_features = _convert_features_to_df (features )
222
+ train_data , val_data = train_test_split (df_features , train_size = self .train_size )
223
+ self .feature_names = list (df_features .columns .drop (self .target ))
224
+ super ().fit (X = train_data [self .feature_names ],
225
+ y = train_data [self .target ],
226
+ X_valid = val_data [self .feature_names ],
227
+ y_valid = val_data [self .target ],
228
+ feature_names = self .feature_names ,
229
+ categorical_features = [])
230
+ return self
231
+
232
+ def transform (self , features , ** kwargs ):
233
+ scores = []
234
+ for image_features in features :
235
+ image_scores = []
236
+ for layer_features in image_features :
237
+ if len (layer_features ) > 0 :
238
+ layer_scores = super ().transform (layer_features [self .feature_names ])
239
+ image_scores .append (list (layer_scores ['prediction' ]))
240
+ else :
241
+ image_scores .append ([])
242
+ scores .append (image_scores )
243
+ return {'scores' : scores }
244
+
245
+ def save (self , filepath ):
246
+ joblib .dump ((self .estimator , self .feature_names ), filepath )
247
+
248
+ def load (self , filepath ):
249
+ self .estimator , self .feature_names = joblib .load (filepath )
250
+
251
+
252
+ class ScoringRandomForest (SklearnRegressor ):
253
+ def __init__ (self , train_size , target , ** kwargs ):
254
+ self .train_size = train_size
255
+ self .target = target
256
+ self .feature_names = []
257
+ self .estimator = RandomForestRegressor ()
258
+
259
+ def fit (self , features , ** kwargs ):
260
+ df_features = _convert_features_to_df (features )
261
+ train_data , val_data = train_test_split (df_features , train_size = self .train_size )
262
+ self .feature_names = list (df_features .columns .drop (self .target ))
263
+ super ().fit (X = train_data [self .feature_names ],
264
+ y = train_data [self .target ])
265
+ return self
266
+
267
+ def transform (self , features , ** kwargs ):
268
+ scores = []
269
+ for image_features in features :
270
+ image_scores = []
271
+ for layer_features in image_features :
272
+ if len (layer_features ) > 0 :
273
+ layer_scores = super ().transform (layer_features [self .feature_names ])
274
+ image_scores .append (list (layer_scores ['prediction' ]))
275
+ else :
276
+ image_scores .append ([])
277
+ scores .append (image_scores )
278
+ return {'scores' : scores }
279
+
280
+ def save (self , filepath ):
281
+ joblib .dump ((self .estimator , self .feature_names ), filepath )
282
+
283
+ def load (self , filepath ):
284
+ self .estimator , self .feature_names = joblib .load (filepath )
285
+
286
+
204
287
def weight_regularization_unet (model , regularize , weight_decay_conv2d ):
205
288
if regularize :
206
289
parameter_list = [{'params' : model .parameters (), 'weight_decay' : weight_decay_conv2d }]
@@ -369,3 +452,11 @@ def multiclass_dice_loss(output, target, smooth=0, activation='softmax', exclude
369
452
class_target .data = class_target .data .float ()
370
453
loss += dice (output [:, class_nr , :, :], class_target )
371
454
return loss
455
+
456
+
457
+ def _convert_features_to_df (features ):
458
+ df_features = []
459
+ for image_features in features :
460
+ for layer_features in image_features [1 :]:
461
+ df_features .append (layer_features )
462
+ return pd .concat (df_features )
0 commit comments