1+ # ' Preprocess a parameter list, overriding defaults with any provided parameters.
2+ # '
3+ # ' @param default_params List of parameters with default values set.
4+ # ' @param user_params (Optional) User-supplied overrides to `default_params`.
5+ # '
6+ # ' @return Parameter list with defaults overriden by values supplied in `user_params`
7+ # ' @export
8+ preprocessParams <- function (default_params , user_params = NULL ) {
9+ # Override defaults from general_params
10+ if (! is.null(user_params )) {
11+ for (key in names(user_params )) {
12+ if (key %in% names(default_params )) {
13+ val <- user_params [[key ]]
14+ if (! is.null(val )) default_params [[key ]] <- val
15+ }
16+ }
17+ }
18+
19+ # Return result
20+ return (default_params )
21+ }
22+
123# ' Preprocess BART parameter list. Override defaults with any provided parameters.
224# '
3- # ' @param params Parameter list
25+ # ' @param general_params List of any non-forest-specific parameters
26+ # ' @param mean_forest_params List of any mean forest parameters
27+ # ' @param variance_forest_params List of any variance forest parameters
428# '
5- # ' @return Parameter list with defaults overriden by values supplied in `params`
29+ # ' @return Parameter list with defaults overriden by values supplied in parameter lists
630# ' @export
7- preprocessBartParams <- function (params ) {
31+ preprocessBartParams <- function (general_params , mean_forest_params , variance_forest_params ) {
832 # Default parameter values
933 processed_params <- list (
10- cutpoint_grid_size = 100 , sigma_leaf_init = NULL ,
34+ cutpoint_grid_size = 100 ,
1135 alpha_mean = 0.95 , beta_mean = 2.0 ,
1236 min_samples_leaf_mean = 5 , max_depth_mean = 10 ,
37+ variable_weights_mean = NULL , num_trees_mean = 200 ,
1338 alpha_variance = 0.95 , beta_variance = 2.0 ,
1439 min_samples_leaf_variance = 5 , max_depth_variance = 10 ,
15- a_global = 0 , b_global = 0 , a_leaf = 3 , b_leaf = NULL ,
16- a_forest = NULL , b_forest = NULL , variance_scale = 1 ,
17- sigma2_init = NULL , variance_forest_init = NULL ,
18- pct_var_sigma2_init = 1 , pct_var_variance_forest_init = 1 ,
19- variable_weights_mean = NULL , variable_weights_variance = NULL ,
20- num_trees_mean = 200 , num_trees_variance = 0 ,
21- sample_sigma_global = T , sample_sigma_leaf = F ,
40+ variable_weights_variance = NULL , num_trees_variance = 0 ,
41+ sample_sigma2_global = T , sigma2_global_init = NULL ,
42+ sigma2_global_shape = 0 , sigma2_global_scale = 0 ,
43+ sample_sigma2_leaf = T , sigma2_leaf_init = NULL ,
44+ sigma2_leaf_shape = 3 , sigma2_leaf_scale = NULL ,
45+ var_forest_prior_shape = NULL , var_forest_prior_scale = NULL ,
46+ variance_forest_init = NULL ,
47+ sample_sigma_global = T , sample_sigma2_leaf_mean = F ,
2248 random_seed = - 1 , keep_burnin = F , keep_gfr = F , keep_every = 1 ,
2349 num_chains = 1 , standardize = T , verbose = F
2450 )
2551
26- # Override defaults
27- for (key in names(params )) {
28- if (! key %in% names(processed_params )) {
29- stop(" Variable " , key , " is not a valid BART model parameter" )
52+ # Override defaults from general_params
53+ for (key in names(general_params )) {
54+ if (key %in% names(processed_params )) {
55+ val <- general_params [[key ]]
56+ if (! is.null(val )) processed_params [[key ]] <- val
57+ }
58+ }
59+
60+ # Override defaults from mean_forest_params
61+ for (key in names(mean_forest_params )) {
62+ modified_key <- paste0(key , " _mean" )
63+ if (modified_key %in% names(processed_params )) {
64+ val <- general_params [[key ]]
65+ if (! is.null(val )) processed_params [[modified_key ]] <- val
66+ }
67+ }
68+
69+ # Override defaults from variance_forest_params
70+ for (key in names(variance_forest_params )) {
71+ modified_key <- paste0(key , " _variance" )
72+ if (modified_key %in% names(processed_params )) {
73+ val <- general_params [[key ]]
74+ if (! is.null(val )) processed_params [[modified_key ]] <- val
3075 }
31- val <- params [[key ]]
32- if (! is.null(val )) processed_params [[key ]] <- val
3376 }
3477
3578 # Return result
@@ -38,9 +81,12 @@ preprocessBartParams <- function(params) {
3881
3982# ' Preprocess BCF parameter list. Override defaults with any provided parameters.
4083# '
41- # ' @param params Parameter list
84+ # ' @param general_params List of any non-forest-specific parameters
85+ # ' @param mu_forest_params List of any mu forest parameters
86+ # ' @param tau_forest_params List of any tau forest parameters
87+ # ' @param variance_forest_params List of any variance forest parameters
4288# '
43- # ' @return Parameter list with defaults overriden by values supplied in `params`
89+ # ' @return Parameter list with defaults overriden by values supplied in parameter lists
4490# ' @export
4591preprocessBcfParams <- function (params ) {
4692 # Default parameter values
@@ -57,19 +103,45 @@ preprocessBcfParams <- function(params) {
57103 keep_vars_tau = NULL , drop_vars_tau = NULL , keep_vars_variance = NULL ,
58104 drop_vars_variance = NULL , num_trees_mu = 250 , num_trees_tau = 50 ,
59105 num_trees_variance = 0 , num_gfr = 5 , num_burnin = 0 , num_mcmc = 100 ,
60- sample_sigma_global = T , sample_sigma_leaf_mu = T , sample_sigma_leaf_tau = F ,
106+ sample_sigma_global = T , sample_sigma2_leaf_mu = T , sample_sigma2_leaf_tau = F ,
61107 propensity_covariate = " mu" , adaptive_coding = T , b_0 = - 0.5 , b_1 = 0.5 ,
62108 rfx_prior_var = NULL , random_seed = - 1 , keep_burnin = F , keep_gfr = F ,
63109 keep_every = 1 , num_chains = 1 , standardize = T , verbose = F
64110 )
65111
66112 # Override defaults
67113 for (key in names(params )) {
68- if (! key %in% names(processed_params )) {
69- stop(" Variable " , key , " is not a valid BART model parameter" )
114+ if (key %in% names(processed_params )) {
115+ val <- params [[key ]]
116+ if (! is.null(val )) processed_params [[key ]] <- val
117+ }
118+ }
119+
120+ # Override defaults from mu_forest_params
121+ for (key in names(mu_forest_params )) {
122+ modified_key <- paste0(key , " _mu" )
123+ if (modified_key %in% names(processed_params )) {
124+ val <- general_params [[key ]]
125+ if (! is.null(val )) processed_params [[modified_key ]] <- val
126+ }
127+ }
128+
129+ # Override defaults from tau_forest_params
130+ for (key in names(tau_forest_params )) {
131+ modified_key <- paste0(key , " _tau" )
132+ if (modified_key %in% names(processed_params )) {
133+ val <- general_params [[key ]]
134+ if (! is.null(val )) processed_params [[modified_key ]] <- val
135+ }
136+ }
137+
138+ # Override defaults from variance_forest_params
139+ for (key in names(variance_forest_params )) {
140+ modified_key <- paste0(key , " _variance" )
141+ if (modified_key %in% names(processed_params )) {
142+ val <- general_params [[key ]]
143+ if (! is.null(val )) processed_params [[modified_key ]] <- val
70144 }
71- val <- params [[key ]]
72- if (! is.null(val )) processed_params [[key ]] <- val
73145 }
74146
75147 # Return result
0 commit comments