Skip to content

Commit 72f9d90

Browse files
authored
Merge pull request #127 from StochasticTree/variable-weight-consistency
Updated interface to handle several UI configurations consistently
2 parents 5c455f7 + 3950347 commit 72f9d90

File tree

6 files changed

+300
-66
lines changed

6 files changed

+300
-66
lines changed

R/bart.R

Lines changed: 109 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#' - `sigma2_global_init` Starting value of global error variance parameter. Calibrated internally as `1.0*var(y_train)`, where `y_train` is the possibly standardized outcome, if not set.
3838
#' - `sigma2_global_shape` Shape parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
3939
#' - `sigma2_global_scale` Scale parameter in the `IG(sigma2_global_shape, sigma2_global_scale)` global error variance model. Default: `0`.
40+
#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`.
4041
#' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
4142
#' - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
4243
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
@@ -51,11 +52,12 @@
5152
#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
5253
#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the mean model. Default: `5`.
5354
#' - `max_depth` Maximum depth of any tree in the ensemble in the mean model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
54-
#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
5555
#' - `sample_sigma2_leaf` Whether or not to update the leaf scale variance parameter based on `IG(sigma2_leaf_shape, sigma2_leaf_scale)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: `FALSE`.
5656
#' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here.
5757
#' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`.
5858
#' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
59+
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
60+
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
5961
#'
6062
#' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
6163
#'
@@ -64,10 +66,12 @@
6466
#' - `beta` Exponent that decreases split probabilities for nodes of depth > 0 in the variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Default: `2`.
6567
#' - `min_samples_leaf` Minimum allowable size of a leaf, in terms of training samples, in the variance model. Default: `5`.
6668
#' - `max_depth` Maximum depth of any tree in the ensemble in the variance model. Default: `10`. Can be overridden with ``-1`` which does not enforce any depth limits on trees.
67-
#' - `variable_weights` Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
69+
#' - `leaf_prior_calibration_param` Hyperparameter used to calibrate the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model. If `var_forest_prior_shape` and `var_forest_prior_scale` are not set below, this calibration parameter is used to set these values to `num_trees / leaf_prior_calibration_param^2 + 0.5` and `num_trees / leaf_prior_calibration_param^2`, respectively. Default: `1.5`.
6870
#' - `var_forest_leaf_init` Starting value of root forest prediction in conditional (heteroskedastic) error variance model. Calibrated internally as `log(0.6*var(y_train))/num_trees`, where `y_train` is the possibly standardized outcome, if not set.
69-
#' - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2 + 0.5` if not set.
70-
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / 1.5^2` if not set.
71+
#' - `var_forest_prior_shape` Shape parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2 + 0.5` if not set.
72+
#' - `var_forest_prior_scale` Scale parameter in the `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance model (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set.
73+
#' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`.
74+
#' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
7175
#'
7276
#' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
7377
#' @export
@@ -108,8 +112,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
108112
cutpoint_grid_size = 100, standardize = T,
109113
sample_sigma2_global = T, sigma2_global_init = NULL,
110114
sigma2_global_shape = 0, sigma2_global_scale = 0,
111-
random_seed = -1, keep_burnin = F, keep_gfr = F,
112-
keep_every = 1, num_chains = 1, verbose = F
115+
variable_weights = NULL, random_seed = -1,
116+
keep_burnin = F, keep_gfr = F, keep_every = 1,
117+
num_chains = 1, verbose = F
113118
)
114119
general_params_updated <- preprocessParams(
115120
general_params_default, general_params
@@ -119,9 +124,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
119124
mean_forest_params_default <- list(
120125
num_trees = 200, alpha = 0.95, beta = 2.0,
121126
min_samples_leaf = 5, max_depth = 10,
122-
variable_weights = NULL,
123127
sample_sigma2_leaf = T, sigma2_leaf_init = NULL,
124-
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL
128+
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
129+
keep_vars = NULL, drop_vars = NULL
125130
)
126131
mean_forest_params_updated <- preprocessParams(
127132
mean_forest_params_default, mean_forest_params
@@ -131,8 +136,11 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
131136
variance_forest_params_default <- list(
132137
num_trees = 0, alpha = 0.95, beta = 2.0,
133138
min_samples_leaf = 5, max_depth = 10,
134-
variable_weights = NULL, var_forest_leaf_init = NULL,
135-
var_forest_prior_shape = NULL, var_forest_prior_scale = NULL
139+
leaf_prior_calibration_param = 1.5,
140+
var_forest_leaf_init = NULL,
141+
var_forest_prior_shape = NULL,
142+
var_forest_prior_scale = NULL,
143+
keep_vars = NULL, drop_vars = NULL
136144
)
137145
variance_forest_params_updated <- preprocessParams(
138146
variance_forest_params_default, variance_forest_params
@@ -146,6 +154,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
146154
sigma2_init <- general_params_updated$sigma2_global_init
147155
a_global <- general_params_updated$sigma2_global_shape
148156
b_global <- general_params_updated$sigma2_global_scale
157+
variable_weights <- general_params_updated$variable_weights
149158
random_seed <- general_params_updated$random_seed
150159
keep_burnin <- general_params_updated$keep_burnin
151160
keep_gfr <- general_params_updated$keep_gfr
@@ -159,22 +168,25 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
159168
beta_mean <- mean_forest_params_updated$beta
160169
min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf
161170
max_depth_mean <- mean_forest_params_updated$max_depth
162-
variable_weights_mean <- mean_forest_params_updated$variable_weights
163171
sample_sigma_leaf <- mean_forest_params_updated$sample_sigma2_leaf
164172
sigma_leaf_init <- mean_forest_params_updated$sigma2_leaf_init
165173
a_leaf <- mean_forest_params_updated$sigma2_leaf_shape
166174
b_leaf <- mean_forest_params_updated$sigma2_leaf_scale
175+
keep_vars_mean <- mean_forest_params_updated$keep_vars
176+
drop_vars_mean <- mean_forest_params_updated$drop_vars
167177

168178
# 3. Variance forest parameters
169179
num_trees_variance <- variance_forest_params_updated$num_trees
170180
alpha_variance <- variance_forest_params_updated$alpha
171181
beta_variance <- variance_forest_params_updated$beta
172182
min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf
173183
max_depth_variance <- variance_forest_params_updated$max_depth
174-
variable_weights_variance <- variance_forest_params_updated$variable_weights
184+
a_0 <- variance_forest_params_updated$leaf_prior_calibration_param
175185
variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init
176186
a_forest <- variance_forest_params_updated$var_forest_prior_shape
177187
b_forest <- variance_forest_params_updated$var_forest_prior_scale
188+
keep_vars_variance <- variance_forest_params_updated$keep_vars
189+
drop_vars_variance <- variance_forest_params_updated$drop_vars
178190

179191
# Check if there are enough GFR samples to seed num_chains samplers
180192
if (num_gfr > 0) {
@@ -228,7 +240,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
228240

229241
# Set the variance forest priors if not set
230242
if (include_variance_forest) {
231-
a_0 <- 1.5
232243
if (is.null(a_forest)) a_forest <- num_trees_variance / (a_0^2) + 0.5
233244
if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2)
234245
} else {
@@ -240,21 +251,90 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
240251
if (!include_mean_forest) sample_sigma_leaf <- F
241252

242253
# Variable weight preprocessing (and initialization if necessary)
243-
if (include_mean_forest) {
244-
if (is.null(variable_weights_mean)) {
245-
variable_weights_mean = rep(1/ncol(X_train), ncol(X_train))
254+
if (is.null(variable_weights)) {
255+
variable_weights = rep(1/ncol(X_train), ncol(X_train))
256+
}
257+
if (any(variable_weights < 0)) {
258+
stop("variable_weights cannot have any negative weights")
259+
}
260+
261+
# Check covariates are matrix or dataframe
262+
if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) {
263+
stop("X_train must be a matrix or dataframe")
264+
}
265+
if (!is.null(X_test)){
266+
if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
267+
stop("X_test must be a matrix or dataframe")
246268
}
247-
if (any(variable_weights_mean < 0)) {
248-
stop("variable_weights_mean cannot have any negative weights")
269+
}
270+
num_cov_orig <- ncol(X_train)
271+
272+
# Standardize the keep variable lists to numeric indices
273+
if (!is.null(keep_vars_mean)) {
274+
if (is.character(keep_vars_mean)) {
275+
if (!all(keep_vars_mean %in% names(X_train))) {
276+
stop("keep_vars_mean includes some variable names that are not in X_train")
277+
}
278+
variable_subset_mu <- unname(which(names(X_train) %in% keep_vars_mean))
279+
} else {
280+
if (any(keep_vars_mean > ncol(X_train))) {
281+
stop("keep_vars_mean includes some variable indices that exceed the number of columns in X_train")
282+
}
283+
if (any(keep_vars_mean < 0)) {
284+
stop("keep_vars_mean includes some negative variable indices")
285+
}
286+
variable_subset_mu <- keep_vars_mean
249287
}
288+
} else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) {
289+
if (is.character(drop_vars_mean)) {
290+
if (!all(drop_vars_mean %in% names(X_train))) {
291+
stop("drop_vars_mean includes some variable names that are not in X_train")
292+
}
293+
variable_subset_mean <- unname(which(!(names(X_train) %in% drop_vars_mean)))
294+
} else {
295+
if (any(drop_vars_mean > ncol(X_train))) {
296+
stop("drop_vars_mean includes some variable indices that exceed the number of columns in X_train")
297+
}
298+
if (any(drop_vars_mean < 0)) {
299+
stop("drop_vars_mean includes some negative variable indices")
300+
}
301+
variable_subset_mean <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_mean)]
302+
}
303+
} else {
304+
variable_subset_mean <- 1:ncol(X_train)
250305
}
251-
if (include_variance_forest) {
252-
if (is.null(variable_weights_variance)) {
253-
variable_weights_variance = rep(1/ncol(X_train), ncol(X_train))
306+
if (!is.null(keep_vars_variance)) {
307+
if (is.character(keep_vars_variance)) {
308+
if (!all(keep_vars_variance %in% names(X_train))) {
309+
stop("keep_vars_variance includes some variable names that are not in X_train")
310+
}
311+
variable_subset_variance <- unname(which(names(X_train) %in% keep_vars_variance))
312+
} else {
313+
if (any(keep_vars_variance > ncol(X_train))) {
314+
stop("keep_vars_variance includes some variable indices that exceed the number of columns in X_train")
315+
}
316+
if (any(keep_vars_variance < 0)) {
317+
stop("keep_vars_variance includes some negative variable indices")
318+
}
319+
variable_subset_variance <- keep_vars_variance
254320
}
255-
if (any(variable_weights_variance < 0)) {
256-
stop("variable_weights_variance cannot have any negative weights")
321+
} else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) {
322+
if (is.character(drop_vars_variance)) {
323+
if (!all(drop_vars_variance %in% names(X_train))) {
324+
stop("drop_vars_variance includes some variable names that are not in X_train")
325+
}
326+
variable_subset_variance <- unname(which(!(names(X_train) %in% drop_vars_variance)))
327+
} else {
328+
if (any(drop_vars_variance > ncol(X_train))) {
329+
stop("drop_vars_variance includes some variable indices that exceed the number of columns in X_train")
330+
}
331+
if (any(drop_vars_variance < 0)) {
332+
stop("drop_vars_variance includes some negative variable indices")
333+
}
334+
variable_subset_variance <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_variance)]
257335
}
336+
} else {
337+
variable_subset_variance <- 1:ncol(X_train)
258338
}
259339

260340
# Preprocess covariates
@@ -266,11 +346,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
266346
stop("X_test must be a matrix or dataframe")
267347
}
268348
}
269-
if ((ncol(X_train) != length(variable_weights_mean)) && (include_mean_forest)) {
270-
stop("length(variable_weights_mean) must equal ncol(X_train)")
271-
}
272-
if ((ncol(X_train) != length(variable_weights_variance)) && (include_variance_forest)) {
273-
stop("length(variable_weights_variance) must equal ncol(X_train)")
349+
if (ncol(X_train) != length(variable_weights)) {
350+
stop("length(variable_weights) must equal ncol(X_train)")
274351
}
275352
train_cov_preprocess_list <- preprocessTrainData(X_train)
276353
X_train_metadata <- train_cov_preprocess_list$metadata
@@ -280,14 +357,17 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
280357
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
281358

282359
# Update variable weights
360+
variable_weights_mean <- variable_weights_variance <- variable_weights
283361
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
284362
if (include_mean_forest) {
285363
variable_weights_mean <- variable_weights_mean[original_var_indices]*variable_weights_adj
364+
variable_weights_mean[!(original_var_indices %in% variable_subset_mean)] <- 0
286365
}
287366
if (include_variance_forest) {
288367
variable_weights_variance <- variable_weights_variance[original_var_indices]*variable_weights_adj
368+
variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0
289369
}
290-
370+
291371
# Convert all input data to matrices if not already converted
292372
if ((is.null(dim(W_train))) && (!is.null(W_train))) {
293373
W_train <- as.matrix(W_train)

0 commit comments

Comments
 (0)