@@ -194,37 +194,6 @@ earth_engine_args <-
194194 component_id = " engine"
195195 )
196196
197- brulee_mlp_engine_args <-
198- tibble :: tribble(
199- ~ name , ~ call_info ,
200- " momentum" , list (pkg = " dials" , fun = " momentum" , range = c(0.5 , 0.95 )),
201- " batch_size" , list (pkg = " dials" , fun = " batch_size" , range = c(3 , 10 )),
202- " stop_iter" , list (pkg = " dials" , fun = " stop_iter" ),
203- " class_weights" , list (pkg = " dials" , fun = " class_weights" ),
204- " decay" , list (pkg = " dials" , fun = " rate_decay" ),
205- " initial" , list (pkg = " dials" , fun = " rate_initial" ),
206- " largest" , list (pkg = " dials" , fun = " rate_largest" ),
207- " rate_schedule" , list (pkg = " dials" , fun = " rate_schedule" ),
208- " step_size" , list (pkg = " dials" , fun = " rate_step_size" ),
209- " mixture" , list (pkg = " dials" , fun = " mixture" )
210- ) %> %
211- dplyr :: mutate(source = " model_spec" ,
212- component = " mlp" ,
213- component_id = " engine"
214- )
215-
216- brulee_linear_engine_args <-
217- brulee_mlp_engine_args %> %
218- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" ))
219-
220- brulee_logistic_engine_args <-
221- brulee_mlp_engine_args %> %
222- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" , " class_weights" ))
223-
224- brulee_multinomial_engine_args <-
225- brulee_mlp_engine_args %> %
226- dplyr :: filter(name %in% c(" momentum" , " batch_size" , " stop_iter" , " class_weights" ))
227-
228197flexsurvspline_engine_args <-
229198 tibble :: tibble(
230199 name = c(" k" ),
@@ -236,6 +205,42 @@ flexsurvspline_engine_args <-
236205 component_id = " engine"
237206 )
238207
208+ # ------------------------------------------------------------------------------
209+ # used for brulee engines:
210+
211+ tune_activations <- c(" relu" , " tanh" , " elu" , " log_sigmoid" , " tanhshrink" )
212+ tune_sched <- c(" none" , " decay_time" , " decay_expo" , " cyclic" , " step" )
213+
214+ brulee_mlp_args <-
215+ tibble :: tibble(
216+ name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217+ ' penalty' , ' mixture' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218+ ' class_weights' , ' stop_iter' , ' rate_schedule' ),
219+ call_info = list (
220+ list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
221+ list (pkg = " dials" , fun = " hidden_units" , range = c(2L , 50L )),
222+ list (pkg = " dials" , fun = " hidden_units_2" , range = c(2L , 50L )),
223+ list (pkg = " dials" , fun = " activation" , values = tune_activations ),
224+ list (pkg = " dials" , fun = " activation_2" , values = tune_activations ),
225+ list (pkg = " dials" , fun = " penalty" ),
226+ list (pkg = " dials" , fun = " mixture" ),
227+ list (pkg = " dials" , fun = " dropout" ),
228+ list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
229+ list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
230+ list (pkg = " dials" , fun = " batch_size" ),
231+ list (pkg = " dials" , fun = " stop_iter" ),
232+ list (pkg = " dials" , fun = " class_weights" ),
233+ list (pkg = " dials" , fun = " rate_schedule" , values = tune_sched )
234+ )
235+ ) %> %
236+ dplyr :: mutate(source = " model_spec" )
237+
238+ brulee_mlp_only_args <-
239+ tibble :: tibble(
240+ name =
241+ c(' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' , ' dropout' )
242+ )
243+
239244# ------------------------------------------------------------------------------
240245
241246# ' @export
@@ -245,31 +250,55 @@ tunable.linear_reg <- function(x, ...) {
245250 res $ call_info [res $ name == " mixture" ] <-
246251 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
247252 } else if (x $ engine == " brulee" ) {
248- res <- add_engine_parameters(res , brulee_linear_engine_args )
253+ res <-
254+ brulee_mlp_args %> %
255+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
256+ dplyr :: filter(name != " class_weights" ) %> %
257+ dplyr :: mutate(
258+ component = " linear_reg" ,
259+ component_id = ifelse(name %in% names(formals(" linear_reg" )), " main" , " engine" )
260+ ) %> %
261+ dplyr :: select(name , call_info , source , component , component_id )
249262 }
250263 res
251264}
252265
266+ # ' @export
267+
253268# ' @export
254269tunable.logistic_reg <- function (x , ... ) {
255270 res <- NextMethod()
256271 if (x $ engine == " glmnet" ) {
257272 res $ call_info [res $ name == " mixture" ] <-
258273 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
259274 } else if (x $ engine == " brulee" ) {
260- res <- add_engine_parameters(res , brulee_logistic_engine_args )
275+ res <-
276+ brulee_mlp_args %> %
277+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
278+ dplyr :: mutate(
279+ component = " logistic_reg" ,
280+ component_id = ifelse(name %in% names(formals(" logistic_reg" )), " main" , " engine" )
281+ ) %> %
282+ dplyr :: select(name , call_info , source , component , component_id )
261283 }
262284 res
263285}
264286
265287# ' @export
266- tunable.multinomial_reg <- function (x , ... ) {
288+ tunable.multinom_reg <- function (x , ... ) {
267289 res <- NextMethod()
268290 if (x $ engine == " glmnet" ) {
269291 res $ call_info [res $ name == " mixture" ] <-
270292 list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
271293 } else if (x $ engine == " brulee" ) {
272- res <- add_engine_parameters(res , brulee_multinomial_engine_args )
294+ res <-
295+ brulee_mlp_args %> %
296+ dplyr :: anti_join(brulee_mlp_only_args , by = " name" ) %> %
297+ dplyr :: mutate(
298+ component = " multinom_reg" ,
299+ component_id = ifelse(name %in% names(formals(" multinom_reg" )), " main" , " engine" )
300+ ) %> %
301+ dplyr :: select(name , call_info , source , component , component_id )
273302 }
274303 res
275304}
@@ -345,28 +374,23 @@ tunable.svm_poly <- function(x, ...) {
345374 res
346375}
347376
348-
349377# ' @export
350378tunable.mlp <- function (x , ... ) {
351379 res <- NextMethod()
352- if (x $ engine == " brulee" ) {
353- res <- add_engine_parameters(res , brulee_mlp_engine_args )
354- res $ call_info [res $ name == " learn_rate" ] <-
355- list (list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 2 )))
356- res $ call_info [res $ name == " epochs" ] <-
357- list (list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )))
358- activation_values <- rlang :: eval_tidy(
359- rlang :: call2(" brulee_activations" , .ns = " brulee" )
360- )
361- res $ call_info [res $ name == " activation" ] <-
362- list (list (pkg = " dials" , fun = " activation" , values = activation_values ))
363- } else if (x $ engine == " keras" ) {
364- activation_values <- parsnip :: keras_activations()
365- res $ call_info [res $ name == " activation" ] <-
366- list (list (pkg = " dials" , fun = " activation" , values = activation_values ))
380+ if (grepl(" brulee" , x $ engine )) {
381+ res <-
382+ brulee_mlp_args %> %
383+ dplyr :: mutate(
384+ component = " mlp" ,
385+ component_id = ifelse(name %in% names(formals(" mlp" )), " main" , " engine" )
386+ ) %> %
387+ dplyr :: select(name , call_info , source , component , component_id )
388+ if (x $ engine == " brulee" ) {
389+ res <- res [! grepl(" _2" , res $ name ),]
390+ }
367391 }
368392 res
369- }
393+ }
370394
371395# ' @export
372396tunable.survival_reg <- function (x , ... ) {
0 commit comments