Skip to content

Commit ab83025

Browse files
committed
Merge branch 'main' into documentation-updates
2 parents 97eb1f3 + d68db11 commit ab83025

28 files changed

+1157
-830
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Description: Stochastic tree ensembles (XBART and BART) for supervised learning
1313
License: MIT + file LICENSE
1414
Encoding: UTF-8
1515
Roxygen: list(markdown = TRUE)
16-
RoxygenNote: 7.3.1
16+
RoxygenNote: 7.3.2
1717
LinkingTo:
1818
cpp11, BH
1919
Suggests:

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export(orderedCatInitializeAndPreprocess)
5151
export(orderedCatPreprocess)
5252
export(preprocessBartParams)
5353
export(preprocessBcfParams)
54+
export(preprocessParams)
5455
export(preprocessPredictionData)
5556
export(preprocessPredictionDataFrame)
5657
export(preprocessPredictionMatrix)

R/bart.R

Lines changed: 119 additions & 111 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 170 additions & 131 deletions
Large diffs are not rendered by default.

R/forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ Forest <- R6::R6Class(
729729
#' Retrieve a vector of split counts for every training set variable in the forest
730730
#' @param num_features Total number of features in the training set
731731
get_forest_split_counts = function(num_features) {
732-
return(get_forest_split_counts_active_forest_cpp(self$forest_ptr, num_features))
732+
return(get_overall_split_counts_active_forest_cpp(self$forest_ptr, num_features))
733733
},
734734

735735
#' @description

R/utils.R

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,78 @@
1+
#' Preprocess a parameter list, overriding defaults with any provided parameters.
2+
#'
3+
#' @param default_params List of parameters with default values set.
4+
#' @param user_params (Optional) User-supplied overrides to `default_params`.
5+
#'
6+
#' @return Parameter list with defaults overriden by values supplied in `user_params`
7+
#' @export
8+
preprocessParams <- function(default_params, user_params = NULL) {
9+
# Override defaults from general_params
10+
if (!is.null(user_params)) {
11+
for (key in names(user_params)) {
12+
if (key %in% names(default_params)) {
13+
val <- user_params[[key]]
14+
if (!is.null(val)) default_params[[key]] <- val
15+
}
16+
}
17+
}
18+
19+
# Return result
20+
return(default_params)
21+
}
22+
123
#' Preprocess BART parameter list. Override defaults with any provided parameters.
224
#'
3-
#' @param params Parameter list
25+
#' @param general_params List of any non-forest-specific parameters
26+
#' @param mean_forest_params List of any mean forest parameters
27+
#' @param variance_forest_params List of any variance forest parameters
428
#'
5-
#' @return Parameter list with defaults overriden by values supplied in `params`
29+
#' @return Parameter list with defaults overriden by values supplied in parameter lists
630
#' @export
7-
preprocessBartParams <- function(params) {
31+
preprocessBartParams <- function(general_params, mean_forest_params, variance_forest_params) {
832
# Default parameter values
933
processed_params <- list(
10-
cutpoint_grid_size = 100, sigma_leaf_init = NULL,
34+
cutpoint_grid_size = 100,
1135
alpha_mean = 0.95, beta_mean = 2.0,
1236
min_samples_leaf_mean = 5, max_depth_mean = 10,
37+
variable_weights_mean = NULL, num_trees_mean = 200,
1338
alpha_variance = 0.95, beta_variance = 2.0,
1439
min_samples_leaf_variance = 5, max_depth_variance = 10,
15-
a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL,
16-
a_forest = NULL, b_forest = NULL, variance_scale = 1,
17-
sigma2_init = NULL, variance_forest_init = NULL,
18-
pct_var_sigma2_init = 1, pct_var_variance_forest_init = 1,
19-
variable_weights_mean = NULL, variable_weights_variance = NULL,
20-
num_trees_mean = 200, num_trees_variance = 0,
21-
sample_sigma_global = T, sample_sigma_leaf = F,
40+
variable_weights_variance = NULL, num_trees_variance = 0,
41+
sample_sigma2_global = T, sigma2_global_init = NULL,
42+
sigma2_global_shape = 0, sigma2_global_scale = 0,
43+
sample_sigma2_leaf = T, sigma2_leaf_init = NULL,
44+
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
45+
var_forest_prior_shape = NULL, var_forest_prior_scale = NULL,
46+
variance_forest_init = NULL,
47+
sample_sigma_global = T, sample_sigma2_leaf_mean = F,
2248
random_seed = -1, keep_burnin = F, keep_gfr = F, keep_every = 1,
2349
num_chains = 1, standardize = T, verbose = F
2450
)
2551

26-
# Override defaults
27-
for (key in names(params)) {
28-
if (!key %in% names(processed_params)) {
29-
stop("Variable ", key, " is not a valid BART model parameter")
52+
# Override defaults from general_params
53+
for (key in names(general_params)) {
54+
if (key %in% names(processed_params)) {
55+
val <- general_params[[key]]
56+
if (!is.null(val)) processed_params[[key]] <- val
57+
}
58+
}
59+
60+
# Override defaults from mean_forest_params
61+
for (key in names(mean_forest_params)) {
62+
modified_key <- paste0(key, "_mean")
63+
if (modified_key %in% names(processed_params)) {
64+
val <- general_params[[key]]
65+
if (!is.null(val)) processed_params[[modified_key]] <- val
66+
}
67+
}
68+
69+
# Override defaults from variance_forest_params
70+
for (key in names(variance_forest_params)) {
71+
modified_key <- paste0(key, "_variance")
72+
if (modified_key %in% names(processed_params)) {
73+
val <- general_params[[key]]
74+
if (!is.null(val)) processed_params[[modified_key]] <- val
3075
}
31-
val <- params[[key]]
32-
if (!is.null(val)) processed_params[[key]] <- val
3376
}
3477

3578
# Return result
@@ -38,9 +81,12 @@ preprocessBartParams <- function(params) {
3881

3982
#' Preprocess BCF parameter list. Override defaults with any provided parameters.
4083
#'
41-
#' @param params Parameter list
84+
#' @param general_params List of any non-forest-specific parameters
85+
#' @param mu_forest_params List of any mu forest parameters
86+
#' @param tau_forest_params List of any tau forest parameters
87+
#' @param variance_forest_params List of any variance forest parameters
4288
#'
43-
#' @return Parameter list with defaults overriden by values supplied in `params`
89+
#' @return Parameter list with defaults overriden by values supplied in parameter lists
4490
#' @export
4591
preprocessBcfParams <- function(params) {
4692
# Default parameter values
@@ -57,19 +103,45 @@ preprocessBcfParams <- function(params) {
57103
keep_vars_tau = NULL, drop_vars_tau = NULL, keep_vars_variance = NULL,
58104
drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50,
59105
num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100,
60-
sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F,
106+
sample_sigma_global = T, sample_sigma2_leaf_mu = T, sample_sigma2_leaf_tau = F,
61107
propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5,
62108
rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F,
63109
keep_every = 1, num_chains = 1, standardize = T, verbose = F
64110
)
65111

66112
# Override defaults
67113
for (key in names(params)) {
68-
if (!key %in% names(processed_params)) {
69-
stop("Variable ", key, " is not a valid BART model parameter")
114+
if (key %in% names(processed_params)) {
115+
val <- params[[key]]
116+
if (!is.null(val)) processed_params[[key]] <- val
117+
}
118+
}
119+
120+
# Override defaults from mu_forest_params
121+
for (key in names(mu_forest_params)) {
122+
modified_key <- paste0(key, "_mu")
123+
if (modified_key %in% names(processed_params)) {
124+
val <- general_params[[key]]
125+
if (!is.null(val)) processed_params[[modified_key]] <- val
126+
}
127+
}
128+
129+
# Override defaults from tau_forest_params
130+
for (key in names(tau_forest_params)) {
131+
modified_key <- paste0(key, "_tau")
132+
if (modified_key %in% names(processed_params)) {
133+
val <- general_params[[key]]
134+
if (!is.null(val)) processed_params[[modified_key]] <- val
135+
}
136+
}
137+
138+
# Override defaults from variance_forest_params
139+
for (key in names(variance_forest_params)) {
140+
modified_key <- paste0(key, "_variance")
141+
if (modified_key %in% names(processed_params)) {
142+
val <- general_params[[key]]
143+
if (!is.null(val)) processed_params[[modified_key]] <- val
70144
}
71-
val <- params[[key]]
72-
if (!is.null(val)) processed_params[[key]] <- val
73145
}
74146

75147
# Return result

demo/notebooks/causal_inference.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@
103103
"outputs": [],
104104
"source": [
105105
"bcf_model = BCFModel()\n",
106-
"bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, params={\"keep_every\": 5})"
106+
"general_params = {\"keep_every\": 5}\n",
107+
"bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100, general_params=general_params)"
107108
]
108109
},
109110
{

demo/notebooks/supervised_learning.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@
119119
"outputs": [],
120120
"source": [
121121
"bart_model = BARTModel()\n",
122-
"param_dict = {\"num_chains\": 3}\n",
123-
"bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, params=param_dict)"
122+
"general_params = {\"num_chains\": 3}\n",
123+
"bart_model.sample(X_train=X_train, y_train=y_train, basis_train=basis_train, X_test=X_test, basis_test=basis_test, num_gfr=10, num_mcmc=100, general_params=general_params)"
124124
]
125125
},
126126
{

0 commit comments

Comments
 (0)