Skip to content

Commit abe8aa6

Browse files
authored
Merge pull request #80 from StochasticTree/variance-refactor
Refactor variance parameters initialization and prior setting
2 parents 0cc6580 + 9a4b46b commit abe8aa6

23 files changed

+397
-217
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ S3method(predict,bartmodel)
66
S3method(predict,bcf)
77
export(bart)
88
export(bcf)
9+
export(calibrate_inverse_gamma_error_variance)
910
export(computeForestKernels)
1011
export(computeForestLeafIndices)
1112
export(convertBCFModelToJson)

R/bart.R

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,19 @@
3131
#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
3232
#' @param min_samples_leaf Minimum allowable size of a leaf, in terms of training samples. Default: 5.
3333
#' @param max_depth Maximum depth of any tree in the ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
34-
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
35-
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
34+
#' @param a_global Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
35+
#' @param b_global Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
3636
#' @param a_leaf Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Default: 3.
3737
#' @param b_leaf Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here.
3838
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
39-
#' @param sigma2_init Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
39+
#' @param sigma2_init Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set.
40+
#' @param pct_var_sigma2_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by `sigma2_init`.
4041
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here.
4142
#' @param num_trees Number of trees in the ensemble. Default: 200.
4243
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
4344
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
4445
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
45-
#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T.
46+
#' @param sample_sigma Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_globa, b_global)`. Default: T.
4647
#' @param sample_tau Whether or not to update the `tau` leaf scale variance parameter based on `IG(a_leaf, b_leaf)`. Cannot (currently) be set to true if `ncol(W_train)>1`. Default: T.
4748
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
4849
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
@@ -81,12 +82,12 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
8182
group_ids_test = NULL, rfx_basis_test = NULL,
8283
cutpoint_grid_size = 100, tau_init = NULL, alpha = 0.95,
8384
beta = 2.0, min_samples_leaf = 5, max_depth = 10, leaf_model = 0,
84-
nu = 3, lambda = NULL, a_leaf = 3, b_leaf = NULL,
85-
q = 0.9, sigma2_init = NULL, variable_weights = NULL,
86-
num_trees = 200, num_gfr = 5, num_burnin = 0,
87-
num_mcmc = 100, sample_sigma = T, sample_tau = T,
88-
random_seed = -1, keep_burnin = F, keep_gfr = F,
89-
verbose = F){
85+
a_global = 0, b_global = 0, a_leaf = 3, b_leaf = NULL,
86+
q = 0.9, sigma2_init = NULL, pct_var_sigma2_init = 0.25,
87+
variable_weights = NULL, num_trees = 200, num_gfr = 5,
88+
num_burnin = 0, num_mcmc = 100, sample_sigma = T,
89+
sample_tau = T, random_seed = -1, keep_burnin = F,
90+
keep_gfr = F, verbose = F) {
9091
# Variable weight preprocessing (and initialization if necessary)
9192
if (is.null(variable_weights)) {
9293
variable_weights = rep(1/ncol(X_train), ncol(X_train))
@@ -216,13 +217,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
216217
resid_train <- (y_train-y_bar_train)/y_std_train
217218

218219
# Calibrate priors for sigma^2 and tau
219-
reg_basis <- cbind(W_train, X_train)
220-
sigma2hat <- (sigma(lm(resid_train~reg_basis)))^2
221-
quantile_cutoff <- 0.9
222-
if (is.null(lambda)) {
223-
lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu
224-
}
225-
if (is.null(sigma2_init)) sigma2_init <- sigma2hat
220+
if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train)
226221
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees)
227222
if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees)
228223
current_leaf_scale <- as.matrix(tau_init)
@@ -331,7 +326,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
331326
current_sigma2, cutpoint_grid_size, gfr = T, pre_initialized = F
332327
)
333328
if (sample_sigma) {
334-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
329+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
335330
current_sigma2 <- global_var_samples[i]
336331
}
337332
if (sample_tau) {
@@ -373,7 +368,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
373368
current_sigma2, cutpoint_grid_size, gfr = F, pre_initialized = F
374369
)
375370
if (sample_sigma) {
376-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
371+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
377372
current_sigma2 <- global_var_samples[i]
378373
}
379374
if (sample_tau) {
@@ -442,11 +437,11 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
442437
# Return results as a list
443438
model_params <- list(
444439
"sigma2_init" = sigma2_init,
445-
"nu" = nu,
446-
"lambda" = lambda,
440+
"a_global" = a_global,
441+
"b_global" = b_global,
447442
"tau_init" = tau_init,
448-
"a" = a_leaf,
449-
"b" = b_leaf,
443+
"a_leaf" = a_leaf,
444+
"b_leaf" = b_leaf,
450445
"outcome_mean" = y_bar_train,
451446
"outcome_scale" = y_std_train,
452447
"output_dimension" = output_dimension,

R/bcf.R

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@
3232
#' @param min_samples_leaf_tau Minimum allowable size of a leaf, in terms of training samples, for the treatment effect forest. Default: 5.
3333
#' @param max_depth_mu Maximum depth of any tree in the mu ensemble. Default: 10. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
3434
#' @param max_depth_tau Maximum depth of any tree in the tau ensemble. Default: 5. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
35-
#' @param nu Shape parameter in the `IG(nu, nu*lambda)` global error variance model. Default: 3.
36-
#' @param lambda Component of the scale parameter in the `IG(nu, nu*lambda)` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
35+
#' @param a_global Shape parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
36+
#' @param b_global Scale parameter in the `IG(a_global, b_global)` global error variance model. Default: 0.
3737
#' @param a_leaf_mu Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Default: 3.
3838
#' @param a_leaf_tau Shape parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Default: 3.
3939
#' @param b_leaf_mu Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the prognostic forest. Calibrated internally as 0.5/num_trees if not set here.
4040
#' @param b_leaf_tau Scale parameter in the `IG(a_leaf, b_leaf)` leaf node parameter variance model for the treatment effect forest. Calibrated internally as 0.5/num_trees if not set here.
4141
#' @param q Quantile used to calibrated `lambda` as in Sparapani et al (2021). Default: 0.9.
42-
#' @param sigma2 Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
42+
#' @param sigma2 Starting value of global error variance parameter. Calibrated internally as `pct_var_sigma2_init*var((y-mean(y))/sd(y))` if not set.
43+
#' @param pct_var_sigma2_init Percentage of standardized outcome variance used to initialize global error variance parameter. Default: 0.25. Superseded by `sigma2`.
4344
#' @param variable_weights Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to `rep(1/ncol(X_train), ncol(X_train))` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to `1/ncol(X_train)`. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in `X_train` and then set `propensity_covariate` to `'none'` adjust `keep_vars_mu` and `keep_vars_tau` accordingly.
4445
#' @param keep_vars_mu Vector of variable names or column indices denoting variables that should be included in the prognostic (`mu(X)`) forest. Default: NULL.
4546
#' @param drop_vars_mu Vector of variable names or column indices denoting variables that should be excluded from the prognostic (`mu(X)`) forest. Default: NULL. If both `drop_vars_mu` and `keep_vars_mu` are set, `drop_vars_mu` will be ignored.
@@ -50,7 +51,7 @@
5051
#' @param num_gfr Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5.
5152
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
5253
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
53-
#' @param sample_sigma_global Whether or not to update the `sigma^2` global error variance parameter based on `IG(nu, nu*lambda)`. Default: T.
54+
#' @param sample_sigma_global Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: T.
5455
#' @param sample_sigma_leaf_mu Whether or not to update the `sigma_leaf_mu` leaf scale variance parameter in the prognostic forest based on `IG(a_leaf_mu, b_leaf_mu)`. Default: T.
5556
#' @param sample_sigma_leaf_tau Whether or not to update the `sigma_leaf_tau` leaf scale variance parameter in the treatment effect forest based on `IG(a_leaf_tau, b_leaf_tau)`. Default: T.
5657
#' @param propensity_covariate Whether to include the propensity score as a covariate in either or both of the forests. Enter "none" for neither, "mu" for the prognostic forest, "tau" for the treatment forest, and "both" for both forests. If this is not "none" and a propensity score is not provided, it will be estimated from (`X_train`, `Z_train`) using `stochtree::bart()`. Default: "mu".
@@ -118,11 +119,11 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
118119
group_ids_test = NULL, rfx_basis_test = NULL, cutpoint_grid_size = 100,
119120
sigma_leaf_mu = NULL, sigma_leaf_tau = NULL, alpha_mu = 0.95, alpha_tau = 0.25,
120121
beta_mu = 2.0, beta_tau = 3.0, min_samples_leaf_mu = 5, min_samples_leaf_tau = 5,
121-
max_depth_mu = 10, max_depth_tau = 5, nu = 3, lambda = NULL, a_leaf_mu = 3, a_leaf_tau = 3,
122-
b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, variable_weights = NULL,
123-
keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL, drop_vars_tau = NULL,
124-
num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0, num_mcmc = 100,
125-
sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F,
122+
max_depth_mu = 10, max_depth_tau = 5, a_global = 0, b_global = 0, a_leaf_mu = 3, a_leaf_tau = 3,
123+
b_leaf_mu = NULL, b_leaf_tau = NULL, q = 0.9, sigma2 = NULL, pct_var_sigma2_init = 0.25,
124+
variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL, keep_vars_tau = NULL,
125+
drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50, num_gfr = 5, num_burnin = 0,
126+
num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F,
126127
propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5,
127128
rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
128129
# Variable weight preprocessing (and initialization if necessary)
@@ -413,13 +414,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
413414
resid_train <- (y_train-y_bar_train)/y_std_train
414415

415416
# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau
416-
reg_basis <- X_train
417-
sigma2hat <- mean(resid(lm(resid_train~reg_basis))^2)
418-
quantile_cutoff <- 0.9
419-
if (is.null(lambda)) {
420-
lambda <- (sigma2hat*qgamma(1-quantile_cutoff,nu))/nu
421-
}
422-
if (is.null(sigma2)) sigma2 <- sigma2hat
417+
if (is.null(sigma2)) sigma2 <- pct_var_sigma2_init*var(resid_train)
423418
if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu)
424419
if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau)
425420
if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu)
@@ -506,16 +501,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
506501
# Initialize the leaves of each tree in the prognostic forest
507502
forest_samples_mu$set_root_leaves(0, mean(resid_train) / num_trees_mu)
508503
forest_samples_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, F, 0, F)
509-
# adjust_residual_forest_container_cpp(forest_dataset_train$data_ptr, outcome_train$data_ptr,
510-
# forest_samples_mu$forest_container_ptr, forest_model_mu$tracker_ptr,
511-
# F, 0, F)
512504

513505
# Initialize the leaves of each tree in the treatment effect forest
514506
forest_samples_tau$set_root_leaves(0, 0.)
515507
forest_samples_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, T, 0, F)
516-
# adjust_residual_forest_container_cpp(forest_dataset_train$data_ptr, outcome_train$data_ptr,
517-
# forest_samples_tau$forest_container_ptr, forest_model_tau$tracker_ptr,
518-
# T, 0, F)
519508

520509
# Run GFR (warm start) if specified
521510
if (num_gfr > 0){
@@ -537,7 +526,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
537526

538527
# Sample variance parameters (if requested)
539528
if (sample_sigma_global) {
540-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
529+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
541530
current_sigma2 <- global_var_samples[i]
542531
}
543532
if (sample_sigma_leaf_mu) {
@@ -589,7 +578,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
589578

590579
# Sample variance parameters (if requested)
591580
if (sample_sigma_global) {
592-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
581+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
593582
current_sigma2 <- global_var_samples[i]
594583
}
595584
if (sample_sigma_leaf_tau) {
@@ -636,7 +625,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
636625

637626
# Sample variance parameters (if requested)
638627
if (sample_sigma_global) {
639-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
628+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
640629
current_sigma2 <- global_var_samples[i]
641630
}
642631
if (sample_sigma_leaf_mu) {
@@ -688,7 +677,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
688677

689678
# Sample variance parameters (if requested)
690679
if (sample_sigma_global) {
691-
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, nu, lambda)
680+
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
692681
current_sigma2 <- global_var_samples[i]
693682
}
694683
if (sample_sigma_leaf_tau) {
@@ -786,8 +775,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
786775
"initial_sigma_leaf_tau" = sigma_leaf_tau,
787776
"initial_b_0" = b_0,
788777
"initial_b_1" = b_1,
789-
"nu" = nu,
790-
"lambda" = lambda,
778+
"a_global" = a_global,
779+
"b_global" = b_global,
791780
"a_leaf_mu" = a_leaf_mu,
792781
"b_leaf_mu" = b_leaf_mu,
793782
"a_leaf_tau" = a_leaf_tau,

R/calibration.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) [1]
2+
#'
3+
#' [1] Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288
4+
#'
5+
#' @param y Outcome to be modeled using BART, BCF or another nonparametric ensemble method.
6+
#' @param X Covariates to be used to partition trees in an ensemble or series of ensemble.
7+
#' @param W [Optional] Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`.
8+
#' @param nu The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as `nu*lambda` where `lambda` is the output of this function. Default: `3`.
9+
#' @param quant [Optional] Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`.
10+
#' @param standardize [Optional] Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`.
11+
#'
12+
#' @return Value of `lambda` which determines the scale parameter of the global error variance prior (`sigma^2 ~ IG(nu,nu*lambda)`)
13+
#' @export
14+
#'
15+
#' @examples
16+
#' n <- 100
17+
#' p <- 5
18+
#' X <- matrix(runif(n*p), ncol = p)
19+
#' y <- 10*X[,1] - 20*X[,2] + rnorm(n)
20+
#' nu <- 3
21+
#' lambda <- calibrate_inverse_gamma_error_variance(y, X, nu = nu)
22+
#' sigma2hat <- mean(resid(lm(y~X))^2)
23+
#' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat)
24+
calibrate_inverse_gamma_error_variance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) {
25+
# Compute regression basis
26+
if (!is.null(W)) basis <- cbind(X, W)
27+
else basis <- X
28+
# Standardize outcome if requested
29+
if (standardize) y <- (y-mean(y))/sd(y)
30+
# Compute the "regression-based" overestimate of sigma^2
31+
sigma2hat <- mean(resid(lm(y~basis))^2)
32+
# Calibrate lambda based on the implied quantile of sigma2hat
33+
return((sigma2hat*qgamma(1-quant,nu))/nu)
34+
}

0 commit comments

Comments
 (0)