Skip to content

Commit cce5966

Browse files
committed
Added heteroskedasticity to BCF
1 parent 2d5b474 commit cce5966

File tree

8 files changed

+339
-99
lines changed

8 files changed

+339
-99
lines changed

R/bart.R

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#' that were not in the training set.
2626
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
2727
#' @param cutpoint_grid_size Maximum size of the "grid" of potential cutpoints to consider. Default: 100.
28-
#' @param tau_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here.
28+
#' @param sigma_leaf_init Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees_mean` if not set here.
2929
#' @param leaf_model Model to use in the leaves, coded as integer with (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
3030
#' @param alpha_mean Prior probability of splitting for a tree of depth 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: 0.95.
3131
#' @param beta_mean Exponent that decreases split probabilities for nodes of depth > 0 in the mean model. Tree split prior combines `alpha_mean` and `beta_mean` via `alpha_mean*(1+node_depth)^-beta_mean`. Default: 2.
@@ -91,7 +91,7 @@
9191
bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
9292
rfx_basis_train = NULL, X_test = NULL, W_test = NULL,
9393
group_ids_test = NULL, rfx_basis_test = NULL,
94-
cutpoint_grid_size = 100, tau_init = NULL,
94+
cutpoint_grid_size = 100, sigma_leaf_init = NULL,
9595
alpha_mean = 0.95, beta_mean = 2.0, min_samples_leaf_mean = 5,
9696
max_depth_mean = 10, alpha_variance = 0.95, beta_variance = 2.0,
9797
min_samples_leaf_variance = 5, max_depth_variance = 10,
@@ -100,7 +100,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
100100
variance_forest_init = NULL, pct_var_sigma2_init = 1,
101101
pct_var_variance_forest_init = 1, variance_scale = 1,
102102
variable_weights_mean = NULL, variable_weights_variance = NULL,
103-
num_trees_mean = 200, num_trees_variance = 20,
103+
num_trees_mean = 200, num_trees_variance = 0,
104104
num_gfr = 5, num_burnin = 0, num_mcmc = 100,
105105
sample_sigma_global = T, sample_sigma_leaf = F, random_seed = -1,
106106
keep_burnin = F, keep_gfr = F, verbose = F) {
@@ -115,6 +115,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
115115
a_0 <- 1.5
116116
if (is.null(a_forest)) a_forest <- num_trees_variance / (a_0^2) + 0.5
117117
if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2)
118+
} else {
119+
a_forest <- 1.
120+
b_forest <- 1.
118121
}
119122

120123
# Override tau sampling if there is no mean forest
@@ -274,8 +277,8 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
274277
if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train)
275278
if (is.null(variance_forest_init)) variance_forest_init <- pct_var_variance_forest_init*var(resid_train)
276279
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
277-
if (is.null(tau_init)) tau_init <- var(resid_train)/(num_trees_mean)
278-
current_leaf_scale <- as.matrix(tau_init)
280+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
281+
current_leaf_scale <- as.matrix(sigma_leaf_init)
279282
current_sigma2 <- sigma2_init
280283

281284
# Determine leaf model type
@@ -542,7 +545,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
542545
# Leaf parameter variance
543546
if (sample_sigma_leaf) tau_samples <- leaf_scale_samples[keep_indices]
544547

545-
# Rescale variance forest prediction by sigma2_samples
548+
# Rescale variance forest prediction by global sigma2 (sampled or constant)
546549
if (include_variance_forest) {
547550
if (sample_sigma_global) {
548551
sigma_x_hat_train <- sapply(1:length(keep_indices), function(i) sqrt(sigma_x_hat_train[,i]*sigma2_samples[i]))
@@ -557,9 +560,9 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
557560
# TODO: store variance_scale and propagate through predict function
558561
model_params <- list(
559562
"sigma2_init" = sigma2_init,
563+
"sigma_leaf_init" = sigma_leaf_init,
560564
"a_global" = a_global,
561565
"b_global" = b_global,
562-
"tau_init" = tau_init,
563566
"a_leaf" = a_leaf,
564567
"b_leaf" = b_leaf,
565568
"a_forest" = a_forest,
@@ -598,7 +601,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
598601
if (has_test) result[["y_hat_test"]] = y_hat_test
599602
}
600603
if (include_variance_forest) {
601-
result[["var_forests"]] = forest_samples_variance
604+
result[["variance_forests"]] = forest_samples_variance
602605
result[["sigma_x_hat_train"]] = sigma_x_hat_train
603606
if (has_test) result[["sigma_x_hat_test"]] = sigma_x_hat_test
604607
}
@@ -634,7 +637,6 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
634637
#' We do not currently support (but plan to in the near future), test set evaluation for group labels
635638
#' that were not in the training set.
636639
#' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model.
637-
#' @param predict_all (Optional) Whether to predict the model for all of the samples in the stored objects or the subset of burnt-in / GFR samples as specified at training time. Default FALSE.
638640
#'
639641
#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest.
640642
#' If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (`y_hat`).
@@ -665,7 +667,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
665667
#' y_hat_test <- predict(bart_model, X_test)
666668
#' # plot(rowMeans(y_hat_test), y_test, xlab = "predicted", ylab = "actual")
667669
#' # abline(0,1,col="red",lty=3,lwd=3)
668-
predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL, predict_all = F){
670+
predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL, rfx_basis_test = NULL){
669671
# Preprocess covariates
670672
if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) {
671673
stop("X_test must be a matrix or dataframe")
@@ -726,11 +728,14 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL
726728
variance_scale <- bart$model_params$variance_scale
727729
y_std <- bart$model_params$outcome_scale
728730
y_bar <- bart$model_params$outcome_mean
729-
mean_forest_predictions <- bart$mean_forests$predict(prediction_dataset)*y_std/sqrt(variance_scale) + y_bar
731+
sigma2_init <- bart$model_params$sigma2_init
732+
if (bart$model_params$include_mean_forest) {
733+
mean_forest_predictions <- bart$mean_forests$predict(prediction_dataset)*y_std/sqrt(variance_scale) + y_bar
734+
}
730735

731736
# Compute variance forest predictions
732737
if (bart$model_params$include_variance_forest) {
733-
var_forest_predictions <- bart$variance_forests$predict(prediction_dataset)*(y_std^2)/variance_scale
738+
s_x_raw <- bart$variance_forests$predict(prediction_dataset)
734739
}
735740

736741
# Compute rfx predictions (if needed)
@@ -739,30 +744,43 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL
739744
}
740745

741746
# Restrict predictions to the "retained" samples (if applicable)
742-
if (!predict_all) {
743-
keep_indices = bart$keep_indices
747+
keep_indices = bart$keep_indices
748+
if (bart$model_params$include_mean_forest) {
744749
mean_forest_predictions <- mean_forest_predictions[,keep_indices]
745-
if (bart$model_params$include_variance_forest) {
746-
variance_forest_predictions <- variance_forest_predictions[,keep_indices]
747-
}
748-
if (bart$model_params$has_rfx) rfx_predictions <- rfx_predictions[,keep_indices]
749750
}
751+
if (bart$model_params$include_variance_forest) {
752+
s_x_raw <- s_x_raw[,keep_indices]
753+
}
754+
if (bart$model_params$has_rfx) rfx_predictions <- rfx_predictions[,keep_indices]
750755

751-
if (bart$model_params$has_rfx) {
756+
# Scale variance forest predictions
757+
if (bart$model_params$include_variance_forest) {
758+
if (bart$model_params$sample_sigma_global) {
759+
sigma2_samples <- bart$sigma2_global_samples
760+
variance_forest_predictions <- sapply(1:length(keep_indices), function(i) sqrt(s_x_raw[,i]*sigma2_samples[i]))
761+
} else {
762+
variance_forest_predictions <- sqrt(s_x_raw*sigma2_init)*y_std/sqrt(variance_scale)
763+
}
764+
}
765+
766+
if ((bart$model_params$include_mean_forest) && (bart$model_params$has_rfx)) {
752767
y_hat <- mean_forest_predictions + rfx_predictions
753-
} else {
768+
} else if ((bart$model_params$include_mean_forest) && (!bart$model_params$has_rfx)) {
754769
y_hat <- mean_forest_predictions
755-
}
756-
757-
result <- list(
758-
"y_hat" = y_hat,
759-
"mean_forest_predictions" = mean_forest_predictions
760-
)
770+
} else if ((!bart$model_params$include_mean_forest) && (bart$model_params$has_rfx)) {
771+
y_hat <- rfx_predictions
772+
}
761773

774+
result <- list()
775+
if ((bart$model_params$has_rfx) || (bart$model_params$include_mean_forest)) {
776+
result[["y_hat"]] = y_hat
777+
}
778+
if (bart$model_params$include_mean_forest) {
779+
result[["mean_forest_predictions"]] = mean_forest_predictions
780+
}
762781
if (bart$model_params$has_rfx) {
763782
result[["rfx_predictions"]] = rfx_predictions
764783
}
765-
766784
if (bart$model_params$include_variance_forest) {
767785
result[["variance_forest_predictions"]] = variance_forest_predictions
768786
}

0 commit comments

Comments
 (0)