From ab4063f92e257246bc4018d8c221baf00d12c2f8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 23 Sep 2025 13:52:38 -0500 Subject: [PATCH] Format R code using air --- R/bart.R | 2111 +++++++++++++++++++-------- R/bcf.R | 3153 +++++++++++++++++++++++++++++------------ R/calibration.R | 28 +- R/config.R | 286 ++-- R/data.R | 175 ++- R/forest.R | 1038 +++++++++----- R/generics.R | 8 +- R/kernel.R | 251 ++-- R/model.R | 283 ++-- R/random_effects.R | 379 +++-- R/serialization.R | 450 ++++-- R/stochtree-package.R | 2 +- R/utils.R | 566 +++++--- R/variance.R | 20 +- 14 files changed, 6093 insertions(+), 2657 deletions(-) diff --git a/R/bart.R b/R/bart.R index 7fe1c3de..83e2f828 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1,26 +1,26 @@ -#' Run the BART algorithm for supervised learning. +#' Run the BART algorithm for supervised learning. #' -#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. -#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be -#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, -#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata #' that the column is ordered categorical). #' @param y_train Outcome to be modeled by the ensemble. -#' @param leaf_basis_train (Optional) Bases used to define a regression model `y ~ W` in -#' each leaf of each regression tree. By default, BART assumes constant leaf node +#' @param leaf_basis_train (Optional) Bases used to define a regression model `y ~ W` in +#' each leaf of each regression tree. By default, BART assumes constant leaf node #' parameters, implicitly regressing on a constant basis of ones (i.e. `y ~ 1`). #' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model. #' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model. -#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model +#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model #' will be estimated. -#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. -#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with #' that of `X_train`. -#' @param leaf_basis_test (Optional) Test set of bases used to define "out of sample" evaluation data. -#' While a test set is optional, the structure of any provided test set must match that -#' of the training set (i.e. if both `X_train` and `leaf_basis_train` are provided, then a test set must +#' @param leaf_basis_test (Optional) Test set of bases used to define "out of sample" evaluation data. +#' While a test set is optional, the structure of any provided test set must match that +#' of the training set (i.e. if both `X_train` and `leaf_basis_train` are provided, then a test set must #' consist of `X_test` and `leaf_basis_test` with the same number of columns). -#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model. +#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. @@ -82,7 +82,7 @@ #' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. -#' +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -91,9 +91,9 @@ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -107,24 +107,43 @@ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' -#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +#' +#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) -bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL, - rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL, - rfx_group_ids_test = NULL, rfx_basis_test = NULL, - num_gfr = 5, num_burnin = 0, num_mcmc = 100, - previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, - general_params = list(), mean_forest_params = list(), - variance_forest_params = list()) { +bart <- function( + X_train, + y_train, + leaf_basis_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + leaf_basis_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + mean_forest_params = list(), + variance_forest_params = list() +) { # Update general BART parameters general_params_default <- list( - cutpoint_grid_size = 100, standardize = TRUE, - sample_sigma2_global = TRUE, sigma2_global_init = NULL, - sigma2_global_shape = 0, sigma2_global_scale = 0, - variable_weights = NULL, random_seed = -1, - keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1, - num_chains = 1, verbose = FALSE, + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, probit_outcome_model = FALSE, rfx_working_parameter_prior_mean = NULL, rfx_group_parameter_prior_mean = NULL, @@ -135,37 +154,50 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train num_threads = -1 ) general_params_updated <- preprocessParams( - general_params_default, general_params + general_params_default, + general_params ) - + # Update mean forest BART parameters mean_forest_params_default <- list( - num_trees = 200, alpha = 0.95, beta = 2.0, - min_samples_leaf = 5, max_depth = 10, - sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL, + num_trees = 200, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, num_features_subsample = NULL ) mean_forest_params_updated <- preprocessParams( - mean_forest_params_default, mean_forest_params + mean_forest_params_default, + mean_forest_params ) - + # Update variance forest BART parameters variance_forest_params_default <- list( - num_trees = 0, alpha = 0.95, beta = 2.0, - min_samples_leaf = 5, max_depth = 10, - leaf_prior_calibration_param = 1.5, + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, var_forest_leaf_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, drop_vars = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, num_features_subsample = NULL ) variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, variance_forest_params + variance_forest_params_default, + variance_forest_params ) - + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size @@ -189,7 +221,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale num_threads <- general_params_updated$num_threads - + # 2. Mean forest parameters num_trees_mean <- mean_forest_params_updated$num_trees alpha_mean <- mean_forest_params_updated$alpha @@ -203,7 +235,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train keep_vars_mean <- mean_forest_params_updated$keep_vars drop_vars_mean <- mean_forest_params_updated$drop_vars num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample - + # 3. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees alpha_variance <- variance_forest_params_updated$alpha @@ -217,43 +249,60 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train keep_vars_variance <- variance_forest_params_updated$keep_vars drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - + # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { if (num_chains > num_gfr) { - stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) } } - + # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) keep_gfr <- TRUE - + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + # Check if previous model JSON is provided and parse it if so has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { - previous_bart_model <- createBARTModelFromJsonString(previous_model_json) + previous_bart_model <- createBARTModelFromJsonString( + previous_model_json + ) previous_y_bar <- previous_bart_model$model_params$outcome_mean previous_y_scale <- previous_bart_model$model_params$outcome_scale if (previous_bart_model$model_params$include_mean_forest) { previous_forest_samples_mean <- previous_bart_model$mean_forests - } else previous_forest_samples_mean <- NULL + } else { + previous_forest_samples_mean <- NULL + } if (previous_bart_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bart_model$variance_forests - } else previous_forest_samples_variance <- NULL + } else { + previous_forest_samples_variance <- NULL + } if (previous_bart_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bart_model$sigma2_global_samples / ( - previous_y_scale*previous_y_scale - ) - } else previous_global_var_samples <- NULL + previous_global_var_samples <- previous_bart_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) + } else { + previous_global_var_samples <- NULL + } if (previous_bart_model$model_params$sample_sigma2_leaf) { previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples - } else previous_leaf_var_samples <- NULL + } else { + previous_leaf_var_samples <- NULL + } if (previous_bart_model$model_params$has_rfx) { previous_rfx_samples <- previous_bart_model$rfx_samples - } else previous_rfx_samples <- NULL + } else { + previous_rfx_samples <- NULL + } previous_model_num_samples <- previous_bart_model$model_params$num_samples if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) } } else { previous_y_bar <- NULL @@ -265,28 +314,38 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train previous_forest_samples_variance <- NULL previous_model_num_samples <- 0 } - + # Determine whether conditional mean, variance, or both will be modeled - if (num_trees_variance > 0) include_variance_forest = TRUE - else include_variance_forest = FALSE - if (num_trees_mean > 0) include_mean_forest = TRUE - else include_mean_forest = FALSE - + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } + if (num_trees_mean > 0) { + include_mean_forest = TRUE + } else { + include_mean_forest = FALSE + } + # Set the variance forest priors if not set if (include_variance_forest) { - if (is.null(a_forest)) a_forest <- num_trees_variance / (a_0^2) + 0.5 + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) } else { a_forest <- 1. b_forest <- 1. } - + # Override tau sampling if there is no mean forest - if (!include_mean_forest) sample_sigma2_leaf <- FALSE - + if (!include_mean_forest) { + sample_sigma2_leaf <- FALSE + } + # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { - variable_weights = rep(1/ncol(X_train), ncol(X_train)) + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) } if (any(variable_weights < 0)) { stop("variable_weights cannot have any negative weights") @@ -296,7 +355,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { stop("X_train must be a matrix or dataframe") } - if (!is.null(X_test)){ + if (!is.null(X_test)) { if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { stop("X_test must be a matrix or dataframe") } @@ -307,12 +366,18 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (!is.null(keep_vars_mean)) { if (is.character(keep_vars_mean)) { if (!all(keep_vars_mean %in% names(X_train))) { - stop("keep_vars_mean includes some variable names that are not in X_train") + stop( + "keep_vars_mean includes some variable names that are not in X_train" + ) } - variable_subset_mu <- unname(which(names(X_train) %in% keep_vars_mean)) + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mean + )) } else { if (any(keep_vars_mean > ncol(X_train))) { - stop("keep_vars_mean includes some variable indices that exceed the number of columns in X_train") + stop( + "keep_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) } if (any(keep_vars_mean < 0)) { stop("keep_vars_mean includes some negative variable indices") @@ -322,17 +387,25 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) { if (is.character(drop_vars_mean)) { if (!all(drop_vars_mean %in% names(X_train))) { - stop("drop_vars_mean includes some variable names that are not in X_train") + stop( + "drop_vars_mean includes some variable names that are not in X_train" + ) } - variable_subset_mean <- unname(which(!(names(X_train) %in% drop_vars_mean))) + variable_subset_mean <- unname(which( + !(names(X_train) %in% drop_vars_mean) + )) } else { if (any(drop_vars_mean > ncol(X_train))) { - stop("drop_vars_mean includes some variable indices that exceed the number of columns in X_train") + stop( + "drop_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) } if (any(drop_vars_mean < 0)) { stop("drop_vars_mean includes some negative variable indices") } - variable_subset_mean <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_mean)] + variable_subset_mean <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mean) + ] } } else { variable_subset_mean <- 1:ncol(X_train) @@ -340,42 +413,62 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (!is.null(keep_vars_variance)) { if (is.character(keep_vars_variance)) { if (!all(keep_vars_variance %in% names(X_train))) { - stop("keep_vars_variance includes some variable names that are not in X_train") + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) } - variable_subset_variance <- unname(which(names(X_train) %in% keep_vars_variance)) + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) } else { if (any(keep_vars_variance > ncol(X_train))) { - stop("keep_vars_variance includes some variable indices that exceed the number of columns in X_train") + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) } if (any(keep_vars_variance < 0)) { - stop("keep_vars_variance includes some negative variable indices") + stop( + "keep_vars_variance includes some negative variable indices" + ) } variable_subset_variance <- keep_vars_variance } - } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + } else if ( + (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) + ) { if (is.character(drop_vars_variance)) { if (!all(drop_vars_variance %in% names(X_train))) { - stop("drop_vars_variance includes some variable names that are not in X_train") + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) } - variable_subset_variance <- unname(which(!(names(X_train) %in% drop_vars_variance))) + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) } else { if (any(drop_vars_variance > ncol(X_train))) { - stop("drop_vars_variance includes some variable indices that exceed the number of columns in X_train") + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) } if (any(drop_vars_variance < 0)) { - stop("drop_vars_variance includes some negative variable indices") + stop( + "drop_vars_variance includes some negative variable indices" + ) } - variable_subset_variance <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_variance)] + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] } } else { variable_subset_variance <- 1:ncol(X_train) } - + # Preprocess covariates if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { stop("X_train must be a matrix or dataframe") } - if (!is.null(X_test)){ + if (!is.null(X_test)) { if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { stop("X_test must be a matrix or dataframe") } @@ -388,20 +481,31 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train X_train <- train_cov_preprocess_list$data original_var_indices <- X_train_metadata$original_var_indices feature_types <- X_train_metadata$feature_types - if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) - + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + # Update variable weights variable_weights_mean <- variable_weights_variance <- variable_weights - variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) if (include_mean_forest) { - variable_weights_mean <- variable_weights_mean[original_var_indices]*variable_weights_adj - variable_weights_mean[!(original_var_indices %in% variable_subset_mean)] <- 0 + variable_weights_mean <- variable_weights_mean[original_var_indices] * + variable_weights_adj + variable_weights_mean[ + !(original_var_indices %in% variable_subset_mean) + ] <- 0 } if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance[original_var_indices]*variable_weights_adj - variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0 + variable_weights_variance <- variable_weights_variance[ + original_var_indices + ] * + variable_weights_adj + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 } - + # Set num_features_subsample to default, ncol(X_train), if not already set if (is.null(num_features_subsample_mean)) { num_features_subsample_mean <- ncol(X_train) @@ -423,7 +527,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { rfx_basis_test <- as.matrix(rfx_basis_test) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE has_rfx_test <- FALSE @@ -432,62 +536,98 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train rfx_group_ids_train <- as.integer(group_ids_factor) has_rfx <- TRUE if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor(rfx_group_ids_test, levels = levels(group_ids_factor)) + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) if (sum(is.na(group_ids_factor_test)) > 0) { - stop("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) } rfx_group_ids_test <- as.integer(group_ids_factor_test) has_rfx_test <- TRUE } } - + # Data consistency checks if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { stop("X_train and X_test must have the same number of columns") } - if ((!is.null(leaf_basis_test)) && (ncol(leaf_basis_test) != ncol(leaf_basis_train))) { - stop("leaf_basis_train and leaf_basis_test must have the same number of columns") - } - if ((!is.null(leaf_basis_train)) && (nrow(leaf_basis_train) != nrow(X_train))) { + if ( + (!is.null(leaf_basis_test)) && + (ncol(leaf_basis_test) != ncol(leaf_basis_train)) + ) { + stop( + "leaf_basis_train and leaf_basis_test must have the same number of columns" + ) + } + if ( + (!is.null(leaf_basis_train)) && + (nrow(leaf_basis_train) != nrow(X_train)) + ) { stop("leaf_basis_train and X_train must have the same number of rows") } - if ((!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test))) { + if ( + (!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test)) + ) { stop("leaf_basis_test and X_test must have the same number of rows") } if (nrow(X_train) != length(y_train)) { stop("X_train and y_train must have the same number of observations") } - if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) { - stop("rfx_basis_train and rfx_basis_test must have the same number of columns") + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" + ) } if (!is.null(rfx_group_ids_train)) { if (!is.null(rfx_group_ids_test)) { if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop("rfx_basis_train is provided but rfx_basis_test is not provided") + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) } } } - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx <- FALSE num_basis_rfx <- 0 if (has_rfx) { if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1) + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) } else { has_basis_rfx <- TRUE num_basis_rfx <- ncol(rfx_basis_train) } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill") + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + ) + } } if (has_rfx_test) { if (is.null(rfx_basis_test)) { if (has_basis_rfx) { - stop("Random effects basis provided for training set, must also be provided for the test set") + stop( + "Random effects basis provided for training set, must also be provided for the test set" + ) } - rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1) + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) } } @@ -495,30 +635,36 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (!is.null(dim(y_train))) { y_train <- as.matrix(y_train) } - + # Determine whether a basis vector is provided has_basis = !is.null(leaf_basis_train) - + # Determine whether a test set is provided has_test = !is.null(X_test) - + # Preliminary runtime checks for probit link if (!include_mean_forest) { probit_outcome_model <- FALSE } if (probit_outcome_model) { if (!(length(unique(y_train)) == 2)) { - stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) } unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0,1)))) { - stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } if (include_variance_forest) { stop("We do not support heteroskedasticity with a probit link") } if (sample_sigma2_global) { - warning("Global error variance will not be sampled with a probit link as it is fixed at 1") + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) sample_sigma2_global <- F } } @@ -532,25 +678,35 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Set a pseudo outcome by subtracting mean(y_train) from y_train resid_train <- y_train - mean(y_train) - + # Set initial values of root nodes to 0.0 (in probit scale) init_val_mean <- 0.0 - + # Calibrate priors for sigma^2 and tau # Set sigma2_init to 1, ignoring default provided sigma2_init <- 1.0 # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf <- 1/(num_trees_mean) + b_leaf <- 1 / (num_trees_mean) if (has_basis) { if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 / (num_trees_mean), + ncol(leaf_basis_train) + ) + } if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, ncol(leaf_basis_train))) + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { current_leaf_scale <- sigma2_leaf_init } } else { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2/(num_trees_mean)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } if (!is.matrix(sigma2_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { @@ -558,7 +714,9 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } } } else { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2/(num_trees_mean)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } if (!is.matrix(sigma2_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { @@ -575,27 +733,45 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train y_bar_train <- 0 y_std_train <- 1 } - + # Compute standardized outcome - resid_train <- (y_train-y_bar_train)/y_std_train - + resid_train <- (y_train - y_bar_train) / y_std_train + # Compute initial value of root nodes in mean forest init_val_mean <- mean(resid_train) - + # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) - if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) + } + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) + } + if (is.null(b_leaf)) { + b_leaf <- var(resid_train) / (2 * num_trees_mean) + } if (has_basis) { if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 * var(resid_train) / (num_trees_mean), + ncol(leaf_basis_train) + ) + } if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, ncol(leaf_basis_train))) + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { current_leaf_scale <- sigma2_leaf_init } } else { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) + } if (!is.matrix(sigma2_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { @@ -603,7 +779,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } } } else { - if (is.null(sigma2_leaf_init)) sigma2_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) + } if (!is.matrix(sigma2_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { @@ -612,16 +792,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } current_sigma2 <- sigma2_init } - + # Determine leaf model type - if (!has_basis) leaf_model_mean_forest <- 0 - else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1 - else if (ncol(leaf_basis_train) > 1) leaf_model_mean_forest <- 2 - else stop("leaf_basis_train passed must be a matrix with at least 1 column") + if (!has_basis) { + leaf_model_mean_forest <- 0 + } else if (ncol(leaf_basis_train) == 1) { + leaf_model_mean_forest <- 1 + } else if (ncol(leaf_basis_train) > 1) { + leaf_model_mean_forest <- 2 + } else { + stop("leaf_basis_train passed must be a matrix with at least 1 column") + } # Set variance leaf model type (currently only one option) leaf_model_variance_forest <- 3 - + # Unpack model type info if (leaf_model_mean_forest == 0) { leaf_dimension = 1 @@ -640,104 +825,190 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train is_leaf_constant = FALSE leaf_regression = TRUE if (sample_sigma2_leaf) { - warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.") + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) sample_sigma2_leaf <- FALSE } } - + # Data if (leaf_regression) { forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) - if (has_test) forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) + } requires_basis <- TRUE } else { forest_dataset_train <- createForestDataset(X_train) - if (has_test) forest_dataset_test <- createForestDataset(X_test) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test) + } requires_basis <- FALSE } outcome_train <- createOutcome(resid_train) - + # Random number generator (std::mt19937) - if (is.null(random_seed)) random_seed = sample(1:10000,1,FALSE) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } rng <- createCppRNG(random_seed) - + # Sampling data structures feature_types <- as.integer(feature_types) - global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) if (include_mean_forest) { - forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train), - num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension, - alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean, - leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, - cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_mean) - forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config) + forest_model_config_mean <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mean, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mean, + leaf_dimension = leaf_dimension, + alpha = alpha_mean, + beta = beta_mean, + min_samples_leaf = min_samples_leaf_mean, + max_depth = max_depth_mean, + leaf_model_type = leaf_model_mean_forest, + leaf_model_scale = current_leaf_scale, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mean + ) + forest_model_mean <- createForestModel( + forest_dataset_train, + forest_model_config_mean, + global_model_config + ) } if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), - num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1, - alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, - max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest, - cutpoint_grid_size=cutpoint_grid_size, num_features_subsample=num_features_subsample_variance) - forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) - } - + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = 1, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + # Container of forest samples if (include_mean_forest) { - forest_samples_mean <- createForestSamples(num_trees_mean, leaf_dimension, is_leaf_constant, FALSE) - active_forest_mean <- createForest(num_trees_mean, leaf_dimension, is_leaf_constant, FALSE) + forest_samples_mean <- createForestSamples( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + active_forest_mean <- createForest( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) } if (include_variance_forest) { - forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE) - active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE) + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) } - - # Random effects initialization + + # Random effects initialization if (has_rfx) { # Prior parameters if (is.null(rfx_working_parameter_prior_mean)) { if (num_rfx_components == 1) { alpha_init <- c(0) } else if (num_rfx_components > 1) { - alpha_init <- rep(0,num_rfx_components) + alpha_init <- rep(0, num_rfx_components) } else { stop("There must be at least 1 random effect component") } } else { - alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) } - + if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups) + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) } else { - xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - + if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) } else { - sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) } - + if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) } else { - sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) } - + sigma_xi_shape <- rfx_variance_prior_shape sigma_xi_scale <- rfx_variance_prior_scale - + # Random effects data structure and storage container - rfx_dataset_train <- createRandomEffectsDataset(rfx_group_ids_train, rfx_basis_train) + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel(num_rfx_components, num_rfx_groups) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) rfx_model$set_working_parameter(alpha_init) rfx_model$set_group_parameters(xi_init) rfx_model$set_working_parameter_cov(sigma_alpha_init) rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples(num_rfx_components, num_rfx_groups, rfx_tracker_train) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) } # Container of variance parameter samples @@ -745,97 +1016,187 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter # Delete GFR samples from these containers after the fact if desired # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains - if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) - if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples) - if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples) - if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples) + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf) { + leaf_scale_samples <- rep(NA, num_retained_samples) + } + if (include_mean_forest) { + mean_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + if (include_variance_forest) { + variance_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } sample_counter <- 0 - + # Initialize the leaves of each tree in the mean forest if (include_mean_forest) { - if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) - else init_values_mean_forest <- 0. - active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest) - active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, FALSE) + if (requires_basis) { + init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) + } else { + init_values_mean_forest <- 0. + } + active_forest_mean$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mean, + leaf_model_mean_forest, + init_values_mean_forest + ) + active_forest_mean$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mean, + requires_basis, + FALSE + ) } # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { - active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) } - + # Run GFR (warm start) if specified - if (num_gfr > 0){ + if (num_gfr > 0) { for (i in 1:num_gfr) { # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) keep_sample <- TRUE - if (keep_sample) sample_counter <- sample_counter + 1 + if (keep_sample) { + sample_counter <- sample_counter + 1 + } # Print progress if (verbose) { if ((i %% 10 == 0) || (i == num_gfr)) { - cat("Sampling", i, "out of", num_gfr, "XBART (grow-from-root) draws\n") + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBART (grow-from-root) draws\n" + ) } } - + if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict(forest_dataset_train) + forest_pred <- active_forest_mean$predict( + forest_dataset_train + ) mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train==0] <- mu0 + qnorm(u0) - resid_train[y_train==1] <- mu1 + qnorm(u1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome outcome_train$update_data(resid_train - forest_pred) } - + # Sample mean forest forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, - active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, num_threads = num_threads, - keep_forest = keep_sample, gfr = TRUE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions() + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() } } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = TRUE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } } if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double - forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } } } - + # Run MCMC if (num_burnin + num_mcmc > 0) { for (chain_num in 1:num_chains) { @@ -843,163 +1204,361 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Reset state of active_forest and forest_model based on a previous GFR sample forest_ind <- num_gfr - chain_num if (include_mean_forest) { - resetActiveForest(active_forest_mean, forest_samples_mean, forest_ind) - resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) + resetActiveForest( + active_forest_mean, + forest_samples_mean, + forest_ind + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf) { leaf_scale_double <- leaf_scale_samples[forest_ind + 1] current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } } if (include_variance_forest) { - resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } if (has_rfx) { - resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } if (sample_sigma2_global) { current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance( + current_sigma2 + ) } } else if (has_prev_model) { if (include_mean_forest) { - resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1) - resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) - if (sample_sigma2_leaf && (!is.null(previous_leaf_var_samples))) { - leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num] + resetActiveForest( + active_forest_mean, + previous_forest_samples_mean, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if ( + sample_sigma2_leaf && + (!is.null(previous_leaf_var_samples)) + ) { + leaf_scale_double <- previous_leaf_var_samples[ + previous_model_warmstart_sample_num + ] current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } } if (include_variance_forest) { - resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } if (has_rfx) { if (is.null(previous_rfx_samples)) { - warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") - rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } else { - resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } } if (sample_sigma2_global) { if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] - global_model_config$update_global_error_variance(current_sigma2) + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num + ] + global_model_config$update_global_error_variance( + current_sigma2 + ) } } } else { if (include_mean_forest) { resetActiveForest(active_forest_mean) - active_forest_mean$set_root_leaves(init_values_mean_forest / num_trees_mean) - resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) + active_forest_mean$set_root_leaves( + init_values_mean_forest / num_trees_mean + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf) { current_leaf_scale <- as.matrix(sigma2_leaf_init) - forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } } if (include_variance_forest) { resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves(log(variance_forest_init) / num_trees_variance) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } if (has_rfx) { - rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } if (sample_sigma2_global) { current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance( + current_sigma2 + ) } } - for (i in (num_gfr+1):num_samples) { + for (i in (num_gfr + 1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) if (is_mcmc) { mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) keep_sample <- TRUE - else keep_sample <- FALSE + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } else { - if (keep_burnin) keep_sample <- TRUE - else keep_sample <- FALSE + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } + } + if (keep_sample) { + sample_counter <- sample_counter + 1 } - if (keep_sample) sample_counter <- sample_counter + 1 # Print progress if (verbose) { if (num_burnin > 0) { - if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { - cat("Sampling", i - num_gfr, "out of", num_burnin, "BART burn-in draws; Chain number ", chain_num, "\n") + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_burnin, + "BART burn-in draws; Chain number ", + chain_num, + "\n" + ) } } if (num_mcmc > 0) { - if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { - cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BART MCMC draws; Chain number ", chain_num, "\n") + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BART MCMC draws; Chain number ", + chain_num, + "\n" + ) } } } - + if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict(forest_dataset_train) + forest_pred <- active_forest_mean$predict( + forest_dataset_train + ) mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train==0] <- mu0 + qnorm(u0) - resid_train[y_train==1] <- mu1 + qnorm(u1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome outcome_train$update_data(resid_train - forest_pred) } - + forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, - active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions() + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() } } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } } if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - global_model_config$update_global_error_variance(current_sigma2) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) } if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double - forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } } } } - + # Remove GFR samples if they are not to be retained if ((!keep_gfr) && (num_gfr > 0)) { for (i in 1:num_gfr) { @@ -1014,16 +1573,24 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } } if (include_mean_forest) { - mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)] + mean_forest_pred_train <- mean_forest_pred_train[, + (num_gfr + 1):ncol(mean_forest_pred_train) + ] } if (include_variance_forest) { - variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)] + variance_forest_pred_train <- variance_forest_pred_train[, + (num_gfr + 1):ncol(variance_forest_pred_train) + ] } if (sample_sigma2_global) { - global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf) { - leaf_scale_samples <- leaf_scale_samples[(num_gfr+1):length(leaf_scale_samples)] + leaf_scale_samples <- leaf_scale_samples[ + (num_gfr + 1):length(leaf_scale_samples) + ] } num_retained_samples <- num_retained_samples - num_gfr } @@ -1031,73 +1598,114 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Mean forest predictions if (include_mean_forest) { # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train - y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train - if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train + y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train + if (has_test) { + y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * + y_std_train + + y_bar_train + } } - + # Variance forest predictions if (include_variance_forest) { # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) sigma2_x_hat_train <- exp(variance_forest_pred_train) - if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } } - + # Random effects predictions if (has_rfx) { - rfx_preds_train <- rfx_samples$predict(rfx_group_ids_train, rfx_basis_train)*y_std_train + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train y_hat_train <- y_hat_train + rfx_preds_train } if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict(rfx_group_ids_test, rfx_basis_test)*y_std_train + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train y_hat_test <- y_hat_test + rfx_preds_test } # Global error variance - if (sample_sigma2_global) sigma2_global_samples <- global_var_samples*(y_std_train^2) - + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + # Leaf parameter variance - if (sample_sigma2_leaf) tau_samples <- leaf_scale_samples - + if (sample_sigma2_leaf) { + tau_samples <- leaf_scale_samples + } + # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_train[,i]*sigma2_global_samples[i]) - if (has_test) sigma2_x_hat_test <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_test[,i]*sigma2_global_samples[i]) + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } } else { - sigma2_x_hat_train <- sigma2_x_hat_train*sigma2_init*y_std_train*y_std_train - if (has_test) sigma2_x_hat_test <- sigma2_x_hat_test*sigma2_init*y_std_train*y_std_train + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } } } - + # Return results as a list model_params <- list( - "sigma2_init" = sigma2_init, + "sigma2_init" = sigma2_init, "sigma2_leaf_init" = sigma2_leaf_init, "a_global" = a_global, - "b_global" = b_global, - "a_leaf" = a_leaf, + "b_global" = b_global, + "a_leaf" = a_leaf, "b_leaf" = b_leaf, - "a_forest" = a_forest, + "a_forest" = a_forest, "b_forest" = b_forest, "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, + "outcome_scale" = y_std_train, + "standardize" = standardize, "leaf_dimension" = leaf_dimension, "is_leaf_constant" = is_leaf_constant, "leaf_regression" = leaf_regression, - "requires_basis" = requires_basis, - "num_covariates" = ncol(X_train), - "num_basis" = ifelse(is.null(leaf_basis_train),0,ncol(leaf_basis_train)), - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, + "requires_basis" = requires_basis, + "num_covariates" = ncol(X_train), + "num_basis" = ifelse( + is.null(leaf_basis_train), + 0, + ncol(leaf_basis_train) + ), + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, "keep_every" = keep_every, "num_chains" = num_chains, - "has_basis" = !is.null(leaf_basis_train), - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, + "has_basis" = !is.null(leaf_basis_train), + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, "sample_sigma2_global" = sample_sigma2_global, "sample_sigma2_leaf" = sample_sigma2_leaf, "include_mean_forest" = include_mean_forest, @@ -1105,7 +1713,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train "probit_outcome_model" = probit_outcome_model ) result <- list( - "model_params" = model_params, + "model_params" = model_params, "train_set_metadata" = X_train_metadata ) if (include_mean_forest) { @@ -1118,25 +1726,39 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train result[["sigma2_x_hat_train"]] = sigma2_x_hat_train if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test } - if (sample_sigma2_global) result[["sigma2_global_samples"]] = sigma2_global_samples - if (sample_sigma2_leaf) result[["sigma2_leaf_samples"]] = tau_samples + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf) { + result[["sigma2_leaf_samples"]] = tau_samples + } if (has_rfx) { result[["rfx_samples"]] = rfx_samples result[["rfx_preds_train"]] = rfx_preds_train result[["rfx_unique_group_ids"]] = levels(group_ids_factor) } - if ((has_rfx_test) && (has_test)) result[["rfx_preds_test"]] = rfx_preds_test + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } class(result) <- "bartmodel" - + # Clean up classes with external pointers to C++ data structures - if (include_mean_forest) rm(forest_model_mean) - if (include_variance_forest) rm(forest_model_variance) + if (include_mean_forest) { + rm(forest_model_mean) + } + if (include_variance_forest) { + rm(forest_model_variance) + } rm(forest_dataset_train) - if (has_test) rm(forest_dataset_test) - if (has_rfx) rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + if (has_test) { + rm(forest_dataset_test) + } + if (has_rfx) { + rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + } rm(outcome_train) rm(rng) - + return(result) } @@ -1145,13 +1767,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. #' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`. -#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. +#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. #' @param ... (Optional) Other prediction parameters. #' -#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. +#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. #' If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (`y_hat`). #' @export #' @@ -1160,9 +1782,9 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1176,17 +1798,24 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' y_hat_test <- predict(bart_model, X_test)$y_hat -predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, ...){ +predict.bartmodel <- function( + object, + X, + leaf_basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + ... +) { # Preprocess covariates if ((!is.data.frame(X)) && (!is.matrix(X))) { stop("X must be a matrix or dataframe") } train_set_metadata <- object$train_set_metadata X <- preprocessPredictionData(X, train_set_metadata) - + # Convert all input data to matrices if not already converted if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { leaf_basis <- as.matrix(leaf_basis) @@ -1194,7 +1823,7 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { rfx_basis <- as.matrix(rfx_basis) } - + # Data checks if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { stop("Basis (leaf_basis) must be provided for this model") @@ -1206,75 +1835,109 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL stop("X and leaf_basis must have the same number of rows") } if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { - stop("Random effect group labels (rfx_group_ids) must be provided for this model") + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) } if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { stop("Random effects basis (rfx_basis) must be provided for this model") } - if ((object$model_params$num_rfx_basis > 0) && (ncol(rfx_basis) != object$model_params$num_rfx_basis)) { - stop("Random effects basis has a different dimension than the basis used to train this model") + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE if (!is.null(rfx_group_ids)) { rfx_unique_group_ids <- object$rfx_unique_group_ids group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) if (sum(is.na(group_ids_factor)) > 0) { - stop("All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train") + stop( + "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" + ) } rfx_group_ids <- as.integer(group_ids_factor) has_rfx <- TRUE } - + # Produce basis for the "intercept-only" random effects case if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) } - + # Create prediction dataset - if (!is.null(leaf_basis)) prediction_dataset <- createForestDataset(X, leaf_basis) - else prediction_dataset <- createForestDataset(X) - + if (!is.null(leaf_basis)) { + prediction_dataset <- createForestDataset(X, leaf_basis) + } else { + prediction_dataset <- createForestDataset(X) + } + # Compute mean forest predictions num_samples <- object$model_params$num_samples y_std <- object$model_params$outcome_scale y_bar <- object$model_params$outcome_mean sigma2_init <- object$model_params$sigma2_init if (object$model_params$include_mean_forest) { - mean_forest_predictions <- object$mean_forests$predict(prediction_dataset)*y_std + y_bar + mean_forest_predictions <- object$mean_forests$predict( + prediction_dataset + ) * + y_std + + y_bar } - + # Compute variance forest predictions if (object$model_params$include_variance_forest) { s_x_raw <- object$variance_forests$predict(prediction_dataset) } - + # Compute rfx predictions (if needed) if (object$model_params$has_rfx) { - rfx_predictions <- object$rfx_samples$predict(rfx_group_ids, rfx_basis)*y_std + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std } - + # Scale variance forest predictions if (object$model_params$include_variance_forest) { if (object$model_params$sample_sigma2_global) { sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i]) + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) } else { - variance_forest_predictions <- s_x_raw*sigma2_init*y_std*y_std + variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std } } - if ((object$model_params$include_mean_forest) && (object$model_params$has_rfx)) { + if ( + (object$model_params$include_mean_forest) && + (object$model_params$has_rfx) + ) { y_hat <- mean_forest_predictions + rfx_predictions - } else if ((object$model_params$include_mean_forest) && (!object$model_params$has_rfx)) { + } else if ( + (object$model_params$include_mean_forest) && + (!object$model_params$has_rfx) + ) { y_hat <- mean_forest_predictions - } else if ((!object$model_params$include_mean_forest) && (object$model_params$has_rfx)) { + } else if ( + (!object$model_params$include_mean_forest) && + (object$model_params$has_rfx) + ) { y_hat <- rfx_predictions - } - + } + result <- list() - if ((object$model_params$has_rfx) || (object$model_params$include_mean_forest)) { + if ( + (object$model_params$has_rfx) || + (object$model_params$include_mean_forest) + ) { result[["y_hat"]] = y_hat } else { result[["y_hat"]] <- NULL @@ -1311,9 +1974,9 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' snr <- 3 @@ -1338,30 +2001,33 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_train = rfx_basis_train, -#' rfx_basis_test = rfx_basis_test, +#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_group_ids_test = rfx_group_ids_test, +#' rfx_basis_train = rfx_basis_train, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' rfx_samples <- getRandomEffectSamples(bart_model) -getRandomEffectSamples.bartmodel <- function(object, ...){ +getRandomEffectSamples.bartmodel <- function(object, ...) { result = list() - + if (!object$model_params$has_rfx) { warning("This model has no RFX terms, returning an empty list") return(result) } - + # Extract the samples result <- object$rfx_samples$extract_parameter_samples() - + # Scale by sd(y_train) - result$beta_samples <- result$beta_samples*object$model_params$outcome_scale - result$xi_samples <- result$xi_samples*object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples*object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples*(object$model_params$outcome_scale^2) - + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) + return(result) } @@ -1377,9 +2043,9 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1393,16 +2059,16 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJson(bart_model) -saveBARTModelToJson <- function(object){ +saveBARTModelToJson <- function(object) { jsonobj <- createCppJson() - + if (!inherits(object, "bartmodel")) { stop("`object` must be a BART model") } - + if (is.null(object$model_params)) { stop("This BCF model has not yet been sampled") } @@ -1416,30 +2082,66 @@ saveBARTModelToJson <- function(object){ } # Add metadata - jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars) - jsonobj$add_scalar("num_ordered_cat_vars", object$train_set_metadata$num_ordered_cat_vars) - jsonobj$add_scalar("num_unordered_cat_vars", object$train_set_metadata$num_unordered_cat_vars) + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector("numeric_vars", object$train_set_metadata$numeric_vars) + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars + ) } if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector("ordered_cat_vars", object$train_set_metadata$ordered_cat_vars) - jsonobj$add_string_list("ordered_unique_levels", object$train_set_metadata$ordered_unique_levels) + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars + ) + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels + ) } if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector("unordered_cat_vars", object$train_set_metadata$unordered_cat_vars) - jsonobj$add_string_list("unordered_unique_levels", object$train_set_metadata$unordered_unique_levels) + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars + ) + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels + ) } - + # Add global parameters jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) - jsonobj$add_boolean("sample_sigma2_global", object$model_params$sample_sigma2_global) - jsonobj$add_boolean("sample_sigma2_leaf", object$model_params$sample_sigma2_leaf) - jsonobj$add_boolean("include_mean_forest", object$model_params$include_mean_forest) - jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf", + object$model_params$sample_sigma2_leaf + ) + jsonobj$add_boolean( + "include_mean_forest", + object$model_params$include_mean_forest + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) @@ -1452,26 +2154,40 @@ saveBARTModelToJson <- function(object){ jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) - jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" + ) } if (object$model_params$sample_sigma2_leaf) { - jsonobj$add_vector("sigma2_leaf_samples", object$sigma2_leaf_samples, "parameters") + jsonobj$add_vector( + "sigma2_leaf_samples", + object$sigma2_leaf_samples, + "parameters" + ) } # Add random effects (if present) if (object$model_params$has_rfx) { jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids + ) } - + # Add covariate preprocessor metadata preprocessor_metadata_string <- savePreprocessorToJsonString( object$train_set_metadata ) jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) - + return(jsonobj) } @@ -1488,9 +2204,9 @@ saveBARTModelToJson <- function(object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1504,15 +2220,15 @@ saveBARTModelToJson <- function(object){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' tmpjson <- tempfile(fileext = ".json") #' saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) #' unlink(tmpjson) -saveBARTModelToJsonFile <- function(object, filename){ +saveBARTModelToJsonFile <- function(object, filename) { # Convert to Json jsonobj <- saveBARTModelToJson(object) - + # Save to file jsonobj$save_file(filename) } @@ -1528,9 +2244,9 @@ saveBARTModelToJsonFile <- function(object, filename){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1544,18 +2260,18 @@ saveBARTModelToJsonFile <- function(object, filename){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json_string <- saveBARTModelToJsonString(bart_model) -saveBARTModelToJsonString <- function(object){ +saveBARTModelToJsonString <- function(object) { # Convert to Json jsonobj <- saveBARTModelToJson(object) - + # Dump to string return(jsonobj$return_json_string()) } -#' Convert an (in-memory) JSON representation of a BART model to a BART model object +#' Convert an (in-memory) JSON representation of a BART model to a BART model object #' which can be used for prediction, etc... #' #' @param json_object Object of type `CppJson` containing Json representation of a BART model @@ -1568,9 +2284,9 @@ saveBARTModelToJsonString <- function(object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1584,52 +2300,89 @@ saveBARTModelToJsonString <- function(object){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJson(bart_model) #' bart_model_roundtrip <- createBARTModelFromJson(bart_json) -createBARTModelFromJson <- function(json_object){ +createBARTModelFromJson <- function(json_object) { # Initialize the BCF model output <- list() - + # Unpack the forests include_mean_forest <- json_object$get_boolean("include_mean_forest") - include_variance_forest <- json_object$get_boolean("include_variance_forest") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerJson(json_object, "forest_0") + output[["mean_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" + ) if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_1") + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_1" + ) } } else { - output[["variance_forests"]] <- loadForestContainerJson(json_object, "forest_0") + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" + ) } # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" + ) } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata - + # Unpack model params model_params = list() model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean("sample_sigma2_leaf") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean( + "sample_sigma2_leaf" + ) model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") @@ -1643,36 +2396,50 @@ createBARTModelFromJson <- function(json_object){ model_params[["num_basis"]] <- json_object$get_scalar("num_basis") model_params[["num_chains"]] <- json_object$get_scalar("num_chains") model_params[["keep_every"]] <- json_object$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") - model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model") - + model_params[["requires_basis"]] <- json_object$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } if (model_params[["sample_sigma2_leaf"]]) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters") + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" + ) } # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids") + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" + ) output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) - + class(output) <- "bartmodel" return(output) } -#' Convert a JSON file containing sample information on a trained BART model +#' Convert a JSON file containing sample information on a trained BART model #' to a BART model object which can be used for prediction, etc... #' #' @param json_filename String of filepath, must end in ".json" @@ -1685,9 +2452,9 @@ createBARTModelFromJson <- function(json_object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1701,23 +2468,23 @@ createBARTModelFromJson <- function(json_object){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' tmpjson <- tempfile(fileext = ".json") #' saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) #' bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) -createBARTModelFromJsonFile <- function(json_filename){ +createBARTModelFromJsonFile <- function(json_filename) { # Load a `CppJson` object from file bart_json <- createCppJsonFile(json_filename) - + # Create and return the BART object bart_object <- createBARTModelFromJson(bart_json) - + return(bart_object) } -#' Convert a JSON string containing sample information on a trained BART model +#' Convert a JSON string containing sample information on a trained BART model #' to a BART model object which can be used for prediction, etc... #' #' @param json_string JSON string dump @@ -1730,9 +2497,9 @@ createBARTModelFromJsonFile <- function(json_filename){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1746,22 +2513,22 @@ createBARTModelFromJsonFile <- function(json_filename){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJsonString(bart_model) #' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) #' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) -createBARTModelFromJsonString <- function(json_string){ +createBARTModelFromJsonString <- function(json_string) { # Load a `CppJson` object from string bart_json <- createCppJsonString(json_string) - + # Create and return the BART object bart_object <- createBARTModelFromJson(bart_json) - + return(bart_object) } -#' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object +#' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object #' which can be used for prediction, etc... #' #' @param json_object_list List of objects of type `CppJson` containing Json representation of a BART model @@ -1774,9 +2541,9 @@ createBARTModelFromJsonString <- function(json_string){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1790,65 +2557,122 @@ createBARTModelFromJsonString <- function(json_string){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- list(saveBARTModelToJson(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) -createBARTModelFromCombinedJson <- function(json_object_list){ +createBARTModelFromCombinedJson <- function(json_object_list) { # Initialize the BCF model output <- list() - # For scalar / preprocessing details which aren't sample-dependent, + # For scalar / preprocessing details which aren't sample-dependent, # defer to the first json json_object_default <- json_object_list[[1]] - + # Unpack the forests - include_mean_forest <- json_object_default$get_boolean("include_mean_forest") - include_variance_forest <- json_object_default$get_boolean("include_variance_forest") + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1") + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) } } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) } - + # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata # Unpack model params model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object_default$get_boolean("standardize") - model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean("sample_sigma2_leaf") + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") - model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") @@ -1859,25 +2683,40 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) } else { - prev_json <- json_object_list[[i-1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") } } output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } else { - output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) + ) } } } @@ -1885,30 +2724,43 @@ createBARTModelFromCombinedJson <- function(json_object_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters") + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" + ) } else { - output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters")) + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) } } } - + # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 + ) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) - + class(output) <- "bartmodel" return(output) } -#' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object +#' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object #' which can be used for prediction, etc... #' #' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BART model @@ -1921,9 +2773,9 @@ createBARTModelFromCombinedJson <- function(json_object_list){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' f_XW <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' noise_sd <- 1 @@ -1937,75 +2789,132 @@ createBARTModelFromCombinedJson <- function(json_object_list){ #' X_train <- X[train_inds,] #' y_test <- y[test_inds] #' y_train <- y[train_inds] -#' bart_model <- bart(X_train = X_train, y_train = y_train, +#' bart_model <- bart(X_train = X_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) -createBARTModelFromCombinedJsonString <- function(json_string_list){ +createBARTModelFromCombinedJsonString <- function(json_string_list) { # Initialize the BCF model output <- list() - + # Convert JSON strings json_object_list <- list() for (i in 1:length(json_string_list)) { json_string <- json_string_list[[i]] json_object_list[[i]] <- createCppJsonString(json_string) } - - # For scalar / preprocessing details which aren't sample-dependent, + + # For scalar / preprocessing details which aren't sample-dependent, # defer to the first json json_object_default <- json_object_list[[1]] - + # Unpack the forests - include_mean_forest <- json_object_default$get_boolean("include_mean_forest") - include_variance_forest <- json_object_default$get_boolean("include_variance_forest") + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1") + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) } } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) } - + # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata # Unpack model params model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object_default$get_boolean("standardize") - model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean("sample_sigma2_leaf") + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) model_params[["include_mean_forest"]] <- include_mean_forest model_params[["include_variance_forest"]] <- include_variance_forest model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") - model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") - + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + # Combine values that are sample-specific for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] @@ -2013,25 +2922,40 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) } else { - prev_json <- json_object_list[[i-1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") } } output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } else { - output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) + ) } } } @@ -2039,25 +2963,38 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector("sigma2_leaf_samples", "parameters") + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" + ) } else { - output[["sigma2_leaf_samples"]] <- c(output[["sigma2_leaf_samples"]], json_object$get_vector("sigma2_leaf_samples", "parameters")) + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) } } } - + # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 + ) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) - + class(output) <- "bartmodel" return(output) } diff --git a/R/bcf.R b/R/bcf.R index 74e9476f..bd5fd5b1 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1,23 +1,23 @@ -#' Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation. +#' Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation. #' -#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. -#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be -#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, -#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata +#' @param X_train Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. +#' Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be +#' preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, +#' categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata #' that the column is ordered categorical). #' @param Z_train Vector of (continuous or binary) treatment assignments. #' @param y_train Outcome to be modeled by the ensemble. #' @param propensity_train (Optional) Vector of propensity scores. If not provided, this will be estimated from the data. #' @param rfx_group_ids_train (Optional) Group labels used for an additive random effects model. #' @param rfx_basis_train (Optional) Basis for "random-slope" regression in an additive random effects model. -#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model +#' If `rfx_group_ids_train` is provided with a regression basis, an intercept-only random effects model #' will be estimated. -#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. -#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with +#' @param X_test (Optional) Test set of covariates used to define "out of sample" evaluation data. +#' May be provided either as a dataframe or a matrix, but the format of `X_test` must be consistent with #' that of `X_train`. #' @param Z_test (Optional) Test set of (continuous or binary) treatment assignments. #' @param propensity_test (Optional) Vector of propensity scores. If not provided, this will be estimated from the data. -#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model. +#' @param rfx_group_ids_test (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis_test (Optional) Test set basis for "random-slope" regression in additive random effects model. @@ -111,21 +111,21 @@ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -148,81 +148,127 @@ #' mu_train <- mu_x[train_inds] #' tau_test <- tau_x[test_inds] #' tau_train <- tau_x[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, X_test = X_test, Z_test = Z_test, -#' propensity_test = pi_test, num_gfr = 10, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, X_test = X_test, Z_test = Z_test, +#' propensity_test = pi_test, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) -bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_ids_train = NULL, - rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, propensity_test = NULL, - rfx_group_ids_test = NULL, rfx_basis_test = NULL, - num_gfr = 5, num_burnin = 0, num_mcmc = 100, - previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, - general_params = list(), prognostic_forest_params = list(), - treatment_effect_forest_params = list(), variance_forest_params = list()) { +bcf <- function( + X_train, + Z_train, + y_train, + propensity_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + Z_test = NULL, + propensity_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + prognostic_forest_params = list(), + treatment_effect_forest_params = list(), + variance_forest_params = list() +) { # Update general BCF parameters general_params_default <- list( - cutpoint_grid_size = 100, standardize = TRUE, - sample_sigma2_global = TRUE, sigma2_global_init = NULL, - sigma2_global_shape = 0, sigma2_global_scale = 0, - variable_weights = NULL, propensity_covariate = "mu", - adaptive_coding = TRUE, control_coding_init = -0.5, - treated_coding_init = 0.5, rfx_prior_var = NULL, - random_seed = -1, keep_burnin = FALSE, keep_gfr = FALSE, - keep_every = 1, num_chains = 1, verbose = FALSE, - probit_outcome_model = FALSE, + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + propensity_covariate = "mu", + adaptive_coding = TRUE, + control_coding_init = -0.5, + treated_coding_init = 0.5, + rfx_prior_var = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + probit_outcome_model = FALSE, rfx_working_parameter_prior_mean = NULL, rfx_group_parameter_prior_mean = NULL, rfx_working_parameter_prior_cov = NULL, rfx_group_parameter_prior_cov = NULL, rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, + rfx_variance_prior_scale = 1, num_threads = -1 ) general_params_updated <- preprocessParams( - general_params_default, general_params + general_params_default, + general_params ) # Update mu forest BCF parameters prognostic_forest_params_default <- list( - num_trees = 250, alpha = 0.95, beta = 2.0, - min_samples_leaf = 5, max_depth = 10, - sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL, + num_trees = 250, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, num_features_subsample = NULL ) prognostic_forest_params_updated <- preprocessParams( - prognostic_forest_params_default, prognostic_forest_params + prognostic_forest_params_default, + prognostic_forest_params ) - + # Update tau forest BCF parameters treatment_effect_forest_params_default <- list( - num_trees = 50, alpha = 0.25, beta = 3.0, - min_samples_leaf = 5, max_depth = 5, - sample_sigma2_leaf = FALSE, sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL, delta_max = 0.9, + num_trees = 50, + alpha = 0.25, + beta = 3.0, + min_samples_leaf = 5, + max_depth = 5, + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + delta_max = 0.9, num_features_subsample = NULL ) treatment_effect_forest_params_updated <- preprocessParams( - treatment_effect_forest_params_default, treatment_effect_forest_params + treatment_effect_forest_params_default, + treatment_effect_forest_params ) - + # Update variance forest BCF parameters variance_forest_params_default <- list( - num_trees = 0, alpha = 0.95, beta = 2.0, - min_samples_leaf = 5, max_depth = 10, - leaf_prior_calibration_param = 1.5, - variance_forest_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, drop_vars = NULL, + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, + variance_forest_init = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, num_features_subsample = NULL ) variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, variance_forest_params + variance_forest_params_default, + variance_forest_params ) - + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size @@ -251,7 +297,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale num_threads <- general_params_updated$num_threads - + # 2. Mu forest parameters num_trees_mu <- prognostic_forest_params_updated$num_trees alpha_mu <- prognostic_forest_params_updated$alpha @@ -265,7 +311,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id keep_vars_mu <- prognostic_forest_params_updated$keep_vars drop_vars_mu <- prognostic_forest_params_updated$drop_vars num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample - + # 3. Tau forest parameters num_trees_tau <- treatment_effect_forest_params_updated$num_trees alpha_tau <- treatment_effect_forest_params_updated$alpha @@ -280,7 +326,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars delta_max <- treatment_effect_forest_params_updated$delta_max num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample - + # 4. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees alpha_variance <- variance_forest_params_updated$alpha @@ -294,17 +340,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id keep_vars_variance <- variance_forest_params_updated$keep_vars drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - + # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { if (num_chains > num_gfr) { - stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) } } # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) keep_gfr <- TRUE - + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + # Check if previous model JSON is provided and parse it if so has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { @@ -315,21 +365,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_forest_samples_tau <- previous_bcf_model$forests_tau if (previous_bcf_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bcf_model$forests_variance - } else previous_forest_samples_variance <- NULL + } else { + previous_forest_samples_variance <- NULL + } if (previous_bcf_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / ( - previous_y_scale*previous_y_scale - ) - } else previous_global_var_samples <- NULL + previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) + } else { + previous_global_var_samples <- NULL + } if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples - } else previous_leaf_var_mu_samples <- NULL + } else { + previous_leaf_var_mu_samples <- NULL + } if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples - } else previous_leaf_var_tau_samples <- NULL + } else { + previous_leaf_var_tau_samples <- NULL + } if (previous_bcf_model$model_params$has_rfx) { previous_rfx_samples <- previous_bcf_model$rfx_samples - } else previous_rfx_samples <- NULL + } else { + previous_rfx_samples <- NULL + } if (previous_bcf_model$model_params$adaptive_coding) { previous_b_1_samples <- previous_bcf_model$b_1_samples previous_b_0_samples <- previous_bcf_model$b_0_samples @@ -339,7 +398,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } previous_model_num_samples <- previous_bcf_model$model_params$num_samples if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) } } else { previous_y_bar <- NULL @@ -354,54 +415,65 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_b_1_samples <- NULL previous_b_0_samples <- NULL } - + # Determine whether conditional variance will be modeled - if (num_trees_variance > 0) include_variance_forest = TRUE - else include_variance_forest = FALSE + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } # Set the variance forest priors if not set if (include_variance_forest) { - if (is.null(a_forest)) a_forest <- num_trees_variance / (a_0^2) + 0.5 + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) } else { a_forest <- 1. b_forest <- 1. } - + # Variable weight preprocessing (and initialization if necessary) if (is.null(variable_weights)) { - variable_weights = rep(1/ncol(X_train), ncol(X_train)) + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) } if (any(variable_weights < 0)) { stop("variable_weights cannot have any negative weights") } - + # Check covariates are matrix or dataframe if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { stop("X_train must be a matrix or dataframe") } - if (!is.null(X_test)){ + if (!is.null(X_test)) { if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { stop("X_test must be a matrix or dataframe") } } num_cov_orig <- ncol(X_train) - + # Check delta_max is valid if ((delta_max <= 0) || (delta_max >= 1)) { stop("delta_max must be > 0 and < 1") } - + # Standardize the keep variable lists to numeric indices if (!is.null(keep_vars_mu)) { if (is.character(keep_vars_mu)) { if (!all(keep_vars_mu %in% names(X_train))) { - stop("keep_vars_mu includes some variable names that are not in X_train") + stop( + "keep_vars_mu includes some variable names that are not in X_train" + ) } - variable_subset_mu <- unname(which(names(X_train) %in% keep_vars_mu)) + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mu + )) } else { if (any(keep_vars_mu > ncol(X_train))) { - stop("keep_vars_mu includes some variable indices that exceed the number of columns in X_train") + stop( + "keep_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) } if (any(keep_vars_mu < 0)) { stop("keep_vars_mu includes some negative variable indices") @@ -411,17 +483,25 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) { if (is.character(drop_vars_mu)) { if (!all(drop_vars_mu %in% names(X_train))) { - stop("drop_vars_mu includes some variable names that are not in X_train") + stop( + "drop_vars_mu includes some variable names that are not in X_train" + ) } - variable_subset_mu <- unname(which(!(names(X_train) %in% drop_vars_mu))) + variable_subset_mu <- unname(which( + !(names(X_train) %in% drop_vars_mu) + )) } else { if (any(drop_vars_mu > ncol(X_train))) { - stop("drop_vars_mu includes some variable indices that exceed the number of columns in X_train") + stop( + "drop_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) } if (any(drop_vars_mu < 0)) { stop("drop_vars_mu includes some negative variable indices") } - variable_subset_mu <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_mu)] + variable_subset_mu <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mu) + ] } } else { variable_subset_mu <- 1:ncol(X_train) @@ -429,12 +509,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (!is.null(keep_vars_tau)) { if (is.character(keep_vars_tau)) { if (!all(keep_vars_tau %in% names(X_train))) { - stop("keep_vars_tau includes some variable names that are not in X_train") + stop( + "keep_vars_tau includes some variable names that are not in X_train" + ) } - variable_subset_tau <- unname(which(names(X_train) %in% keep_vars_tau)) + variable_subset_tau <- unname(which( + names(X_train) %in% keep_vars_tau + )) } else { if (any(keep_vars_tau > ncol(X_train))) { - stop("keep_vars_tau includes some variable indices that exceed the number of columns in X_train") + stop( + "keep_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) } if (any(keep_vars_tau < 0)) { stop("keep_vars_tau includes some negative variable indices") @@ -444,17 +530,25 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) { if (is.character(drop_vars_tau)) { if (!all(drop_vars_tau %in% names(X_train))) { - stop("drop_vars_tau includes some variable names that are not in X_train") + stop( + "drop_vars_tau includes some variable names that are not in X_train" + ) } - variable_subset_tau <- unname(which(!(names(X_train) %in% drop_vars_tau))) + variable_subset_tau <- unname(which( + !(names(X_train) %in% drop_vars_tau) + )) } else { if (any(drop_vars_tau > ncol(X_train))) { - stop("drop_vars_tau includes some variable indices that exceed the number of columns in X_train") + stop( + "drop_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) } if (any(drop_vars_tau < 0)) { stop("drop_vars_tau includes some negative variable indices") } - variable_subset_tau <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_tau)] + variable_subset_tau <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_tau) + ] } } else { variable_subset_tau <- 1:ncol(X_train) @@ -462,37 +556,57 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (!is.null(keep_vars_variance)) { if (is.character(keep_vars_variance)) { if (!all(keep_vars_variance %in% names(X_train))) { - stop("keep_vars_variance includes some variable names that are not in X_train") + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) } - variable_subset_variance <- unname(which(names(X_train) %in% keep_vars_variance)) + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) } else { if (any(keep_vars_variance > ncol(X_train))) { - stop("keep_vars_variance includes some variable indices that exceed the number of columns in X_train") + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) } if (any(keep_vars_variance < 0)) { - stop("keep_vars_variance includes some negative variable indices") + stop( + "keep_vars_variance includes some negative variable indices" + ) } variable_subset_variance <- keep_vars_variance } - } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + } else if ( + (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) + ) { if (is.character(drop_vars_variance)) { if (!all(drop_vars_variance %in% names(X_train))) { - stop("drop_vars_variance includes some variable names that are not in X_train") + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) } - variable_subset_variance <- unname(which(!(names(X_train) %in% drop_vars_variance))) + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) } else { if (any(drop_vars_variance > ncol(X_train))) { - stop("drop_vars_variance includes some variable indices that exceed the number of columns in X_train") + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) } if (any(drop_vars_variance < 0)) { - stop("drop_vars_variance includes some negative variable indices") + stop( + "drop_vars_variance includes some negative variable indices" + ) } - variable_subset_variance <- (1:ncol(X_train))[!(1:ncol(X_train) %in% drop_vars_variance)] + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] } } else { variable_subset_variance <- 1:ncol(X_train) } - + # Preprocess covariates if (ncol(X_train) != length(variable_weights)) { stop("length(variable_weights) must equal ncol(X_train)") @@ -504,8 +618,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id original_var_indices <- X_train_metadata$original_var_indices feature_types <- X_train_metadata$feature_types X_test_raw <- X_test - if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) - + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + # Convert all input data to matrices if not already converted Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) @@ -524,7 +640,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { rfx_basis_test <- as.matrix(rfx_basis_test) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE has_rfx_test <- FALSE @@ -533,18 +649,27 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rfx_group_ids_train <- as.integer(group_ids_factor) has_rfx <- TRUE if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor(rfx_group_ids_test, levels = levels(group_ids_factor)) + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) if (sum(is.na(group_ids_factor_test)) > 0) { - stop("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) } rfx_group_ids_test <- as.integer(group_ids_factor_test) has_rfx_test <- TRUE } } - + # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) stop("y_train must be numeric") - if (!is.numeric(Z_train)) stop("Z_train must be numeric") + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } if (!is.null(Z_test)) { if (!is.numeric(Z_test)) stop("Z_test must be numeric") } @@ -559,98 +684,138 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) { stop("Z_train and X_train must have the same number of rows") } - if ((!is.null(propensity_train)) && (nrow(propensity_train) != nrow(X_train))) { + if ( + (!is.null(propensity_train)) && + (nrow(propensity_train) != nrow(X_train)) + ) { stop("propensity_train and X_train must have the same number of rows") } if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) { stop("Z_test and X_test must have the same number of rows") } - if ((!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test))) { + if ( + (!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test)) + ) { stop("propensity_test and X_test must have the same number of rows") } if (nrow(X_train) != length(y_train)) { stop("X_train and y_train must have the same number of observations") } - if ((!is.null(rfx_basis_test)) && (ncol(rfx_basis_test) != ncol(rfx_basis_train))) { - stop("rfx_basis_train and rfx_basis_test must have the same number of columns") + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" + ) } if (!is.null(rfx_group_ids_train)) { if (!is.null(rfx_group_ids_test)) { if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop("rfx_basis_train is provided but rfx_basis_test is not provided") + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) } } } - + # # Stop if multivariate treatment is provided # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - + # Handle multivariate treatment has_multivariate_treatment <- ncol(Z_train) > 1 if (has_multivariate_treatment) { - # Disable adaptive coding, internal propensity model, and + # Disable adaptive coding, internal propensity model, and # leaf scale sampling if treatment is multivariate if (adaptive_coding) { - warning("Adaptive coding is incompatible with multivariate treatment and will be ignored") + warning( + "Adaptive coding is incompatible with multivariate treatment and will be ignored" + ) adaptive_coding <- FALSE } if (is.null(propensity_train)) { if (propensity_covariate != "none") { - warning("No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'") + warning( + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" + ) propensity_covariate <- "none" } } if (sample_sigma2_leaf_tau) { - warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.") + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." + ) sample_sigma2_leaf_tau <- FALSE } } - + # Random effects covariance prior if (has_rfx) { if (is.null(rfx_prior_var)) { rfx_prior_var <- rep(1, ncol(rfx_basis_train)) } else { - if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) stop("rfx_prior_var must be a numeric vector") - if (length(rfx_prior_var) != ncol(rfx_basis_train)) stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") + if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { + stop("rfx_prior_var must be a numeric vector") + } + if (length(rfx_prior_var) != ncol(rfx_basis_train)) { + stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") + } } } - + # Update variable weights - variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x)) - variable_weights <- variable_weights[original_var_indices]*variable_weights_adj - + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices] * + variable_weights_adj + # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0 variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0 if (include_variance_forest) { - variable_weights_variance[!(original_var_indices %in% variable_subset_variance)] <- 0 + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 } - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx <- FALSE num_basis_rfx <- 0 if (has_rfx) { if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1) + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) } else { has_basis_rfx <- TRUE num_basis_rfx <- ncol(rfx_basis_train) } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) warning("Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill") + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + ) + } } if (has_rfx_test) { if (is.null(rfx_basis_test)) { if (!is.null(rfx_basis_train)) { - stop("Random effects basis provided for training set, must also be provided for the test set") + stop( + "Random effects basis provided for training set, must also be provided for the test set" + ) } - rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1) + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) } } - + # Check that number of samples are all nonnegative stopifnot(num_gfr >= 0) stopifnot(num_burnin >= 0) @@ -658,29 +823,31 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Determine whether a test set is provided has_test = !is.null(X_test) - + # Convert y_train to numeric vector if not already converted if (!is.null(dim(y_train))) { y_train <- as.matrix(y_train) } - + # Check whether treatment is binary (specifically 0-1 binary) binary_treatment <- length(unique(Z_train)) == 2 if (binary_treatment) { unique_treatments <- sort(unique(Z_train)) - if (!(all(unique_treatments == c(0,1)))) binary_treatment <- FALSE + if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE } - + # Adaptive coding will be ignored for continuous / ordered categorical treatments if ((!binary_treatment) && (adaptive_coding)) { adaptive_coding <- FALSE } - + # Check if propensity_covariate is one of the required inputs - if (!(propensity_covariate %in% c("mu","tau","both","none"))) { - stop("propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'") + if (!(propensity_covariate %in% c("mu", "tau", "both", "none"))) { + stop( + "propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'" + ) } - + # Estimate if pre-estimated propensity score is not provided internal_propensity_model <- FALSE if ((is.null(propensity_train)) && (propensity_covariate != "none")) { @@ -688,52 +855,105 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Estimate using the last of several iterations of GFR BART num_burnin <- 10 num_total <- 50 - bart_model_propensity <- bart(X_train = X_train, y_train = as.numeric(Z_train), X_test = X_test_raw, - num_gfr = num_total, num_burnin = 0, num_mcmc = 0) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train[,(num_burnin+1):num_total]) + bart_model_propensity <- bart( + X_train = X_train, + y_train = as.numeric(Z_train), + X_test = X_test_raw, + num_gfr = num_total, + num_burnin = 0, + num_mcmc = 0 + ) + propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, + (num_burnin + 1):num_total + ]) if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { propensity_train <- as.matrix(propensity_train) } if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test[,(num_burnin+1):num_total]) - if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, + (num_burnin + 1):num_total + ]) + if ( + (is.null(dim(propensity_test))) && (!is.null(propensity_test)) + ) { propensity_test <- as.matrix(propensity_test) } } } if (has_test) { - if (is.null(propensity_test)) stop("Propensity score must be provided for the test set if provided for the training set") + if (is.null(propensity_test)) { + stop( + "Propensity score must be provided for the test set if provided for the training set" + ) + } } - + # Update feature_types and covariates feature_types <- as.integer(feature_types) if (propensity_covariate != "none") { - feature_types <- as.integer(c(feature_types,rep(0, ncol(propensity_train)))) + feature_types <- as.integer(c( + feature_types, + rep(0, ncol(propensity_train)) + )) X_train <- cbind(X_train, propensity_train) if (propensity_covariate == "mu") { - variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train))) - variable_weights_tau <- c(variable_weights_tau, rep(0, ncol(propensity_train))) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(0, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } } else if (propensity_covariate == "tau") { - variable_weights_mu <- c(variable_weights_mu, rep(0, ncol(propensity_train))) - variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train))) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) + variable_weights_mu <- c( + variable_weights_mu, + rep(0, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } } else if (propensity_covariate == "both") { - variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train))) - variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train))) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } } if (has_test) X_test <- cbind(X_test, propensity_test) } - + # Renormalize variable weights variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu) variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau) if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance / sum(variable_weights_variance) + variable_weights_variance <- variable_weights_variance / + sum(variable_weights_variance) } - + # Set num_features_subsample to default, ncol(X_train), if not already set if (is.null(num_features_subsample_mu)) { num_features_subsample_mu <- ncol(X_train) @@ -744,46 +964,56 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (is.null(num_features_subsample_variance)) { num_features_subsample_variance <- ncol(X_train) } - + # Preliminary runtime checks for probit link if (probit_outcome_model) { if (!(length(unique(y_train)) == 2)) { - stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) } unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0,1)))) { - stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } if (include_variance_forest) { stop("We do not support heteroskedasticity with a probit link") } if (sample_sigma2_global) { - warning("Global error variance will not be sampled with a probit link as it is fixed at 1") + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) sample_sigma2_global <- F } } - + # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (probit_outcome_model) { # Compute a probit-scale offset and fix scale to 1 y_bar_train <- qnorm(mean(y_train)) y_std_train <- 1 - + # Set a pseudo outcome by subtracting mean(y_train) from y_train resid_train <- y_train - mean(y_train) - + # Set initial value for the mu forest init_mu <- 0.0 - + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau # Set sigma2_init to 1, ignoring any defaults provided sigma2_init <- 1.0 # Skip variance_forest_init, since variance forests are not supported with probit link - if (is.null(b_leaf_mu)) b_leaf_mu <- 1/num_trees_mu - if (is.null(b_leaf_tau)) b_leaf_tau <- 1/(2*num_trees_tau) + if (is.null(b_leaf_mu)) { + b_leaf_mu <- 1 / num_trees_mu + } + if (is.null(b_leaf_tau)) { + b_leaf_tau <- 1 / (2 * num_trees_tau) + } if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2/(num_trees_mu) + sigma2_leaf_mu <- 2 / (num_trees_mu) current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { if (!is.matrix(sigma2_leaf_mu)) { @@ -794,20 +1024,35 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } if (is.null(sigma2_leaf_tau)) { # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p - # Use p = 0.9 as an internal default rather than adding another - # user-facing "parameter" of the binary outcome BCF prior. - # Can be overriden by specifying `sigma2_leaf_init` in + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in # treatment_effect_forest_params. p <- 0.6827 - q_quantile <- qnorm((p+1)/2) - sigma2_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau - current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) + q_quantile <- qnorm((p + 1) / 2) + sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) / + num_trees_tau + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) } else { if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) + } + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) + } current_leaf_scale_tau <- sigma2_leaf_tau } } @@ -821,20 +1066,28 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id y_bar_train <- 0 y_std_train <- 1 } - + # Compute standardized outcome - resid_train <- (y_train-y_bar_train)/y_std_train - + resid_train <- (y_train - y_bar_train) / y_std_train + # Set initial value for the mu forest init_mu <- mean(resid_train) - + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau - if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) - if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) - if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) + } + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) + } + if (is.null(b_leaf_mu)) { + b_leaf_mu <- var(resid_train) / (num_trees_mu) + } + if (is.null(b_leaf_tau)) { + b_leaf_tau <- var(resid_train) / (2 * num_trees_tau) + } if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2.0*var(resid_train)/(num_trees_mu) + sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu) current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { if (!is.matrix(sigma2_leaf_mu)) { @@ -844,20 +1097,34 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } if (is.null(sigma2_leaf_tau)) { - sigma2_leaf_tau <- var(resid_train)/(num_trees_tau) - current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) + sigma2_leaf_tau <- var(resid_train) / (num_trees_tau) + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) } else { if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag(sigma2_leaf_tau, ncol(Z_train))) + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) stop("sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) + } + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) + } current_leaf_scale_tau <- sigma2_leaf_tau } } current_sigma2 <- sigma2_init } - + # Set mu and tau leaf models / dimensions leaf_model_mu_forest <- 0 leaf_dimension_mu_forest <- 1 @@ -868,11 +1135,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id leaf_model_tau_forest <- 1 leaf_dimension_tau_forest <- 1 } - + # Set variance leaf model type (currently only one option) leaf_model_variance_forest <- 3 leaf_dimension_variance_forest <- 1 - + # Random effects prior parameters if (has_rfx) { # Prior parameters @@ -880,65 +1147,111 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (num_rfx_components == 1) { alpha_init <- c(0) } else if (num_rfx_components > 1) { - alpha_init <- rep(0,num_rfx_components) + alpha_init <- rep(0, num_rfx_components) } else { stop("There must be at least 1 random effect component") } } else { - alpha_init <- expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) } - + if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups) + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) } else { - xi_init <- expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - + if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) } else { - sigma_alpha_init <- expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) } - + if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components) + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) } else { - sigma_xi_init <- expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) } - + sigma_xi_shape <- rfx_variance_prior_shape sigma_xi_scale <- rfx_variance_prior_scale } - + # Random effects data structure and storage container if (has_rfx) { - rfx_dataset_train <- createRandomEffectsDataset(rfx_group_ids_train, rfx_basis_train) + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel(num_rfx_components, num_rfx_groups) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) rfx_model$set_working_parameter(alpha_init) rfx_model$set_group_parameters(xi_init) rfx_model$set_working_parameter_cov(sigma_alpha_init) rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples(num_rfx_components, num_rfx_groups, rfx_tracker_train) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) } - + # Container of variance parameter samples num_actual_mcmc_iter <- num_mcmc * keep_every num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter # Delete GFR samples from these containers after the fact if desired # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains - if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples) - if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples) - if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples) + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_samples <- rep(NA, num_retained_samples) + } muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) - if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) + if (include_variance_forest) { + sigma2_x_train_raw <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } sample_counter <- 0 # Prepare adaptive coding structure - if ((!is.numeric(b_0)) || (!is.numeric(b_1)) || (length(b_0) > 1) || (length(b_1) > 1)) { + if ( + (!is.numeric(b_0)) || + (!is.numeric(b_1)) || + (length(b_0) > 1) || + (length(b_1) > 1) + ) { stop("b_0 and b_1 must be single numeric values") } if (adaptive_coding) { @@ -946,485 +1259,1016 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id b_1_samples <- rep(NA, num_retained_samples) current_b_0 <- b_0 current_b_1 <- b_1 - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 - if (has_test) tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 + if (has_test) { + tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1 + } } else { tau_basis_train <- Z_train if (has_test) tau_basis_test <- Z_test } - + # Data forest_dataset_train <- createForestDataset(X_train, tau_basis_train) - if (has_test) forest_dataset_test <- createForestDataset(X_test, tau_basis_test) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, tau_basis_test) + } outcome_train <- createOutcome(resid_train) - + # Random number generator (std::mt19937) - if (is.null(random_seed)) random_seed = sample(1:10000,1,FALSE) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } rng <- createCppRNG(random_seed) - + # Sampling data structures - global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) - forest_model_config_mu <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mu, num_features=ncol(X_train), - num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest, - alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu, - leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu, - cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_mu) - forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train), - num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest, - alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau, - leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau, - cutpoint_grid_size=cutpoint_grid_size, num_features_subsample = num_features_subsample_tau) - forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu, global_model_config) - forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau, global_model_config) + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + forest_model_config_mu <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mu, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mu, + leaf_dimension = leaf_dimension_mu_forest, + alpha = alpha_mu, + beta = beta_mu, + min_samples_leaf = min_samples_leaf_mu, + max_depth = max_depth_mu, + leaf_model_type = leaf_model_mu_forest, + leaf_model_scale = current_leaf_scale_mu, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mu + ) + forest_model_config_tau <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_tau, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_tau, + leaf_dimension = leaf_dimension_tau_forest, + alpha = alpha_tau, + beta = beta_tau, + min_samples_leaf = min_samples_leaf_tau, + max_depth = max_depth_tau, + leaf_model_type = leaf_model_tau_forest, + leaf_model_scale = current_leaf_scale_tau, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_tau + ) + forest_model_mu <- createForestModel( + forest_dataset_train, + forest_model_config_mu, + global_model_config + ) + forest_model_tau <- createForestModel( + forest_dataset_train, + forest_model_config_tau, + global_model_config + ) if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), - num_observations=nrow(X_train), variable_weights=variable_weights_variance, - leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance, - min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance, - leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size, - num_features_subsample=num_features_subsample_variance) - forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) - } - + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = leaf_dimension_variance_forest, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + # Container of forest samples forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) - forest_samples_tau <- createForestSamples(num_trees_tau, ncol(Z_train), FALSE) + forest_samples_tau <- createForestSamples( + num_trees_tau, + ncol(Z_train), + FALSE + ) active_forest_mu <- createForest(num_trees_mu, 1, TRUE) active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) if (include_variance_forest) { - forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE) - active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE) + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) } - + # Initialize the leaves of each tree in the prognostic forest - active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, leaf_model_mu_forest, init_mu) - active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, FALSE, FALSE) + active_forest_mu$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mu, + leaf_model_mu_forest, + init_mu + ) + active_forest_mu$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mu, + FALSE, + FALSE + ) # Initialize the leaves of each tree in the treatment effect forest init_tau <- rep(0., ncol(Z_train)) - active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, leaf_model_tau_forest, init_tau) - active_forest_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, TRUE, FALSE) + active_forest_tau$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_tau, + leaf_model_tau_forest, + init_tau + ) + active_forest_tau$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_tau, + TRUE, + FALSE + ) # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { - active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) } # Run GFR (warm start) if specified - if (num_gfr > 0){ + if (num_gfr > 0) { for (i in 1:num_gfr) { # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) keep_sample <- TRUE - if (keep_sample) sample_counter <- sample_counter + 1 + if (keep_sample) { + sample_counter <- sample_counter + 1 + } # Print progress if (verbose) { if ((i %% 10 == 0) || (i == num_gfr)) { - cat("Sampling", i, "out of", num_gfr, "XBCF (grow-from-root) draws\n") + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBCF (grow-from-root) draws\n" + ) } } - + if (probit_outcome_model) { # Sample latent probit variable, z | - mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) - tau_forest_pred <- active_forest_tau$predict(forest_dataset_train) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train + ) forest_pred <- mu_forest_pred + tau_forest_pred mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train==0] <- mu0 + qnorm(u0) - resid_train[y_train==1] <- mu1 + qnorm(u1) - + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + # Update outcome outcome_train$update_data(resid_train - forest_pred) } - + # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, - active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions() + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() } - + # Sample variance parameters (if requested) if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } - + # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, - active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - - # Cannot cache train set predictions for tau because the cached predictions in the + + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... # Sample coding parameters (if requested) if (adaptive_coding) { # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw(forest_dataset_train) - tau_x_raw_train <- active_forest_tau$predict_raw(forest_dataset_train) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) partial_resid_mu_train <- resid_train - mu_x_raw_train if (has_rfx) { - rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) - partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train } - + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) - s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) - s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) - + s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) + s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) + s_ty0 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) + ) + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) - current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) - + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) + # Update basis for the leaf regression - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 forest_dataset_train$update_basis(tau_basis_train) if (keep_sample) { b_0_samples[sample_counter] <- current_b_0 b_1_samples[sample_counter] <- current_b_1 } if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 forest_dataset_test$update_basis(tau_basis_test) } - + # Update leaf predictions and residual - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } - + # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = TRUE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } } if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } - + # Sample random effects parameters (if requested) if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } } } - + # Run MCMC if (num_burnin + num_mcmc > 0) { for (chain_num in 1:num_chains) { if (num_gfr > 0) { # Reset state of active_forest and forest_model based on a previous GFR sample forest_ind <- num_gfr - chain_num - resetActiveForest(active_forest_mu, forest_samples_mu, forest_ind) - resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) - resetActiveForest(active_forest_tau, forest_samples_tau, forest_ind) - resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + resetActiveForest( + active_forest_mu, + forest_samples_mu, + forest_ind + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + forest_samples_tau, + forest_ind + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- leaf_scale_mu_samples[forest_ind + 1] + leaf_scale_mu_double <- leaf_scale_mu_samples[ + forest_ind + 1 + ] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- leaf_scale_tau_samples[forest_ind + 1] + leaf_scale_tau_double <- leaf_scale_tau_samples[ + forest_ind + 1 + ] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } if (include_variance_forest) { - resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } if (has_rfx) { - resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } if (adaptive_coding) { current_b_1 <- b_1_samples[forest_ind + 1] current_b_0 <- b_0_samples[forest_ind + 1] - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 forest_dataset_train$update_basis(tau_basis_train) if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 forest_dataset_test$update_basis(tau_basis_test) } - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } if (sample_sigma2_global) { current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance( + current_sigma2 + ) } } else if (has_prev_model) { - resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) - resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) - resetActiveForest(active_forest_tau, previous_forest_samples_tau, previous_model_warmstart_sample_num - 1) - resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + resetActiveForest( + active_forest_mu, + previous_forest_samples_mu, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + previous_forest_samples_tau, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (include_variance_forest) { - resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } - if (sample_sigma2_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { - leaf_scale_mu_double <- previous_leaf_var_mu_samples[previous_model_warmstart_sample_num] + if ( + sample_sigma2_leaf_mu && + (!is.null(previous_leaf_var_mu_samples)) + ) { + leaf_scale_mu_double <- previous_leaf_var_mu_samples[ + previous_model_warmstart_sample_num + ] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } - if (sample_sigma2_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { - leaf_scale_tau_double <- previous_leaf_var_tau_samples[previous_model_warmstart_sample_num] + if ( + sample_sigma2_leaf_tau && + (!is.null(previous_leaf_var_tau_samples)) + ) { + leaf_scale_tau_double <- previous_leaf_var_tau_samples[ + previous_model_warmstart_sample_num + ] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } if (adaptive_coding) { if (!is.null(previous_b_1_samples)) { - current_b_1 <- previous_b_1_samples[previous_model_warmstart_sample_num] + current_b_1 <- previous_b_1_samples[ + previous_model_warmstart_sample_num + ] } if (!is.null(previous_b_0_samples)) { - current_b_0 <- previous_b_0_samples[previous_model_warmstart_sample_num] + current_b_0 <- previous_b_0_samples[ + previous_model_warmstart_sample_num + ] } - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 forest_dataset_train$update_basis(tau_basis_train) if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 forest_dataset_test$update_basis(tau_basis_test) } - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } if (has_rfx) { if (is.null(previous_rfx_samples)) { - warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") - rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } else { - resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) - resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } } if (sample_sigma2_global) { if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num + ] } - global_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance( + current_sigma2 + ) } } else { resetActiveForest(active_forest_mu) active_forest_mu$set_root_leaves(init_mu / num_trees_mu) - resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) resetActiveForest(active_forest_tau) active_forest_tau$set_root_leaves(init_tau / num_trees_tau) - resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf_mu) { current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } if (sample_sigma2_leaf_tau) { current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) - forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } if (include_variance_forest) { resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves(log(variance_forest_init) / num_trees_variance) - resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } if (has_rfx) { - rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } if (adaptive_coding) { current_b_1 <- b_1 current_b_0 <- b_0 - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 forest_dataset_train$update_basis(tau_basis_train) if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 forest_dataset_test$update_basis(tau_basis_test) } - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } if (sample_sigma2_global) { current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance(current_sigma2) + global_model_config$update_global_error_variance( + current_sigma2 + ) } } - for (i in (num_gfr+1):num_samples) { + for (i in (num_gfr + 1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) if (is_mcmc) { mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) keep_sample <- TRUE - else keep_sample <- FALSE + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } else { - if (keep_burnin) keep_sample <- TRUE - else keep_sample <- FALSE + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } + } + if (keep_sample) { + sample_counter <- sample_counter + 1 } - if (keep_sample) sample_counter <- sample_counter + 1 # Print progress if (verbose) { if (num_burnin > 0) { - if (((i - num_gfr) %% 100 == 0) || ((i - num_gfr) == num_burnin)) { - cat("Sampling", i - num_gfr, "out of", num_gfr, "BCF burn-in draws\n") + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_gfr, + "BCF burn-in draws\n" + ) } } if (num_mcmc > 0) { - if (((i - num_gfr - num_burnin) %% 100 == 0) || (i == num_samples)) { - cat("Sampling", i - num_burnin - num_gfr, "out of", num_mcmc, "BCF MCMC draws\n") + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BCF MCMC draws\n" + ) } } } - + if (probit_outcome_model) { # Sample latent probit variable, z | - - mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) - tau_forest_pred <- active_forest_tau$predict(forest_dataset_train) + mu_forest_pred <- active_forest_mu$predict( + forest_dataset_train + ) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train + ) forest_pred <- mu_forest_pred + tau_forest_pred mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train==0] <- mu0 + qnorm(u0) - resid_train[y_train==1] <- mu1 + qnorm(u1) - + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + # Update outcome outcome_train$update_data(resid_train - forest_pred) } - + # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, - active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions() + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() } - + # Sample variance parameters (if requested) if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - global_model_config$update_global_error_variance(current_sigma2) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + global_model_config$update_global_error_variance( + current_sigma2 + ) } if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double - forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } - + # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, - active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE ) - - # Cannot cache train set predictions for tau because the cached predictions in the + + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... - + # Sample coding parameters (if requested) if (adaptive_coding) { # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw(forest_dataset_train) - tau_x_raw_train <- active_forest_tau$predict_raw(forest_dataset_train) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) partial_resid_mu_train <- resid_train - mu_x_raw_train if (has_rfx) { - rfx_preds_train <- rfx_model$predict(rfx_dataset_train, rfx_tracker_train) - partial_resid_mu_train <- partial_resid_mu_train - rfx_preds_train + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train } - + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==0)) - s_tt1 <- sum(tau_x_raw_train*tau_x_raw_train*(Z_train==1)) - s_ty0 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==0)) - s_ty1 <- sum(tau_x_raw_train*partial_resid_mu_train*(Z_train==1)) - + s_tt0 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 0) + ) + s_tt1 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 1) + ) + s_ty0 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 1) + ) + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm(1, (s_ty0/(s_tt0 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt0 + 2*current_sigma2))) - current_b_1 <- rnorm(1, (s_ty1/(s_tt1 + 2*current_sigma2)), sqrt(current_sigma2/(s_tt1 + 2*current_sigma2))) - + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) + # Update basis for the leaf regression - tau_basis_train <- (1-Z_train)*current_b_0 + Z_train*current_b_1 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 forest_dataset_train$update_basis(tau_basis_train) if (keep_sample) { b_0_samples[sample_counter] <- current_b_0 b_1_samples[sample_counter] <- current_b_1 } if (has_test) { - tau_basis_test <- (1-Z_test)*current_b_0 + Z_test*current_b_1 + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 forest_dataset_test$update_basis(tau_basis_test) } - + # Update leaf predictions and residual - forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } - + # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, - active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, num_threads = num_threads, keep_forest = keep_sample, gfr = FALSE + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE ) - + # Cache train set predictions since they are already computed during sampling if (keep_sample) { - sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions() + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } } if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) - if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 - global_model_config$update_global_error_variance(current_sigma2) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double - forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double + } + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } - + # Sample random effects parameters (if requested) if (has_rfx) { - rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } } } } - + # Remove GFR samples if they are not to be retained if ((!keep_gfr) && (num_gfr > 0)) { for (i in 1:num_gfr) { @@ -1438,60 +2282,96 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } if (sample_sigma2_global) { - global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)] + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf_mu) { - leaf_scale_mu_samples <- leaf_scale_mu_samples[(num_gfr+1):length(leaf_scale_mu_samples)] + leaf_scale_mu_samples <- leaf_scale_mu_samples[ + (num_gfr + 1):length(leaf_scale_mu_samples) + ] } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_samples <- leaf_scale_tau_samples[(num_gfr+1):length(leaf_scale_tau_samples)] + leaf_scale_tau_samples <- leaf_scale_tau_samples[ + (num_gfr + 1):length(leaf_scale_tau_samples) + ] } if (adaptive_coding) { - b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)] - b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)] + b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] + b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] } - muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)] + muhat_train_raw <- muhat_train_raw[, + (num_gfr + 1):ncol(muhat_train_raw) + ] if (include_variance_forest) { - sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)] + sigma2_x_train_raw <- sigma2_x_train_raw[, + (num_gfr + 1):ncol(sigma2_x_train_raw) + ] } num_retained_samples <- num_retained_samples - num_gfr } # Forest predictions - mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train + mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train if (adaptive_coding) { - tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train) - tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train + tau_hat_train_raw <- forest_samples_tau$predict_raw( + forest_dataset_train + ) + tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) * + y_std_train } else { - tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train)*y_std_train + tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * + y_std_train } if (has_multivariate_treatment) { tau_train_dim <- dim(tau_hat_train) tau_num_obs <- tau_train_dim[1] tau_num_samples <- tau_train_dim[3] - treatment_term_train <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + treatment_term_train <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) for (i in 1:nrow(Z_train)) { - treatment_term_train[i,] <- colSums(tau_hat_train[i,,] * Z_train[i,]) + treatment_term_train[i, ] <- colSums( + tau_hat_train[i, , ] * Z_train[i, ] + ) } } else { treatment_term_train <- tau_hat_train * as.numeric(Z_train) } y_hat_train <- mu_hat_train + treatment_term_train if (has_test) { - mu_hat_test <- forest_samples_mu$predict(forest_dataset_test)*y_std_train + y_bar_train + mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) * + y_std_train + + y_bar_train if (adaptive_coding) { - tau_hat_test_raw <- forest_samples_tau$predict_raw(forest_dataset_test) - tau_hat_test <- t(t(tau_hat_test_raw) * (b_1_samples - b_0_samples))*y_std_train + tau_hat_test_raw <- forest_samples_tau$predict_raw( + forest_dataset_test + ) + tau_hat_test <- t( + t(tau_hat_test_raw) * (b_1_samples - b_0_samples) + ) * + y_std_train } else { - tau_hat_test <- forest_samples_tau$predict_raw(forest_dataset_test)*y_std_train + tau_hat_test <- forest_samples_tau$predict_raw( + forest_dataset_test + ) * + y_std_train } if (has_multivariate_treatment) { tau_test_dim <- dim(tau_hat_test) tau_num_obs <- tau_test_dim[1] tau_num_samples <- tau_test_dim[3] - treatment_term_test <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + treatment_term_test <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) for (i in 1:nrow(Z_test)) { - treatment_term_test[i,] <- colSums(tau_hat_test[i,,] * Z_test[i,]) + treatment_term_test[i, ] <- colSums( + tau_hat_test[i, , ] * Z_test[i, ] + ) } } else { treatment_term_test <- tau_hat_test * as.numeric(Z_test) @@ -1500,39 +2380,74 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } if (include_variance_forest) { sigma2_x_hat_train <- exp(sigma2_x_train_raw) - if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } } # Random effects predictions if (has_rfx) { - rfx_preds_train <- rfx_samples$predict(rfx_group_ids_train, rfx_basis_train)*y_std_train + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train y_hat_train <- y_hat_train + rfx_preds_train } if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict(rfx_group_ids_test, rfx_basis_test)*y_std_train + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train y_hat_test <- y_hat_test + rfx_preds_test } - + # Global error variance - if (sample_sigma2_global) sigma2_global_samples <- global_var_samples*(y_std_train^2) - + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + # Leaf parameter variance for prognostic forest - if (sample_sigma2_leaf_mu) sigma2_leaf_mu_samples <- leaf_scale_mu_samples - + if (sample_sigma2_leaf_mu) { + sigma2_leaf_mu_samples <- leaf_scale_mu_samples + } + # Leaf parameter variance for treatment effect forest - if (sample_sigma2_leaf_tau) sigma2_leaf_tau_samples <- leaf_scale_tau_samples - + if (sample_sigma2_leaf_tau) { + sigma2_leaf_tau_samples <- leaf_scale_tau_samples + } + # Rescale variance forest prediction by global sigma2 (sampled or constant) if (include_variance_forest) { if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_train[,i]*sigma2_global_samples[i]) - if (has_test) sigma2_x_hat_test <- sapply(1:num_retained_samples, function(i) sigma2_x_hat_test[,i]*sigma2_global_samples[i]) + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } } else { - sigma2_x_hat_train <- sigma2_x_hat_train*sigma2_init*y_std_train*y_std_train - if (has_test) sigma2_x_hat_test <- sigma2_x_hat_test*sigma2_init*y_std_train*y_std_train + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } } } - + # Return results as a list if (include_variance_forest) { num_variance_covariates <- sum(variable_weights_variance > 0) @@ -1540,67 +2455,79 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id num_variance_covariates <- 0 } model_params <- list( - "initial_sigma2" = sigma2_init, + "initial_sigma2" = sigma2_init, "initial_sigma2_leaf_mu" = sigma2_leaf_mu, "initial_sigma2_leaf_tau" = sigma2_leaf_tau, "initial_b_0" = b_0, "initial_b_1" = b_1, "a_global" = a_global, "b_global" = b_global, - "a_leaf_mu" = a_leaf_mu, + "a_leaf_mu" = a_leaf_mu, "b_leaf_mu" = b_leaf_mu, - "a_leaf_tau" = a_leaf_tau, + "a_leaf_tau" = a_leaf_tau, "b_leaf_tau" = b_leaf_tau, - "a_forest" = a_forest, + "a_forest" = a_forest, "b_forest" = b_forest, "outcome_mean" = y_bar_train, "outcome_scale" = y_std_train, - "standardize" = standardize, + "standardize" = standardize, "num_covariates" = num_cov_orig, "num_prognostic_covariates" = sum(variable_weights_mu > 0), "num_treatment_covariates" = sum(variable_weights_tau > 0), "num_variance_covariates" = num_variance_covariates, - "treatment_dim" = ncol(Z_train), - "propensity_covariate" = propensity_covariate, - "binary_treatment" = binary_treatment, - "multivariate_treatment" = has_multivariate_treatment, - "adaptive_coding" = adaptive_coding, - "internal_propensity_model" = internal_propensity_model, - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, + "treatment_dim" = ncol(Z_train), + "propensity_covariate" = propensity_covariate, + "binary_treatment" = binary_treatment, + "multivariate_treatment" = has_multivariate_treatment, + "adaptive_coding" = adaptive_coding, + "internal_propensity_model" = internal_propensity_model, + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, "keep_every" = keep_every, "num_chains" = num_chains, - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "include_variance_forest" = include_variance_forest, + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "include_variance_forest" = include_variance_forest, "sample_sigma2_global" = sample_sigma2_global, "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, - "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, + "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, "probit_outcome_model" = probit_outcome_model ) result <- list( - "forests_mu" = forest_samples_mu, - "forests_tau" = forest_samples_tau, - "model_params" = model_params, - "mu_hat_train" = mu_hat_train, - "tau_hat_train" = tau_hat_train, - "y_hat_train" = y_hat_train, + "forests_mu" = forest_samples_mu, + "forests_tau" = forest_samples_tau, + "model_params" = model_params, + "mu_hat_train" = mu_hat_train, + "tau_hat_train" = tau_hat_train, + "y_hat_train" = y_hat_train, "train_set_metadata" = X_train_metadata ) - if (has_test) result[["mu_hat_test"]] = mu_hat_test - if (has_test) result[["tau_hat_test"]] = tau_hat_test - if (has_test) result[["y_hat_test"]] = y_hat_test + if (has_test) { + result[["mu_hat_test"]] = mu_hat_test + } + if (has_test) { + result[["tau_hat_test"]] = tau_hat_test + } + if (has_test) { + result[["y_hat_test"]] = y_hat_test + } if (include_variance_forest) { result[["forests_variance"]] = forest_samples_variance result[["sigma2_x_hat_train"]] = sigma2_x_hat_train if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test } - if (sample_sigma2_global) result[["sigma2_global_samples"]] = sigma2_global_samples - if (sample_sigma2_leaf_mu) result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples - if (sample_sigma2_leaf_tau) result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf_mu) { + result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + } + if (sample_sigma2_leaf_tau) { + result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + } if (adaptive_coding) { result[["b_0_samples"]] = b_0_samples result[["b_1_samples"]] = b_1_samples @@ -1610,12 +2537,14 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id result[["rfx_preds_train"]] = rfx_preds_train result[["rfx_unique_group_ids"]] = levels(group_ids_factor) } - if ((has_rfx_test) && (has_test)) result[["rfx_preds_test"]] = rfx_preds_test + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } if (internal_propensity_model) { result[["bart_propensity_model"]] = bart_model_propensity } class(result) <- "bcfmodel" - + return(result) } @@ -1625,7 +2554,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id #' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param Z Treatments used for prediction. #' @param propensity (Optional) Propensities used for prediction. -#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. +#' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. @@ -1639,21 +2568,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -1676,18 +2605,26 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id #' mu_train <- mu_x[train_inds] #' tau_test <- tau_x[test_inds] #' tau_train <- tau_x[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, num_gfr = 10, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) -predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, ...){ +predict.bcfmodel <- function( + object, + X, + Z, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + ... +) { # Preprocess covariates if ((!is.data.frame(X)) && (!is.matrix(X))) { stop("X must be a matrix or dataframe") } train_set_metadata <- object$train_set_metadata X <- preprocessPredictionData(X, train_set_metadata) - + # Convert all input data to matrices if not already converted if ((is.null(dim(Z))) && (!is.null(Z))) { Z <- as.matrix(as.numeric(Z)) @@ -1698,9 +2635,12 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { rfx_basis <- as.matrix(rfx_basis) } - + # Data checks - if ((object$model_params$propensity_covariate != "none") && (is.null(propensity))) { + if ( + (object$model_params$propensity_covariate != "none") && + (is.null(propensity)) + ) { if (!object$model_params$internal_propensity_model) { stop("propensity must be provided for this model") } @@ -1711,25 +2651,36 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU stop("X and Z must have the same number of rows") } if (object$model_params$num_covariates != ncol(X)) { - stop("X and must have the same number of columns as the covariates used to train the model") + stop( + "X and must have the same number of columns as the covariates used to train the model" + ) } if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { - stop("Random effect group labels (rfx_group_ids) must be provided for this model") + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) } if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { stop("Random effects basis (rfx_basis) must be provided for this model") } - if ((object$model_params$num_rfx_basis > 0) && (ncol(rfx_basis) != object$model_params$num_rfx_basis)) { - stop("Random effects basis has a different dimension than the basis used to train this model") + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) } - + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE if (!is.null(rfx_group_ids)) { rfx_unique_group_ids <- object$rfx_unique_group_ids group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) if (sum(is.na(group_ids_factor)) > 0) { - stop("All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train") + stop( + "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" + ) } rfx_group_ids <- as.integer(group_ids_factor) has_rfx <- TRUE @@ -1739,12 +2690,12 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) } - + # Add propensities to covariate set if necessary if (object$model_params$propensity_covariate != "none") { X_combined <- cbind(X, propensity) } - + # Create prediction datasets forest_dataset_pred <- createForestDataset(X_combined, Z) @@ -1753,12 +2704,15 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU y_std <- object$model_params$outcome_scale y_bar <- object$model_params$outcome_mean initial_sigma2 <- object$model_params$initial_sigma2 - mu_hat <- object$forests_mu$predict(forest_dataset_pred)*y_std + y_bar + mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar if (object$model_params$adaptive_coding) { tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) - tau_hat <- t(t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples))*y_std + tau_hat <- t( + t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) + ) * + y_std } else { - tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred)*y_std + tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std } if (object$model_params$multivariate_treatment) { tau_dim <- dim(tau_hat) @@ -1766,7 +2720,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU tau_num_samples <- tau_dim[3] treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) for (i in 1:nrow(Z)) { - treatment_term[i,] <- colSums(tau_hat[i,,] * Z[i,]) + treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) } } else { treatment_term <- tau_hat * as.numeric(Z) @@ -1774,29 +2728,40 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU if (object$model_params$include_variance_forest) { s_x_raw <- object$forests_variance$predict(forest_dataset_pred) } - + # Compute rfx predictions (if needed) if (object$model_params$has_rfx) { - rfx_predictions <- object$rfx_samples$predict(rfx_group_ids, rfx_basis)*y_std + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std } - + # Compute overall "y_hat" predictions y_hat <- mu_hat + treatment_term - if (object$model_params$has_rfx) y_hat <- y_hat + rfx_predictions - + if (object$model_params$has_rfx) { + y_hat <- y_hat + rfx_predictions + } + # Scale variance forest predictions if (object$model_params$include_variance_forest) { if (object$model_params$sample_sigma2_global) { sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) s_x_raw[,i]*sigma2_global_samples[i]) + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) } else { - variance_forest_predictions <- s_x_raw*initial_sigma2*y_std*y_std + variance_forest_predictions <- s_x_raw * + initial_sigma2 * + y_std * + y_std } } result <- list( - "mu_hat" = mu_hat, - "tau_hat" = tau_hat, + "mu_hat" = mu_hat, + "tau_hat" = tau_hat, "y_hat" = y_hat ) if (object$model_params$has_rfx) { @@ -1826,21 +2791,21 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -1876,34 +2841,37 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' rfx_samples <- getRandomEffectSamples(bcf_model) -getRandomEffectSamples.bcfmodel <- function(object, ...){ +getRandomEffectSamples.bcfmodel <- function(object, ...) { result = list() - + if (!object$model_params$has_rfx) { warning("This model has no RFX terms, returning an empty list") return(result) } - + # Extract the samples result <- object$rfx_samples$extract_parameter_samples() - + # Scale by sd(y_train) - result$beta_samples <- result$beta_samples*object$model_params$outcome_scale - result$xi_samples <- result$xi_samples*object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples*object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples*(object$model_params$outcome_scale^2) - + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) + return(result) } @@ -1919,21 +2887,21 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -1969,24 +2937,24 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' bcf_json <- saveBCFModelToJson(bcf_model) -saveBCFModelToJson <- function(object){ +saveBCFModelToJson <- function(object) { jsonobj <- createCppJson() - + if (!inherits(object, "bcfmodel")) { stop("`object` must be a BCF model") } - + if (is.null(object$model_params)) { stop("This BCF model has not yet been sampled") } @@ -1997,39 +2965,84 @@ saveBCFModelToJson <- function(object){ if (object$model_params$include_variance_forest) { jsonobj$add_forest(object$forests_variance) } - + # Add metadata - jsonobj$add_scalar("num_numeric_vars", object$train_set_metadata$num_numeric_vars) - jsonobj$add_scalar("num_ordered_cat_vars", object$train_set_metadata$num_ordered_cat_vars) - jsonobj$add_scalar("num_unordered_cat_vars", object$train_set_metadata$num_unordered_cat_vars) + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector("numeric_vars", object$train_set_metadata$numeric_vars) + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars + ) } if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector("ordered_cat_vars", object$train_set_metadata$ordered_cat_vars) - jsonobj$add_string_list("ordered_unique_levels", object$train_set_metadata$ordered_unique_levels) + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars + ) + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels + ) } if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector("unordered_cat_vars", object$train_set_metadata$unordered_cat_vars) - jsonobj$add_string_list("unordered_unique_levels", object$train_set_metadata$unordered_unique_levels) + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars + ) + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels + ) } - + # Add global parameters jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) - jsonobj$add_boolean("sample_sigma2_global", object$model_params$sample_sigma2_global) - jsonobj$add_boolean("sample_sigma2_leaf_mu", object$model_params$sample_sigma2_leaf_mu) - jsonobj$add_boolean("sample_sigma2_leaf_tau", object$model_params$sample_sigma2_leaf_tau) - jsonobj$add_boolean("include_variance_forest", object$model_params$include_variance_forest) - jsonobj$add_string("propensity_covariate", object$model_params$propensity_covariate) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_mu", + object$model_params$sample_sigma2_leaf_mu + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_tau", + object$model_params$sample_sigma2_leaf_tau + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) + jsonobj$add_string( + "propensity_covariate", + object$model_params$propensity_covariate + ) jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) - jsonobj$add_boolean("multivariate_treatment", object$model_params$multivariate_treatment) + jsonobj$add_boolean( + "multivariate_treatment", + object$model_params$multivariate_treatment + ) jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) - jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model) + jsonobj$add_boolean( + "internal_propensity_model", + object$model_params$internal_propensity_model + ) jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) @@ -2037,15 +3050,30 @@ saveBCFModelToJson <- function(object){ jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" + ) } if (object$model_params$sample_sigma2_leaf_mu) { - jsonobj$add_vector("sigma2_leaf_mu_samples", object$sigma2_leaf_mu_samples, "parameters") + jsonobj$add_vector( + "sigma2_leaf_mu_samples", + object$sigma2_leaf_mu_samples, + "parameters" + ) } if (object$model_params$sample_sigma2_leaf_tau) { - jsonobj$add_vector("sigma2_leaf_tau_samples", object$sigma2_leaf_tau_samples, "parameters") + jsonobj$add_vector( + "sigma2_leaf_tau_samples", + object$sigma2_leaf_tau_samples, + "parameters" + ) } if (object$model_params$adaptive_coding) { jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") @@ -2055,9 +3083,12 @@ saveBCFModelToJson <- function(object){ # Add random effects (if present) if (object$model_params$has_rfx) { jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids + ) } - + # Add propensity model (if it exists) if (object$model_params$internal_propensity_model) { bart_propensity_string <- saveBARTModelToJsonString( @@ -2065,7 +3096,7 @@ saveBCFModelToJson <- function(object){ ) jsonobj$add_string("bart_propensity_model", bart_propensity_string) } - + # Add covariate preprocessor metadata preprocessor_metadata_string <- savePreprocessorToJsonString( object$train_set_metadata @@ -2088,21 +3119,21 @@ saveBCFModelToJson <- function(object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2138,23 +3169,23 @@ saveBCFModelToJson <- function(object){ #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' tmpjson <- tempfile(fileext = ".json") #' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) #' unlink(tmpjson) -saveBCFModelToJsonFile <- function(object, filename){ +saveBCFModelToJsonFile <- function(object, filename) { # Convert to Json jsonobj <- saveBCFModelToJson(object) - + # Save to file jsonobj$save_file(filename) } @@ -2170,21 +3201,21 @@ saveBCFModelToJsonFile <- function(object, filename){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2220,26 +3251,26 @@ saveBCFModelToJsonFile <- function(object, filename){ #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' saveBCFModelToJsonString(bcf_model) -saveBCFModelToJsonString <- function(object){ +saveBCFModelToJsonString <- function(object) { # Convert to Json jsonobj <- saveBCFModelToJson(object) - + # Dump to string return(jsonobj$return_json_string()) } -#' Convert an (in-memory) JSON representation of a BCF model to a BCF model object +#' Convert an (in-memory) JSON representation of a BCF model to a BCF model object #' which can be used for prediction, etc... #' #' @param json_object Object of type `CppJson` containing Json representation of a BCF model @@ -2252,21 +3283,21 @@ saveBCFModelToJsonString <- function(object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2302,45 +3333,72 @@ saveBCFModelToJsonString <- function(object){ #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' bcf_json <- saveBCFModelToJson(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) -createBCFModelFromJson <- function(json_object){ +createBCFModelFromJson <- function(json_object) { # Initialize the BCF model output <- list() - + # Unpack the forests output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") - include_variance_forest <- json_object$get_boolean("include_variance_forest") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerJson(json_object, "forest_2") + output[["forests_variance"]] <- loadForestContainerJson( + json_object, + "forest_2" + ) } # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" + ) } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata @@ -2350,56 +3408,93 @@ createBCFModelFromJson <- function(json_object){ model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean("sample_sigma2_leaf_mu") - model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean("sample_sigma2_leaf_tau") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean( + "sample_sigma2_leaf_tau" + ) model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object$get_string("propensity_covariate") + model_params[["propensity_covariate"]] <- json_object$get_string( + "propensity_covariate" + ) model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding") - model_params[["multivariate_treatment"]] <- json_object$get_boolean("multivariate_treatment") - model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model") + model_params[["adaptive_coding"]] <- json_object$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object$get_boolean( + "multivariate_treatment" + ) + model_params[["internal_propensity_model"]] <- json_object$get_boolean( + "internal_propensity_model" + ) model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model") + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } if (model_params[["sample_sigma2_leaf_mu"]]) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) } if (model_params[["sample_sigma2_leaf_tau"]]) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) } if (model_params[["adaptive_coding"]]) { - output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) } - + # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector("rfx_unique_group_ids") + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" + ) output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) } - + # Unpack propensity model (if it exists) if (model_params[["internal_propensity_model"]]) { - bart_propensity_string <- json_object$get_string("bart_propensity_model") + bart_propensity_string <- json_object$get_string( + "bart_propensity_model" + ) output[["bart_propensity_model"]] <- createBARTModelFromJsonString( bart_propensity_string ) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) @@ -2408,7 +3503,7 @@ createBCFModelFromJson <- function(json_object){ return(output) } -#' Convert a JSON file containing sample information on a trained BCF model +#' Convert a JSON file containing sample information on a trained BCF model #' to a BCF model object which can be used for prediction, etc... #' #' @param json_filename String of filepath, must end in ".json" @@ -2421,21 +3516,21 @@ createBCFModelFromJson <- function(json_object){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2471,31 +3566,31 @@ createBCFModelFromJson <- function(json_object){ #' rfx_term_train <- rfx_term[train_inds] #' mu_params <- list(sample_sigma2_leaf = TRUE) #' tau_params <- list(sample_sigma2_leaf = FALSE) -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, -#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, -#' prognostic_forest_params = mu_params, +#' rfx_basis_test = rfx_basis_test, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10, +#' prognostic_forest_params = mu_params, #' treatment_effect_forest_params = tau_params) #' tmpjson <- tempfile(fileext = ".json") #' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) #' bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) -createBCFModelFromJsonFile <- function(json_filename){ +createBCFModelFromJsonFile <- function(json_filename) { # Load a `CppJson` object from file bcf_json <- createCppJsonFile(json_filename) - + # Create and return the BCF object bcf_object <- createBCFModelFromJson(bcf_json) - + return(bcf_object) } -#' Convert a JSON string containing sample information on a trained BCF model +#' Convert a JSON string containing sample information on a trained BCF model #' to a BCF model object which can be used for prediction, etc... #' #' @param json_string JSON string dump @@ -2508,21 +3603,21 @@ createBCFModelFromJsonFile <- function(json_filename){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2556,27 +3651,27 @@ createBCFModelFromJsonFile <- function(json_filename){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bcf_json <- saveBCFModelToJsonString(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) -createBCFModelFromJsonString <- function(json_string){ +createBCFModelFromJsonString <- function(json_string) { # Load a `CppJson` object from string bcf_json <- createCppJsonString(json_string) - + # Create and return the BCF object bcf_object <- createBCFModelFromJson(bcf_json) - + return(bcf_object) } -#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object +#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object #' which can be used for prediction, etc... #' #' @param json_object_list List of objects of type `CppJson` containing Json representation of a BCF model @@ -2589,21 +3684,21 @@ createBCFModelFromJsonString <- function(json_string){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2637,72 +3732,135 @@ createBCFModelFromJsonString <- function(json_string){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bcf_json_list <- list(saveBCFModelToJson(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) -createBCFModelFromCombinedJson <- function(json_object_list){ +createBCFModelFromCombinedJson <- function(json_object_list) { # Initialize the BCF model output <- list() - - # For scalar / preprocessing details which aren't sample-dependent, + + # For scalar / preprocessing details which aren't sample-dependent, # defer to the first json json_object_default <- json_object_list[[1]] - + # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") - output[["forests_tau"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1") - include_variance_forest <- json_object_default$get_boolean("include_variance_forest") + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson(json_object_list, "forest_2") + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" + ) } - + # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata - + # Unpack model params model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object_default$get_boolean("standardize") - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma2_leaf_mu") - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma2_leaf_tau") + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") - model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment") - model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") - + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + # Combine values that are sample-specific for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] @@ -2710,25 +3868,40 @@ createBCFModelFromCombinedJson <- function(json_object_list){ model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) } else { - prev_json <- json_object_list[[i-1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") } } output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } else { - output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) + ) } } } @@ -2736,9 +3909,18 @@ createBCFModelFromCombinedJson <- function(json_object_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) } else { - output[["sigma2_leaf_mu_samples"]] <- c(output[["sigma2_leaf_mu_samples"]], json_object$get_vector("sigma2_leaf_mu_samples", "parameters")) + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) } } } @@ -2746,9 +3928,18 @@ createBCFModelFromCombinedJson <- function(json_object_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) } else { - output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) } } } @@ -2756,9 +3947,18 @@ createBCFModelFromCombinedJson <- function(json_object_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) } else { - output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) } } } @@ -2766,32 +3966,51 @@ createBCFModelFromCombinedJson <- function(json_object_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) } else { - output[["b_1_samples"]] <- c(output[["b_1_samples"]], json_object$get_vector("b_1_samples", "parameters")) - output[["b_0_samples"]] <- c(output[["b_0_samples"]], json_object$get_vector("b_0_samples", "parameters")) + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) } } } - + # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 + ) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) - + class(output) <- "bcfmodel" return(output) } -#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object +#' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object #' which can be used for prediction, etc... #' #' @param json_string_list List of JSON strings which can be parsed to objects of type `CppJson` containing Json representation of a BCF model @@ -2804,21 +4023,21 @@ createBCFModelFromCombinedJson <- function(json_object_list){ #' p <- 5 #' X <- matrix(runif(n*p), ncol = p) #' mu_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) #' ) #' pi_x <- ( -#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + -#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + -#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + #' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) #' ) #' tau_x <- ( -#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + -#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + -#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + #' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) #' ) #' Z <- rbinom(n, 1, pi_x) @@ -2852,20 +4071,20 @@ createBCFModelFromCombinedJson <- function(json_object_list){ #' rfx_basis_train <- rfx_basis[train_inds,] #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, -#' rfx_group_ids_train = rfx_group_ids_train, -#' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_basis_train = rfx_basis_train, X_test = X_test, +#' Z_test = Z_test, propensity_test = pi_test, #' rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_test = rfx_basis_test, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) -createBCFModelFromCombinedJsonString <- function(json_string_list){ +createBCFModelFromCombinedJsonString <- function(json_string_list) { # Initialize the BCF model output <- list() - + # Convert JSON strings json_object_list <- list() for (i in 1:length(json_string_list)) { @@ -2875,62 +4094,127 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ # We don't support merging BCF models with independent propensity models # this way at the moment if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { - stop("Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling.") + stop( + "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." + ) } } - - # For scalar / preprocessing details which aren't sample-dependent, + + # For scalar / preprocessing details which aren't sample-dependent, # defer to the first json json_object_default <- json_object_list[[1]] - + # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson(json_object_list, "forest_0") - output[["forests_tau"]] <- loadForestContainerCombinedJson(json_object_list, "forest_1") - include_variance_forest <- json_object_default$get_boolean("include_variance_forest") + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson(json_object_list, "forest_2") + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" + ) } # Unpack metadata train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar("num_numeric_vars") - train_set_metadata[["num_ordered_cat_vars"]] <- json_object_default$get_scalar("num_ordered_cat_vars") - train_set_metadata[["num_unordered_cat_vars"]] <- json_object_default$get_scalar("num_unordered_cat_vars") + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object_default$get_string_vector("numeric_vars") + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") } if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[["ordered_cat_vars"]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[["ordered_unique_levels"]] <- json_object_default$get_string_list("ordered_unique_levels", train_set_metadata[["ordered_cat_vars"]]) + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] + ) } if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[["unordered_cat_vars"]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[["unordered_unique_levels"]] <- json_object_default$get_string_list("unordered_unique_levels", train_set_metadata[["unordered_cat_vars"]]) + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) } output[["train_set_metadata"]] <- train_set_metadata - + # Unpack model params model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object_default$get_boolean("standardize") - model_params[["initial_sigma2"]] <- json_object_default$get_scalar("initial_sigma2") - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean("sample_sigma2_global") - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean("sample_sigma2_leaf_mu") - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean("sample_sigma2_leaf_tau") + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string("propensity_covariate") + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar("num_rfx_basis") - model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment") - model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") - model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") - + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + # Combine values that are sample-specific for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] @@ -2938,25 +4222,40 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) } else { - prev_json <- json_object_list[[i-1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + json_object$get_scalar("num_samples") + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") } } output[["model_params"]] <- model_params - + # Unpack sampled parameters if (model_params[["sample_sigma2_global"]]) { for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector("sigma2_global_samples", "parameters") + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) } else { - output[["sigma2_global_samples"]] <- c(output[["sigma2_global_samples"]], json_object$get_vector("sigma2_global_samples", "parameters")) + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) + ) } } } @@ -2964,9 +4263,18 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector("sigma2_leaf_mu_samples", "parameters") + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) } else { - output[["sigma2_leaf_mu_samples"]] <- c(output[["sigma2_leaf_mu_samples"]], json_object$get_vector("sigma2_leaf_mu_samples", "parameters")) + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) } } } @@ -2974,9 +4282,18 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) } else { - output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) } } } @@ -2984,9 +4301,18 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector("sigma2_leaf_tau_samples", "parameters") + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) } else { - output[["sigma2_leaf_tau_samples"]] <- c(output[["sigma2_leaf_tau_samples"]], json_object$get_vector("sigma2_leaf_tau_samples", "parameters")) + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) } } } @@ -2994,27 +4320,46 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector("b_1_samples", "parameters") - output[["b_0_samples"]] <- json_object$get_vector("b_0_samples", "parameters") + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) } else { - output[["b_1_samples"]] <- c(output[["b_1_samples"]], json_object$get_vector("b_1_samples", "parameters")) - output[["b_0_samples"]] <- c(output[["b_0_samples"]], json_object$get_vector("b_0_samples", "parameters")) + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) } } } - + # Unpack random effects if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson(json_object_list, 0) + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 + ) } - + # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) - + class(output) <- "bcfmodel" return(output) } diff --git a/R/calibration.R b/R/calibration.R index ea91436f..cbcd293e 100644 --- a/R/calibration.R +++ b/R/calibration.R @@ -1,5 +1,5 @@ #' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) -#' +#' #' Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288 #' #' @param y Outcome to be modeled using BART, BCF or another nonparametric ensemble method. @@ -10,7 +10,7 @@ #' @param standardize (Optional) Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`. #' #' @return Value of `lambda` which determines the scale parameter of the global error variance prior (`sigma^2 ~ IG(nu,nu*lambda)`) -#' @export +#' @export #' #' @examples #' n <- 100 @@ -21,14 +21,26 @@ #' lambda <- calibrateInverseGammaErrorVariance(y, X, nu = nu) #' sigma2hat <- mean(resid(lm(y~X))^2) #' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat) -calibrateInverseGammaErrorVariance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) { +calibrateInverseGammaErrorVariance <- function( + y, + X, + W = NULL, + nu = 3, + quant = 0.9, + standardize = TRUE +) { # Compute regression basis - if (!is.null(W)) basis <- cbind(X, W) - else basis <- X + if (!is.null(W)) { + basis <- cbind(X, W) + } else { + basis <- X + } # Standardize outcome if requested - if (standardize) y <- (y-mean(y))/sd(y) + if (standardize) { + y <- (y - mean(y)) / sd(y) + } # Compute the "regression-based" overestimate of sigma^2 - sigma2hat <- mean(resid(lm(y~basis))^2) + sigma2hat <- mean(resid(lm(y ~ basis))^2) # Calibrate lambda based on the implied quantile of sigma2hat - return((sigma2hat*qgamma(1-quant,nu))/nu) + return((sigma2hat * qgamma(1 - quant, nu)) / nu) } diff --git a/R/config.R b/R/config.R index 0e2a0b40..f63ce906 100644 --- a/R/config.R +++ b/R/config.R @@ -13,55 +13,54 @@ ForestModelConfig <- R6::R6Class( classname = "ForestModelConfig", cloneable = FALSE, public = list( - #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) feature_types = NULL, - + #' @field sweep_update_indices Vector of trees to update in a sweep sweep_update_indices = NULL, - + #' @field num_trees Number of trees in the forest being sampled num_trees = NULL, - + #' @field num_features Number of features in training dataset num_features = NULL, - + #' @field num_observations Number of observations in training dataset num_observations = NULL, - + #' @field leaf_dimension Dimension of the leaf model leaf_dimension = NULL, - + #' @field alpha Root node split probability in tree prior alpha = NULL, - + #' @field beta Depth prior penalty in tree prior beta = NULL, - + #' @field min_samples_leaf Minimum number of samples in a tree leaf min_samples_leaf = NULL, - + #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. max_depth = NULL, - + #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) leaf_model_type = NULL, - + #' @field leaf_model_scale Scale parameter used in Gaussian leaf models leaf_model_scale = NULL, - + #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset variable_weights = NULL, - + #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) variance_forest_shape = NULL, - + #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) variance_forest_scale = NULL, - + #' @field cutpoint_grid_size Number of unique cutpoints to consider cutpoint_grid_size = NULL, - + #' @field num_features_subsample Number of features to subsample for the GFR algorithm num_features_subsample = NULL, @@ -84,18 +83,36 @@ ForestModelConfig <- R6::R6Class( #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) #' @param num_features_subsample Number of features to subsample for the GFR algorithm - #' + #' #' @return A new ForestModelConfig object. - initialize = function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, - num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, - alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, - leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, - variance_forest_scale = 1.0, cutpoint_grid_size = 100, num_features_subsample = NULL) { + initialize = function( + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL + ) { if (is.null(feature_types)) { if (is.null(num_features)) { - stop("Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object.") + stop( + "Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object." + ) } - warning("`feature_types` not provided, will be assumed to be numeric") + warning( + "`feature_types` not provided, will be assumed to be numeric" + ) feature_types <- rep(0, num_features) } else { if (is.null(num_features)) { @@ -103,8 +120,10 @@ ForestModelConfig <- R6::R6Class( } } if (is.null(variable_weights)) { - warning("`variable_weights` not provided, will be assumed to be equal-weighted") - variable_weights <- rep(1/num_features, num_features) + warning( + "`variable_weights` not provided, will be assumed to be equal-weighted" + ) + variable_weights <- rep(1 / num_features, num_features) } if (is.null(num_trees)) { stop("num_trees must be provided") @@ -120,7 +139,9 @@ ForestModelConfig <- R6::R6Class( stop("`feature_types` must have `num_features` total elements") } if (num_features != length(variable_weights)) { - stop("`variable_weights` must have `num_features` total elements") + stop( + "`variable_weights` must have `num_features` total elements" + ) } self$feature_types <- feature_types self$sweep_update_indices <- sweep_update_indices @@ -140,13 +161,15 @@ ForestModelConfig <- R6::R6Class( num_features_subsample <- num_features } if (num_features_subsample > num_features) { - stop("`num_features_subsample` cannot be larger than `num_features`") + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) } if (num_features_subsample <= 0) { stop("`num_features_subsample` must be at least 1") } self$num_features_subsample <- num_features_subsample - + if (!(as.integer(leaf_model_type) == leaf_model_type)) { stop("`leaf_model_type` must be an integer between 0 and 3") if ((leaf_model_type < 0) | (leaf_model_type > 3)) { @@ -154,33 +177,37 @@ ForestModelConfig <- R6::R6Class( } } self$leaf_model_type <- leaf_model_type - + if (is.null(leaf_model_scale)) { - self$leaf_model_scale <- diag(1/num_trees, leaf_dimension) + self$leaf_model_scale <- diag(1 / num_trees, leaf_dimension) } else if (is.matrix(leaf_model_scale)) { if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { stop("`leaf_model_scale` must be a square matrix") } if (ncol(leaf_model_scale) != leaf_dimension) { - stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) } self$leaf_model_scale <- leaf_model_scale } else { if (leaf_model_scale <= 0) { - stop("`leaf_model_scale` must be positive, if provided as scalar") + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) } self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) } }, - + #' @description #' Update feature types #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) update_feature_types = function(feature_types) { stopifnot(length(feature_types) == self$num_features) self$feature_types <- feature_types - }, - + }, + #' @description #' Update sweep update indices #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep @@ -190,44 +217,44 @@ ForestModelConfig <- R6::R6Class( stopifnot(max(sweep_update_indices) < self$num_trees) } self$sweep_update_indices <- sweep_update_indices - }, - + }, + #' @description #' Update variable weights #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset update_variable_weights = function(variable_weights) { stopifnot(length(variable_weights) == self$num_features) self$variable_weights <- variable_weights - }, - + }, + #' @description #' Update root node split probability in tree prior #' @param alpha Root node split probability in tree prior update_alpha = function(alpha) { self$alpha <- alpha - }, - + }, + #' @description #' Update depth prior penalty in tree prior #' @param beta Depth prior penalty in tree prior update_beta = function(beta) { self$beta <- beta - }, - + }, + #' @description #' Update minimum number of samples per leaf node in the tree prior #' @param min_samples_leaf Minimum number of samples in a tree leaf update_min_samples_leaf = function(min_samples_leaf) { self$min_samples_leaf <- min_samples_leaf - }, - + }, + #' @description #' Update max depth in the tree prior #' @param max_depth Maximum depth of any tree in the ensemble in the model update_max_depth = function(max_depth) { self$max_depth <- max_depth - }, - + }, + #' @description #' Update scale parameter used in Gaussian leaf models #' @param leaf_model_scale Scale parameter used in Gaussian leaf models @@ -237,156 +264,162 @@ ForestModelConfig <- R6::R6Class( stop("`leaf_model_scale` must be a square matrix") } if (ncol(leaf_model_scale) != self$leaf_dimension) { - stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) } self$leaf_model_scale <- leaf_model_scale } else { if (leaf_model_scale <= 0) { - stop("`leaf_model_scale` must be positive, if provided as scalar") + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) } self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) } - }, - + }, + #' @description #' Update shape parameter for IG leaf models #' @param variance_forest_shape Shape parameter for IG leaf models update_variance_forest_shape = function(variance_forest_shape) { self$variance_forest_shape <- variance_forest_shape - }, - + }, + #' @description #' Update scale parameter for IG leaf models #' @param variance_forest_scale Scale parameter for IG leaf models update_variance_forest_scale = function(variance_forest_scale) { self$variance_forest_scale <- variance_forest_scale - }, - + }, + #' @description #' Update number of unique cutpoints to consider #' @param cutpoint_grid_size Number of unique cutpoints to consider update_cutpoint_grid_size = function(cutpoint_grid_size) { self$cutpoint_grid_size <- cutpoint_grid_size }, - + #' @description #' Update number of features to subsample for the GFR algorithm #' @param num_features_subsample Number of features to subsample for the GFR algorithm update_num_features_subsample = function(num_features_subsample) { if (num_features_subsample > self$num_features) { - stop("`num_features_subsample` cannot be larger than `num_features`") + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) } if (num_features_subsample <= 0) { stop("`num_features_subsample` must at least 1") } self$num_features_subsample <- num_features_subsample - }, - + }, + #' @description #' Query feature types for this ForestModelConfig object #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) get_feature_types = function() { return(self$feature_types) - }, - + }, + #' @description #' Query sweep update indices for this ForestModelConfig object #' @returns Vector of (0-indexed) indices of trees to update in a sweep get_sweep_indices = function() { return(self$sweep_update_indices) - }, - + }, + #' @description #' Query variable weights for this ForestModelConfig object #' @returns Vector specifying sampling probability for all p covariates in ForestDataset get_variable_weights = function() { return(self$variable_weights) - }, - + }, + #' @description #' Query number of trees #' @returns Number of trees in a forest get_num_trees = function() { return(self$num_trees) - }, - + }, + #' @description #' Query number of features #' @returns Number of features in a forest model training set get_num_features = function() { return(self$num_features) - }, - + }, + #' @description #' Query number of observations #' @returns Number of observations in a forest model training set get_num_observations = function() { return(self$num_observations) - }, - + }, + #' @description #' Query root node split probability in tree prior for this ForestModelConfig object #' @returns Root node split probability in tree prior get_alpha = function() { return(self$alpha) - }, - + }, + #' @description #' Query depth prior penalty in tree prior for this ForestModelConfig object #' @returns Depth prior penalty in tree prior get_beta = function() { return(self$beta) - }, - + }, + #' @description #' Query root node split probability in tree prior for this ForestModelConfig object #' @returns Minimum number of samples in a tree leaf get_min_samples_leaf = function() { return(self$min_samples_leaf) - }, - + }, + #' @description #' Query root node split probability in tree prior for this ForestModelConfig object #' @returns Maximum depth of any tree in the ensemble in the model get_max_depth = function() { return(self$max_depth) - }, - + }, + #' @description #' Query (integer-coded) type of leaf model #' @returns Integer coded leaf model type get_leaf_model_type = function() { return(self$leaf_model_type) - }, - + }, + #' @description #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object #' @returns Scale parameter used in Gaussian leaf models get_leaf_model_scale = function() { return(self$leaf_model_scale) - }, - + }, + #' @description #' Query shape parameter for IG leaf models for this ForestModelConfig object #' @returns Shape parameter for IG leaf models get_variance_forest_shape = function() { return(self$variance_forest_shape) - }, - + }, + #' @description #' Query scale parameter for IG leaf models for this ForestModelConfig object #' @returns Scale parameter for IG leaf models get_variance_forest_scale = function() { return(self$variance_forest_scale) - }, - + }, + #' @description #' Query number of unique cutpoints to consider for this ForestModelConfig object #' @returns Number of unique cutpoints to consider get_cutpoint_grid_size = function() { return(self$cutpoint_grid_size) - }, - + }, + #' @description #' Query number of features to subsample for the GFR algorithm #' @returns Number of features to subsample for the GFR algorithm @@ -396,7 +429,7 @@ ForestModelConfig <- R6::R6Class( ) ) -#' Object used to get / set global parameters and other global model +#' Object used to get / set global parameters and other global model #' configuration options in the "low-level" stochtree interface #' #' @description @@ -404,33 +437,32 @@ ForestModelConfig <- R6::R6Class( #' customization, in which users employ R wrappers around C++ objects #' like ForestDataset, Outcome, CppRng, and ForestModel to run the #' Gibbs sampler of a BART model with custom modifications. -#' GlobalModelConfig allows users to specify / query the global parameters +#' GlobalModelConfig allows users to specify / query the global parameters #' of a model they wish to run. GlobalModelConfig <- R6::R6Class( classname = "GlobalModelConfig", cloneable = FALSE, public = list( - #' @field global_error_variance Global error variance parameter global_error_variance = NULL, #' Create a new GlobalModelConfig object. #' #' @param global_error_variance Global error variance parameter (default: `1.0`) - #' + #' #' @return A new GlobalModelConfig object. initialize = function(global_error_variance = 1.0) { self$global_error_variance <- global_error_variance }, - + #' @description #' Update global error variance parameter #' @param global_error_variance Global error variance parameter update_global_error_variance = function(global_error_variance) { self$global_error_variance <- global_error_variance - }, - + }, + #' @description #' Query global error variance parameter for this GlobalModelConfig object #' @returns Global error variance parameter @@ -464,18 +496,46 @@ GlobalModelConfig <- R6::R6Class( #' #' @examples #' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) -createForestModelConfig <- function(feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, - num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, - alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, - leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, - variance_forest_scale = 1.0, cutpoint_grid_size = 100, - num_features_subsample = NULL){ - return(invisible(( - ForestModelConfig$new(feature_types, sweep_update_indices, num_trees, num_features, num_observations, - variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, - max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, - variance_forest_scale, cutpoint_grid_size, num_features_subsample) - ))) +createForestModelConfig <- function( + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL +) { + return(invisible( + (ForestModelConfig$new( + feature_types, + sweep_update_indices, + num_trees, + num_features, + num_observations, + variable_weights, + leaf_dimension, + alpha, + beta, + min_samples_leaf, + max_depth, + leaf_model_type, + leaf_model_scale, + variance_forest_shape, + variance_forest_scale, + cutpoint_grid_size, + num_features_subsample + )) + )) } #' Create a global model config object @@ -486,8 +546,6 @@ createForestModelConfig <- function(feature_types = NULL, sweep_update_indices = #' #' @examples #' config <- createGlobalModelConfig(global_error_variance = 100) -createGlobalModelConfig <- function(global_error_variance = 1.0){ - return(invisible(( - GlobalModelConfig$new(global_error_variance) - ))) +createGlobalModelConfig <- function(global_error_variance = 1.0) { + return(invisible((GlobalModelConfig$new(global_error_variance)))) } diff --git a/R/data.R b/R/data.R index 8bea2823..13cd714f 100644 --- a/R/data.R +++ b/R/data.R @@ -1,25 +1,28 @@ #' Dataset used to sample a forest #' #' @description -#' A dataset consists of three matrices / vectors: covariates, -#' bases, and variance weights. Both the basis vector and variance +#' A dataset consists of three matrices / vectors: covariates, +#' bases, and variance weights. Both the basis vector and variance #' weights are optional. ForestDataset <- R6::R6Class( classname = "ForestDataset", cloneable = FALSE, public = list( - #' @field data_ptr External pointer to a C++ ForestDataset class data_ptr = NULL, - + #' @description #' Create a new ForestDataset object. #' @param covariates Matrix of covariates #' @param basis (Optional) Matrix of bases used to define a leaf regression #' @param variance_weights (Optional) Vector of observation-specific variance weights #' @return A new `ForestDataset` object. - initialize = function(covariates, basis=NULL, variance_weights=NULL) { + initialize = function( + covariates, + basis = NULL, + variance_weights = NULL + ) { self$data_ptr <- create_forest_dataset_cpp() forest_dataset_add_covariates_cpp(self$data_ptr, covariates) if (!is.null(basis)) { @@ -28,8 +31,8 @@ ForestDataset <- R6::R6Class( if (!is.null(variance_weights)) { forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) } - }, - + }, + #' @description #' Update basis matrix in a dataset #' @param basis Updated matrix of bases used to define a leaf regression @@ -44,58 +47,62 @@ ForestDataset <- R6::R6Class( #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. update_variance_weights = function(variance_weights, exponentiate = F) { stopifnot(self$has_variance_weights()) - forest_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + forest_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) }, - + #' @description #' Return number of observations in a `ForestDataset` object #' @return Observation count num_observations = function() { return(dataset_num_rows_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return number of covariates in a `ForestDataset` object #' @return Covariate count num_covariates = function() { return(dataset_num_covariates_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return number of bases in a `ForestDataset` object #' @return Basis count num_basis = function() { return(dataset_num_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return covariates as an R matrix #' @return Covariate data get_covariates = function() { return(forest_dataset_get_covariates_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return bases as an R matrix #' @return Basis data get_basis = function() { return(forest_dataset_get_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return variance weights as an R vector #' @return Variance weight data get_variance_weights = function() { return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - + }, + #' @description #' Whether or not a dataset has a basis matrix #' @return True if basis matrix is loaded, false otherwise has_basis = function() { return(dataset_has_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Whether or not a dataset has variance weights #' @return True if variance weights are loaded, false otherwise @@ -110,19 +117,18 @@ ForestDataset <- R6::R6Class( #' @description #' The outcome class is wrapper around a vector of (mutable) #' outcomes for ML tasks (supervised learning, causal inference). -#' When an additive tree ensemble is sampled, the outcome used to -#' sample a specific model term is the "partial residual" consisting -#' of the outcome minus the predictions of every other model term +#' When an additive tree ensemble is sampled, the outcome used to +#' sample a specific model term is the "partial residual" consisting +#' of the outcome minus the predictions of every other model term #' (trees, group random effects, etc...). Outcome <- R6::R6Class( classname = "Outcome", cloneable = FALSE, public = list( - #' @field data_ptr External pointer to a C++ Outcome class data_ptr = NULL, - + #' @description #' Create a new Outcome object. #' @param outcome Vector of outcome values @@ -130,14 +136,14 @@ Outcome <- R6::R6Class( initialize = function(outcome) { self$data_ptr <- create_column_vector_cpp(outcome) }, - + #' @description #' Extract raw data in R from the underlying C++ object #' @return R vector containing (copy of) the values in `Outcome` object get_data = function() { return(get_residual_cpp(self$data_ptr)) - }, - + }, + #' @description #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` #' @param update_vector Vector to be added to outcome @@ -148,13 +154,17 @@ Outcome <- R6::R6Class( } else { dim_vec <- dim(update_vector) if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } update_vector <- as.numeric(update_vector) } } add_to_column_vector_cpp(self$data_ptr, update_vector) - }, - + }, + #' @description #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` #' @param update_vector Vector to be subtracted from outcome @@ -165,13 +175,17 @@ Outcome <- R6::R6Class( } else { dim_vec <- dim(update_vector) if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } update_vector <- as.numeric(update_vector) } } subtract_from_column_vector_cpp(self$data_ptr, update_vector) - }, - + }, + #' @description #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` #' @param new_vector Vector from which to overwrite the current data @@ -182,7 +196,11 @@ Outcome <- R6::R6Class( } else { dim_vec <- dim(new_vector) if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) stop("if update_vector is provided as a matrix, it must be 2d") + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } new_vector <- as.numeric(new_vector) } } @@ -194,32 +212,31 @@ Outcome <- R6::R6Class( #' Dataset used to sample a random effects model #' #' @description -#' A dataset consists of three matrices / vectors: group labels, +#' A dataset consists of three matrices / vectors: group labels, #' bases, and variance weights. Variance weights are optional. RandomEffectsDataset <- R6::R6Class( classname = "RandomEffectsDataset", cloneable = FALSE, public = list( - #' @field data_ptr External pointer to a C++ RandomEffectsDataset class data_ptr = NULL, - + #' @description #' Create a new RandomEffectsDataset object. #' @param group_labels Vector of group labels #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) #' @param variance_weights (Optional) Vector of observation-specific variance weights #' @return A new `RandomEffectsDataset` object. - initialize = function(group_labels, basis, variance_weights=NULL) { + initialize = function(group_labels, basis, variance_weights = NULL) { self$data_ptr <- create_rfx_dataset_cpp() rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels) rfx_dataset_add_basis_cpp(self$data_ptr, basis) if (!is.null(variance_weights)) { rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights) } - }, - + }, + #' @description #' Update basis matrix in a dataset #' @param basis Updated matrix of bases used to define random slopes / intercepts @@ -227,65 +244,69 @@ RandomEffectsDataset <- R6::R6Class( stopifnot(self$has_basis()) rfx_dataset_update_basis_cpp(self$data_ptr, basis) }, - + #' @description #' Update variance_weights in a dataset #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. update_variance_weights = function(variance_weights, exponentiate = F) { stopifnot(self$has_variance_weights()) - rfx_dataset_update_var_weights_cpp(self$data_ptr, variance_weights, exponentiate) + rfx_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) }, - + #' @description #' Return number of observations in a `RandomEffectsDataset` object #' @return Observation count num_observations = function() { return(rfx_dataset_num_rows_cpp(self$data_ptr)) }, - + #' @description #' Return dimension of the basis matrix in a `RandomEffectsDataset` object #' @return Basis vector count num_basis = function() { return(rfx_dataset_num_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return group labels as an R vector #' @return Group label data get_group_labels = function() { return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return bases as an R matrix #' @return Basis data get_basis = function() { return(rfx_dataset_get_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Return variance weights as an R vector #' @return Variance weight data get_variance_weights = function() { return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - + }, + #' @description #' Whether or not a dataset has group label indices #' @return True if group label vector is loaded, false otherwise has_group_labels = function() { return(rfx_dataset_has_group_labels_cpp(self$data_ptr)) - }, - + }, + #' @description #' Whether or not a dataset has a basis matrix #' @return True if basis matrix is loaded, false otherwise has_basis = function() { return(rfx_dataset_has_basis_cpp(self$data_ptr)) - }, - + }, + #' @description #' Whether or not a dataset has variance weights #' @return True if variance weights are loaded, false otherwise @@ -303,7 +324,7 @@ RandomEffectsDataset <- R6::R6Class( #' #' @return `ForestDataset` object #' @export -#' +#' #' @examples #' covariate_matrix <- matrix(runif(10*100), ncol = 10) #' basis_matrix <- matrix(rnorm(3*100), ncol = 3) @@ -311,10 +332,12 @@ RandomEffectsDataset <- R6::R6Class( #' forest_dataset <- createForestDataset(covariate_matrix) #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix) #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector) -createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){ - return(invisible(( - ForestDataset$new(covariates, basis, variance_weights) - ))) +createForestDataset <- function( + covariates, + basis = NULL, + variance_weights = NULL +) { + return(invisible((ForestDataset$new(covariates, basis, variance_weights)))) } #' Create an outcome object @@ -323,15 +346,13 @@ createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){ #' #' @return `Outcome` object #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) #' outcome <- createOutcome(y) -createOutcome <- function(outcome){ - return(invisible(( - Outcome$new(outcome) - ))) +createOutcome <- function(outcome) { + return(invisible((Outcome$new(outcome)))) } #' Create a random effects dataset object @@ -342,15 +363,19 @@ createOutcome <- function(outcome){ #' #' @return `RandomEffectsDataset` object #' @export -#' +#' #' @examples #' rfx_group_ids <- sample(1:2, size = 100, replace = TRUE) #' rfx_basis <- matrix(rnorm(3*100), ncol = 3) #' weight_vector <- rnorm(100) #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector) -createRandomEffectsDataset <- function(group_labels, basis, variance_weights=NULL){ - return(invisible(( - RandomEffectsDataset$new(group_labels, basis, variance_weights) - ))) +createRandomEffectsDataset <- function( + group_labels, + basis, + variance_weights = NULL +) { + return(invisible( + (RandomEffectsDataset$new(group_labels, basis, variance_weights)) + )) } diff --git a/R/forest.R b/R/forest.R index 09f202ff..a554c2d5 100644 --- a/R/forest.R +++ b/R/forest.R @@ -7,10 +7,9 @@ ForestSamples <- R6::R6Class( classname = "ForestSamples", cloneable = FALSE, public = list( - #' @field forest_container_ptr External pointer to a C++ ForestContainer class forest_container_ptr = NULL, - + #' @description #' Create a new ForestContainer object. #' @param num_trees Number of trees @@ -18,18 +17,28 @@ ForestSamples <- R6::R6Class( #' @param is_leaf_constant Whether leaf is constant #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned #' @return A new `ForestContainer` object. - initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=FALSE, is_exponentiated=FALSE) { - self$forest_container_ptr <- forest_container_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) - }, - - #' @description - #' Collapse forests in this container by a pre-specified batch size. - #' For example, if we have a container of twenty 10-tree forests, and we - #' specify a `batch_size` of 5, then this method will yield four 50-tree - #' forests. "Excess" forests remaining after the size of a forest container - #' is divided by `batch_size` will be pruned from the beginning of the - #' container (i.e. earlier sampled forests will be deleted). This method - #' has no effect if `batch_size` is larger than the number of forests + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_container_ptr <- forest_container_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) + }, + + #' @description + #' Collapse forests in this container by a pre-specified batch size. + #' For example, if we have a container of twenty 10-tree forests, and we + #' specify a `batch_size` of 5, then this method will yield four 50-tree + #' forests. "Excess" forests remaining after the size of a forest container + #' is divided by `batch_size` will be pruned from the beginning of the + #' container (i.e. earlier sampled forests will be deleted). This method + #' has no effect if `batch_size` is larger than the number of forests #' in a container. #' @param batch_size Number of forests to be collapsed into a single forest collapse = function(batch_size) { @@ -37,26 +46,38 @@ ForestSamples <- R6::R6Class( if ((batch_size <= container_size) && (batch_size > 1)) { reverse_container_inds <- seq(container_size, 1, -1) num_clean_batches <- container_size %/% batch_size - batch_inds <- (reverse_container_inds - (container_size - (container_size %/% num_clean_batches) * num_clean_batches) - 1) %/% batch_size + batch_inds <- (reverse_container_inds - + (container_size - + (container_size %/% num_clean_batches) * + num_clean_batches) - + 1) %/% + batch_size for (batch_ind in unique(batch_inds[batch_inds >= 0])) { - merge_forest_inds <- sort(reverse_container_inds[batch_inds == batch_ind] - 1) + merge_forest_inds <- sort( + reverse_container_inds[batch_inds == batch_ind] - 1 + ) num_merge_forests <- length(merge_forest_inds) self$combine_forests(merge_forest_inds) for (i in num_merge_forests:2) { self$delete_sample(merge_forest_inds[i]) } forest_scale_factor <- 1.0 / num_merge_forests - self$multiply_forest(merge_forest_inds[1], forest_scale_factor) + self$multiply_forest( + merge_forest_inds[1], + forest_scale_factor + ) } if (min(batch_inds) < 0) { - delete_forest_inds <- sort(reverse_container_inds[batch_inds < 0] - 1) + delete_forest_inds <- sort( + reverse_container_inds[batch_inds < 0] - 1 + ) for (i in length(delete_forest_inds):1) { self$delete_sample(delete_forest_inds[i]) } } } - }, - + }, + #' @description #' Merge specified forests into a single forest #' @param forest_inds Indices of forests to be combined (0-indexed) @@ -66,9 +87,12 @@ ForestSamples <- R6::R6Class( stopifnot(length(forest_inds) > 1) stopifnot(all(as.integer(forest_inds) == forest_inds)) forest_inds_sorted <- as.integer(sort(forest_inds)) - combine_forests_forest_container_cpp(self$forest_container_ptr, forest_inds_sorted) - }, - + combine_forests_forest_container_cpp( + self$forest_container_ptr, + forest_inds_sorted + ) + }, + #' @description #' Add a constant value to every leaf of every tree of a given forest #' @param forest_index Index of forest whose leaves will be modified (0-indexed) @@ -76,9 +100,13 @@ ForestSamples <- R6::R6Class( add_to_forest = function(forest_index, constant_value) { stopifnot(forest_index < self$num_samples()) stopifnot(forest_index >= 0) - add_to_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_value) - }, - + add_to_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_value + ) + }, + #' @description #' Multiply every leaf of every tree of a given forest by constant value #' @param forest_index Index of forest whose leaves will be modified (0-indexed) @@ -86,160 +114,237 @@ ForestSamples <- R6::R6Class( multiply_forest = function(forest_index, constant_multiple) { stopifnot(forest_index < self$num_samples()) stopifnot(forest_index >= 0) - multiply_forest_forest_container_cpp(self$forest_container_ptr, forest_index, constant_multiple) - }, - + multiply_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_multiple + ) + }, + #' @description #' Create a new `ForestContainer` object from a json object #' @param json_object Object of class `CppJson` #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy #' @return A new `ForestContainer` object. load_from_json = function(json_object, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_cpp(json_object$json_ptr, json_forest_label) - }, - + self$forest_container_ptr <- forest_container_from_json_cpp( + json_object$json_ptr, + json_forest_label + ) + }, + #' @description #' Append to a `ForestContainer` object from a json object #' @param json_object Object of class `CppJson` #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy #' @return None append_from_json = function(json_object, json_forest_label) { - forest_container_append_from_json_cpp(self$forest_container_ptr, json_object$json_ptr, json_forest_label) - }, - + forest_container_append_from_json_cpp( + self$forest_container_ptr, + json_object$json_ptr, + json_forest_label + ) + }, + #' @description #' Create a new `ForestContainer` object from a json object #' @param json_string JSON string which parses into object of class `CppJson` #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy #' @return A new `ForestContainer` object. load_from_json_string = function(json_string, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_string_cpp(json_string, json_forest_label) - }, - + self$forest_container_ptr <- forest_container_from_json_string_cpp( + json_string, + json_forest_label + ) + }, + #' @description #' Append to a `ForestContainer` object from a json object #' @param json_string JSON string which parses into object of class `CppJson` #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy #' @return None append_from_json_string = function(json_string, json_forest_label) { - forest_container_append_from_json_string_cpp(self$forest_container_ptr, json_string, json_forest_label) - }, - + forest_container_append_from_json_string_cpp( + self$forest_container_ptr, + json_string, + json_forest_label + ) + }, + #' @description #' Predict every tree ensemble on every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class - #' @return matrix of predictions with as many rows as in forest_dataset + #' @return matrix of predictions with as many rows as in forest_dataset #' and as many columns as samples in the `ForestContainer` predict = function(forest_dataset) { stopifnot(!is.null(forest_dataset$data_ptr)) - return(predict_forest_cpp(self$forest_container_ptr, forest_dataset$data_ptr)) - }, - + return(predict_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + )) + }, + #' @description #' Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is two-dimensional (number of observations, - #' number of forest samples). In the case of a multivariate leaf regression, - #' this array is three-dimension (number of observations, leaf model dimension, + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is two-dimensional (number of observations, + #' number of forest samples). In the case of a multivariate leaf regression, + #' this array is three-dimension (number of observations, leaf model dimension, #' number of samples). predict_raw = function(forest_dataset) { stopifnot(!is.null(forest_dataset$data_ptr)) # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp(self$forest_container_ptr) - num_samples <- num_samples_forest_container_cpp(self$forest_container_ptr) + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) + num_samples <- num_samples_forest_container_cpp( + self$forest_container_ptr + ) n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - + # Predict leaf values from forest - predictions <- predict_forest_raw_cpp(self$forest_container_ptr, forest_dataset$data_ptr) + predictions <- predict_forest_raw_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + ) if (output_dim > 1) { dim(predictions) <- c(n, output_dim, num_samples) } else { dim(predictions) <- c(n, num_samples) } - + return(predictions) - }, - + }, + #' @description #' Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class #' @param forest_num Index of the forest sample within the container - #' @return matrix of predictions with as many rows as in forest_dataset + #' @return matrix of predictions with as many rows as in forest_dataset #' and as many columns as dimensions in the leaves of trees in `ForestContainer` predict_raw_single_forest = function(forest_dataset, forest_num) { stopifnot(!is.null(forest_dataset$data_ptr)) # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp(self$forest_container_ptr) + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - + # Predict leaf values from forest - output <- predict_forest_raw_single_forest_cpp(self$forest_container_ptr, forest_dataset$data_ptr, forest_num) + output <- predict_forest_raw_single_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num + ) return(output) - }, - + }, + #' @description #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` #' @param forest_dataset `ForestDataset` R class #' @param forest_num Index of the forest sample within the container #' @param tree_num Index of the tree to be queried - #' @return matrix of predictions with as many rows as in `forest_dataset` + #' @return matrix of predictions with as many rows as in `forest_dataset` #' and as many columns as dimensions in the leaves of trees in `ForestContainer` - predict_raw_single_tree = function(forest_dataset, forest_num, tree_num) { + predict_raw_single_tree = function( + forest_dataset, + forest_num, + tree_num + ) { stopifnot(!is.null(forest_dataset$data_ptr)) # Predict leaf values from forest - output <- predict_forest_raw_single_tree_cpp(self$forest_container_ptr, forest_dataset$data_ptr, forest_num, tree_num) + output <- predict_forest_raw_single_tree_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num, + tree_num + ) return(output) - }, - + }, + #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. #' @param forest_num Index of the forest sample within the container. #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. set_root_leaves = function(forest_num, leaf_value) { stopifnot(!is.null(self$forest_container_ptr)) - stopifnot(num_samples_forest_container_cpp(self$forest_container_ptr) == 0) - + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + # Set leaf values if (length(leaf_value) == 1) { - stopifnot(leaf_dimension_forest_container_cpp(self$forest_container_ptr) == 1) - set_leaf_value_forest_container_cpp(self$forest_container_ptr, leaf_value) + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + 1 + ) + set_leaf_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) } else if (length(leaf_value) > 1) { - stopifnot(leaf_dimension_forest_container_cpp(self$forest_container_ptr) == length(leaf_value)) - set_leaf_vector_forest_container_cpp(self$forest_container_ptr, leaf_value) + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + length(leaf_value) + ) + set_leaf_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) } else { - stop("leaf_value must be a numeric value or vector of length >= 1") + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) } - }, - + }, + #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) #' @param outcome `Outcome` Outcome class (residual / partial residual) #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function(dataset, outcome, forest_model, leaf_model_int, leaf_value) { + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { stopifnot(!is.null(dataset$data_ptr)) stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_container_ptr)) - stopifnot(num_samples_forest_container_cpp(self$forest_container_ptr) == 0) - + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + # Initialize the model - initialize_forest_model_cpp(dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr, - forest_model$tracker_ptr, leaf_value, leaf_model_int) - }, - - #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. + initialize_forest_model_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int + ) + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions @@ -247,81 +352,111 @@ ForestSamples <- R6::R6Class( #' @param requires_basis Whether or not a forest requires a basis for prediction #' @param forest_num Index of forest used to update residuals #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function(dataset, outcome, forest_model, requires_basis, forest_num, add) { + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + forest_num, + add + ) { stopifnot(!is.null(dataset$data_ptr)) stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_container_ptr)) - + adjust_residual_forest_container_cpp( - dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr, - forest_model$tracker_ptr, requires_basis, forest_num, add + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + requires_basis, + forest_num, + add ) - }, - + }, + #' @description #' Store the trees and metadata of `ForestDataset` class in a json file #' @param json_filename Name of output json file (must end in ".json") save_json = function(json_filename) { - invisible(json_save_forest_container_cpp(self$forest_container_ptr, json_filename)) - }, - + invisible(json_save_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + #' @description - #' Load trees and metadata for an ensemble from a json file. Note that - #' any trees and metadata already present in `ForestDataset` class will + #' Load trees and metadata for an ensemble from a json file. Note that + #' any trees and metadata already present in `ForestDataset` class will #' be overwritten. #' @param json_filename Name of model input json file (must end in ".json") load_json = function(json_filename) { - invisible(json_load_forest_container_cpp(self$forest_container_ptr, json_filename)) - }, - + invisible(json_load_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + #' @description #' Return number of samples in a `ForestContainer` object #' @return Sample count num_samples = function() { return(num_samples_forest_container_cpp(self$forest_container_ptr)) - }, - + }, + #' @description #' Return number of trees in each ensemble of a `ForestContainer` object #' @return Tree count num_trees = function() { return(num_trees_forest_container_cpp(self$forest_container_ptr)) - }, - + }, + #' @description #' Return output dimension of trees in a `ForestContainer` object #' @return Leaf node parameter size leaf_dimension = function() { - return(leaf_dimension_forest_container_cpp(self$forest_container_ptr)) - }, - + return(leaf_dimension_forest_container_cpp( + self$forest_container_ptr + )) + }, + #' @description #' Return constant leaf status of trees in a `ForestContainer` object #' @return `TRUE` if leaves are constant, `FALSE` otherwise is_constant_leaf = function() { - return(is_constant_leaf_forest_container_cpp(self$forest_container_ptr)) - }, - + return(is_constant_leaf_forest_container_cpp( + self$forest_container_ptr + )) + }, + #' @description #' Return exponentiation status of trees in a `ForestContainer` object #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise is_exponentiated = function() { - return(is_exponentiated_forest_container_cpp(self$forest_container_ptr)) - }, - + return(is_exponentiated_forest_container_cpp( + self$forest_container_ptr + )) + }, + #' @description - #' Add a new all-root ensemble to the container, with all of the leaves + #' Add a new all-root ensemble to the container, with all of the leaves #' set to the value / vector provided #' @param leaf_value Value (or vector of values) to initialize root nodes in tree add_forest_with_constant_leaves = function(leaf_value) { if (length(leaf_value) > 1) { - add_sample_vector_forest_container_cpp(self$forest_container_ptr, leaf_value) + add_sample_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) } else { - add_sample_value_forest_container_cpp(self$forest_container_ptr, leaf_value) + add_sample_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) } - }, - + }, + #' @description #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble #' @param forest_num Index of the forest which contains the tree to be split @@ -331,56 +466,101 @@ ForestSamples <- R6::R6Class( #' @param split_threshold Value that defines the cutoff of the new split #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function(forest_num, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) { + add_numeric_split_tree = function( + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + add_numeric_split_tree_vector_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) } else { - add_numeric_split_tree_value_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + add_numeric_split_tree_value_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) } - }, - + }, + #' @description #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest #' @param forest_num Index of the forest which contains tree `tree_num` #' @param tree_num Index of the tree for which leaf indices will be retrieved get_tree_leaves = function(forest_num, tree_num) { - return(get_tree_leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) - }, - + return(get_tree_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + #' @description #' Retrieve a vector of split counts for every training set variable in a given tree in a given forest #' @param forest_num Index of the forest which contains tree `tree_num` #' @param tree_num Index of the tree for which split counts will be retrieved #' @param num_features Total number of features in the training set get_tree_split_counts = function(forest_num, tree_num, num_features) { - return(get_tree_split_counts_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, num_features)) - }, - + return(get_tree_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + num_features + )) + }, + #' @description #' Retrieve a vector of split counts for every training set variable in a given forest #' @param forest_num Index of the forest for which split counts will be retrieved #' @param num_features Total number of features in the training set get_forest_split_counts = function(forest_num, num_features) { - return(get_forest_split_counts_forest_container_cpp(self$forest_container_ptr, forest_num, num_features)) - }, - + return(get_forest_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + num_features + )) + }, + #' @description #' Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees #' @param num_features Total number of features in the training set get_aggregate_split_counts = function(num_features) { - return(get_overall_split_counts_forest_container_cpp(self$forest_container_ptr, num_features)) - }, - + return(get_overall_split_counts_forest_container_cpp( + self$forest_container_ptr, + num_features + )) + }, + #' @description #' Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree #' @param num_features Total number of features in the training set get_granular_split_counts = function(num_features) { n_samples <- self$num_samples() n_trees <- self$num_trees() - output <- get_granular_split_count_array_forest_container_cpp(self$forest_container_ptr, num_features) + output <- get_granular_split_count_array_forest_container_cpp( + self$forest_container_ptr, + num_features + ) dim(output) <- c(n_samples, n_trees, num_features) return(output) - }, + }, #' @description #' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object @@ -388,40 +568,55 @@ ForestSamples <- R6::R6Class( #' @param tree_num Tree index within ensemble `ensemble_num` #' @return Maximum leaf depth ensemble_tree_max_depth = function(ensemble_num, tree_num) { - return(ensemble_tree_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num, tree_num)) - }, + return(ensemble_tree_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num, + tree_num + )) + }, #' @description #' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object #' @param ensemble_num Ensemble number #' @return Average maximum depth average_ensemble_max_depth = function(ensemble_num) { - return(ensemble_average_max_depth_forest_container_cpp(self$forest_container_ptr, ensemble_num)) - }, + return(ensemble_average_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num + )) + }, #' @description #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object #' @return Average maximum depth average_max_depth = function() { - return(average_max_depth_forest_container_cpp(self$forest_container_ptr)) - }, - + return(average_max_depth_forest_container_cpp( + self$forest_container_ptr + )) + }, + #' @description #' Number of leaves in a given ensemble in a `ForestSamples` object #' @param forest_num Index of the ensemble to be queried #' @return Count of leaves in the ensemble stored at `forest_num` num_forest_leaves = function(forest_num) { - return(num_leaves_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num)) - }, - + return(num_leaves_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + }, + #' @description #' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object #' @param forest_num Index of the ensemble to be queried #' @return Average maximum depth sum_leaves_squared = function(forest_num) { - return(sum_leaves_squared_ensemble_forest_container_cpp(self$forest_container_ptr, forest_num)) + return(sum_leaves_squared_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) }, - + #' @description #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf #' @param forest_num Index of the forest to be queried @@ -429,9 +624,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return `TRUE` if node is a leaf, `FALSE` otherwise is_leaf_node = function(forest_num, tree_num, node_id) { - return(is_leaf_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(is_leaf_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node #' @param forest_num Index of the forest to be queried @@ -439,9 +639,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return `TRUE` if node is a numeric split node, `FALSE` otherwise is_numeric_split_node = function(forest_num, tree_num, node_id) { - return(is_numeric_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(is_numeric_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node #' @param forest_num Index of the forest to be queried @@ -449,9 +654,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return `TRUE` if node is a categorical split node, `FALSE` otherwise is_categorical_split_node = function(forest_num, tree_num, node_id) { - return(is_categorical_split_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(is_categorical_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Parent node of given node of a given tree in a given forest in a `ForestSamples` object #' @param forest_num Index of the forest to be queried @@ -459,9 +669,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Integer ID of the parent node parent_node = function(forest_num, tree_num, node_id) { - return(parent_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(parent_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Left child node of given node of a given tree in a given forest in a `ForestSamples` object #' @param forest_num Index of the forest to be queried @@ -469,9 +684,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Integer ID of the left child node left_child_node = function(forest_num, tree_num, node_id) { - return(left_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(left_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Right child node of given node of a given tree in a given forest in a `ForestSamples` object #' @param forest_num Index of the forest to be queried @@ -479,9 +699,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Integer ID of the right child node right_child_node = function(forest_num, tree_num, node_id) { - return(right_child_node_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(right_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node. #' @param forest_num Index of the forest to be queried @@ -489,9 +714,14 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Integer valued depth of the node node_depth = function(forest_num, tree_num, node_id) { - return(node_depth_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(node_depth_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf. #' @param forest_num Index of the forest to be queried @@ -502,10 +732,15 @@ ForestSamples <- R6::R6Class( if (self$is_leaf_node(forest_num, tree_num, node_id)) { return(-1) } else { - return(split_index_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(split_index_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) } }, - + #' @description #' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object. #' Returns `Inf` if the node is a leaf or a categorical split node. @@ -514,14 +749,25 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Threshold defining a split for the node node_split_threshold = function(forest_num, tree_num, node_id) { - if (self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_categorical_split_node(forest_num, tree_num, node_id)) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_categorical_split_node( + forest_num, + tree_num, + node_id + ) + ) { return(Inf) } else { - return(split_theshold_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(split_theshold_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) } }, - + #' @description #' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object. #' Returns `c(Inf)` if the node is a leaf or a numeric split node. @@ -530,14 +776,21 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Categories defining a split for the node node_split_categories = function(forest_num, tree_num, node_id) { - if (self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_numeric_split_node(forest_num, tree_num, node_id)) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_numeric_split_node(forest_num, tree_num, node_id) + ) { return(c(Inf)) } else { - return(split_categories_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(split_categories_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) } }, - + #' @description #' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object. #' Values are stale if the node is a split node. @@ -546,68 +799,100 @@ ForestSamples <- R6::R6Class( #' @param node_id Index of the node to be queried #' @return Vector (often univariate) of leaf values node_leaf_values = function(forest_num, tree_num, node_id) { - return(leaf_values_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num, node_id)) + return(leaf_values_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) }, - + #' @description #' Number of nodes in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Count of total tree nodes num_nodes = function(forest_num, tree_num) { - return(num_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(num_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Number of leaves in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Count of total tree leaves num_leaves = function(forest_num, tree_num) { - return(num_leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(num_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Count of total tree leaf parents num_leaf_parents = function(forest_num, tree_num) { - return(num_leaf_parents_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(num_leaf_parents_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Number of split nodes in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Count of total tree split nodes num_split_nodes = function(forest_num, tree_num) { - return(num_split_nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(num_split_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Array of node indices in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Indices of tree nodes nodes = function(forest_num, tree_num) { - return(nodes_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object. #' @param forest_num Index of the forest to be queried #' @param tree_num Index of the tree to be queried #' @return Indices of leaf nodes leaves = function(forest_num, tree_num) { - return(leaves_forest_container_cpp(self$forest_container_ptr, forest_num, tree_num)) + return(leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) }, - + #' @description #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num #' @param forest_num Index of the forest to be removed delete_sample = function(forest_num) { - return(remove_sample_forest_container_cpp(self$forest_container_ptr, forest_num)) + return(remove_sample_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) } ) ) @@ -621,13 +906,12 @@ Forest <- R6::R6Class( classname = "Forest", cloneable = FALSE, public = list( - #' @field forest_ptr External pointer to a C++ TreeEnsemble class forest_ptr = NULL, - + #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. internal_forest_is_empty = TRUE, - + #' @description #' Create a new Forest object. #' @param num_trees Number of trees in the forest @@ -635,11 +919,21 @@ Forest <- R6::R6Class( #' @param is_leaf_constant Whether leaf is constant #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned #' @return A new `Forest` object. - initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=FALSE, is_exponentiated=FALSE) { - self$forest_ptr <- active_forest_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_ptr <- active_forest_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) self$internal_forest_is_empty <- TRUE - }, - + }, + #' @description #' Create a larger forest by merging the trees of this forest with those of another forest #' @param forest Forest to be merged into this forest @@ -648,22 +942,22 @@ Forest <- R6::R6Class( stopifnot(self$is_constant_leaf() == forest$is_constant_leaf()) stopifnot(self$is_exponentiated() == forest$is_exponentiated()) forest_merge_cpp(self$forest_ptr, forest$forest_ptr) - }, - + }, + #' @description #' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. #' @param constant_value Value that will be added to every leaf of every tree add_constant = function(constant_value) { forest_add_constant_cpp(self$forest_ptr, constant_value) - }, - + }, + #' @description #' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. #' @param constant_multiple Value that will be multiplied by every leaf of every tree multiply_constant = function(constant_multiple) { forest_multiply_constant_cpp(self$forest_ptr, constant_multiple) - }, - + }, + #' @description #' Predict forest on every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class @@ -671,130 +965,163 @@ Forest <- R6::R6Class( predict = function(forest_dataset) { stopifnot(!is.null(forest_dataset$data_ptr)) stopifnot(!is.null(self$forest_ptr)) - return(predict_active_forest_cpp(self$forest_ptr, forest_dataset$data_ptr)) - }, - + return(predict_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + )) + }, + #' @description #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is a vector (length is the number of - #' observations). In the case of a multivariate leaf regression, - #' this array is a matrix (number of observations by leaf model dimension, + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is a vector (length is the number of + #' observations). In the case of a multivariate leaf regression, + #' this array is a matrix (number of observations by leaf model dimension, #' number of samples). predict_raw = function(forest_dataset) { stopifnot(!is.null(forest_dataset$data_ptr)) # Unpack dimensions output_dim <- leaf_dimension_active_forest_cpp(self$forest_ptr) n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - + # Predict leaf values from forest - predictions <- predict_raw_active_forest_cpp(self$forest_ptr, forest_dataset$data_ptr) + predictions <- predict_raw_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + ) if (output_dim > 1) { dim(predictions) <- c(n, output_dim) } - + return(predictions) - }, - + }, + #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. set_root_leaves = function(leaf_value) { stopifnot(!is.null(self$forest_ptr)) stopifnot(self$internal_forest_is_empty) - + # Set leaf values if (length(leaf_value) == 1) { - stopifnot(leaf_dimension_active_forest_cpp(self$forest_ptr) == 1) + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == 1 + ) set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) } else if (length(leaf_value) > 1) { - stopifnot(leaf_dimension_active_forest_cpp(self$forest_ptr) == length(leaf_value)) + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == + length(leaf_value) + ) set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) } else { - stop("leaf_value must be a numeric value or vector of length >= 1") + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) } - + self$internal_forest_is_empty = FALSE - }, - + }, + #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) #' @param outcome `Outcome` Outcome class (residual / partial residual) #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function(dataset, outcome, forest_model, leaf_model_int, leaf_value) { + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { stopifnot(!is.null(dataset$data_ptr)) stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_ptr)) stopifnot(self$internal_forest_is_empty) - + # Initialize the model initialize_forest_model_active_forest_cpp( - dataset$data_ptr, outcome$data_ptr, self$forest_ptr, - forest_model$tracker_ptr, leaf_value, leaf_model_int + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int ) - + self$internal_forest_is_empty = FALSE - }, - + }, + #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling #' @param requires_basis Whether or not a forest requires a basis for prediction #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function(dataset, outcome, forest_model, requires_basis, add) { + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + add + ) { stopifnot(!is.null(dataset$data_ptr)) stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_ptr)) - + adjust_residual_active_forest_cpp( - dataset$data_ptr, outcome$data_ptr, self$forest_ptr, - forest_model$tracker_ptr, requires_basis, add + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + requires_basis, + add ) - }, - + }, + #' @description #' Return number of trees in each ensemble of a `Forest` object #' @return Tree count num_trees = function() { return(num_trees_active_forest_cpp(self$forest_ptr)) - }, - + }, + #' @description #' Return output dimension of trees in a `Forest` object #' @return Leaf node parameter size leaf_dimension = function() { return(leaf_dimension_active_forest_cpp(self$forest_ptr)) - }, - + }, + #' @description #' Return constant leaf status of trees in a `Forest` object #' @return `TRUE` if leaves are constant, `FALSE` otherwise is_constant_leaf = function() { return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) - }, - + }, + #' @description #' Return exponentiation status of trees in a `Forest` object #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise is_exponentiated = function() { return(is_exponentiated_active_forest_cpp(self$forest_ptr)) - }, - + }, + #' @description #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble #' @param tree_num Index of the tree to be split @@ -803,63 +1130,98 @@ Forest <- R6::R6Class( #' @param split_threshold Value that defines the cutoff of the new split #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function(tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) { + add_numeric_split_tree = function( + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_active_forest_cpp(self$forest_ptr, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + add_numeric_split_tree_vector_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) } else { - add_numeric_split_tree_value_active_forest_cpp(self$forest_ptr, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) + add_numeric_split_tree_value_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) } - }, - + }, + #' @description #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest #' @param tree_num Index of the tree for which leaf indices will be retrieved get_tree_leaves = function(tree_num) { return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) - }, - + }, + #' @description #' Retrieve a vector of split counts for every training set variable in a given tree in the forest #' @param tree_num Index of the tree for which split counts will be retrieved #' @param num_features Total number of features in the training set get_tree_split_counts = function(tree_num, num_features) { - return(get_tree_split_counts_active_forest_cpp(self$forest_ptr, tree_num, num_features)) - }, - + return(get_tree_split_counts_active_forest_cpp( + self$forest_ptr, + tree_num, + num_features + )) + }, + #' @description #' Retrieve a vector of split counts for every training set variable in the forest #' @param num_features Total number of features in the training set get_forest_split_counts = function(num_features) { - return(get_overall_split_counts_active_forest_cpp(self$forest_ptr, num_features)) - }, - + return(get_overall_split_counts_active_forest_cpp( + self$forest_ptr, + num_features + )) + }, + #' @description #' Maximum depth of a specific tree in the forest #' @param tree_num Tree index within forest #' @return Maximum leaf depth tree_max_depth = function(tree_num) { - return(ensemble_tree_max_depth_active_forest_cpp(self$forest_ptr, tree_num)) - }, - + return(ensemble_tree_max_depth_active_forest_cpp( + self$forest_ptr, + tree_num + )) + }, + #' @description #' Average the maximum depth of each tree in the forest #' @return Average maximum depth average_max_depth = function() { - return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' When a forest object is created, it is "empty" in the sense that none - #' of its component trees have leaves with values. There are two ways to - #' "initialize" a Forest object. First, the `set_root_leaves()` method - #' simply initializes every tree in the forest to a single node carrying - #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` - #' method initializes every tree in the forest to a single node with the + return(ensemble_average_max_depth_active_forest_cpp( + self$forest_ptr + )) + }, + + #' @description + #' When a forest object is created, it is "empty" in the sense that none + #' of its component trees have leaves with values. There are two ways to + #' "initialize" a Forest object. First, the `set_root_leaves()` method + #' simply initializes every tree in the forest to a single node carrying + #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + #' method initializes every tree in the forest to a single node with the #' same value and also propagates this information through to a ForestModel - #' object, which must be synchronized with a Forest during a forest + #' object, which must be synchronized with a Forest during a forest #' sampler loop. - #' @return `TRUE` if a Forest has not yet been initialized with a constant - #' root value, `FALSE` otherwise if the forest has already been + #' @return `TRUE` if a Forest has not yet been initialized with a constant + #' root value, `FALSE` otherwise if the forest has already been #' initialized / grown. is_empty = function() { return(self$internal_forest_is_empty) @@ -876,17 +1238,27 @@ Forest <- R6::R6Class( #' #' @return `ForestSamples` object #' @export -#' +#' #' @examples #' num_trees <- 100 #' leaf_dimension <- 2 #' is_leaf_constant <- FALSE #' is_exponentiated <- FALSE #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -createForestSamples <- function(num_trees, leaf_dimension=1, is_leaf_constant=FALSE, is_exponentiated=FALSE) { - return(invisible(( - ForestSamples$new(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) - ))) +createForestSamples <- function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE +) { + return(invisible( + (ForestSamples$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + )) + )) } #' Create a forest @@ -898,28 +1270,38 @@ createForestSamples <- function(num_trees, leaf_dimension=1, is_leaf_constant=FA #' #' @return `Forest` object #' @export -#' +#' #' @examples #' num_trees <- 100 #' leaf_dimension <- 2 #' is_leaf_constant <- FALSE #' is_exponentiated <- FALSE #' forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=FALSE, is_exponentiated=FALSE) { - return(invisible(( - Forest$new(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) - ))) +createForest <- function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE +) { + return(invisible( + (Forest$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + )) + )) } -#' Reset an active forest, either from a specific forest in a `ForestContainer` +#' Reset an active forest, either from a specific forest in a `ForestContainer` #' or to an ensemble of single-node (i.e. root) trees -#' +#' #' @param active_forest Current active forest #' @param forest_samples (Optional) Container of forest samples from which to re-initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees. #' @param forest_num (Optional) Index of forest samples from which to initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees. #' @return None #' @export -#' +#' #' @examples #' num_trees <- 100 #' leaf_dimension <- 1 @@ -933,20 +1315,30 @@ createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=FALSE, is #' active_forest$set_root_leaves(0.1) #' resetActiveForest(active_forest, forest_samples, 0) #' resetActiveForest(active_forest) -resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NULL) { +resetActiveForest <- function( + active_forest, + forest_samples = NULL, + forest_num = NULL +) { if (is.null(forest_samples)) { root_reset_active_forest_cpp(active_forest$forest_ptr) active_forest$internal_forest_is_empty = TRUE } else { if (is.null(forest_num)) { - stop("`forest_num` must be specified if `forest_samples` is provided") + stop( + "`forest_num` must be specified if `forest_samples` is provided" + ) } - reset_active_forest_cpp(active_forest$forest_ptr, forest_samples$forest_container_ptr, forest_num) + reset_active_forest_cpp( + active_forest$forest_ptr, + forest_samples$forest_container_ptr, + forest_num + ) } } #' Re-initialize a forest model (tracking data structures) from a specific forest in a `ForestContainer` -#' +#' #' @param forest_model Forest model with tracking data structures #' @param forest Forest from which to re-initialize forest model #' @param dataset Training dataset object @@ -954,7 +1346,7 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' @param is_mean_model Whether the model being updated is a conditional mean model #' @return None #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -980,27 +1372,39 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' outcome <- createOutcome(y) #' rng <- createCppRNG(1234) #' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -#' forest_model_config <- createForestModelConfig(feature_types=feature_types, -#' num_trees=num_trees, num_observations=n, -#' num_features=p, alpha=alpha, beta=beta, -#' min_samples_leaf=min_samples_leaf, -#' max_depth=max_depth, -#' variable_weights=variable_weights, -#' cutpoint_grid_size=cutpoint_grid_size, -#' leaf_model_type=leaf_model, +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_observations=n, +#' num_features=p, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, +#' variable_weights=variable_weights, +#' cutpoint_grid_size=cutpoint_grid_size, +#' leaf_model_type=leaf_model, #' leaf_model_scale=leaf_scale) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -#' forest_samples <- createForestSamples(num_trees, leaf_dimension, +#' forest_samples <- createForestSamples(num_trees, leaf_dimension, #' is_leaf_constant, is_exponentiated) #' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( -#' forest_dataset, outcome, forest_samples, active_forest, -#' rng, forest_model_config, global_model_config, +#' forest_dataset, outcome, forest_samples, active_forest, +#' rng, forest_model_config, global_model_config, #' keep_forest = TRUE, gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) -resetForestModel <- function(forest_model, forest, dataset, residual, is_mean_model) { - reset_forest_model_cpp(forest_model$tracker_ptr, forest$forest_ptr, dataset$data_ptr, residual$data_ptr, is_mean_model) +resetForestModel <- function( + forest_model, + forest, + dataset, + residual, + is_mean_model +) { + reset_forest_model_cpp( + forest_model$tracker_ptr, + forest$forest_ptr, + dataset$data_ptr, + residual$data_ptr, + is_mean_model + ) } diff --git a/R/generics.R b/R/generics.R index 1c73ad9a..1df1a174 100644 --- a/R/generics.R +++ b/R/generics.R @@ -1,10 +1,10 @@ #' Generic function for extracting random effect samples from a model object (BCF, BART, etc...) -#' +#' #' @param object Fitted model object from which to extract random effects #' @param ... Other parameters to be used in random effects extraction #' @return List of random effect samples #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -15,4 +15,6 @@ #' bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, #' rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) #' rfx_samples <- getRandomEffectSamples(bart_model) -getRandomEffectSamples <- function(object, ...) UseMethod("getRandomEffectSamples") +getRandomEffectSamples <- function(object, ...) { + UseMethod("getRandomEffectSamples") +} diff --git a/R/kernel.R b/R/kernel.R index 381e13bf..d7e9661e 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -1,45 +1,45 @@ #' Compute vector of forest leaf indices -#' -#' @description Compute and return a vector representation of a forest's leaf predictions for +#' +#' @description Compute and return a vector representation of a forest's leaf predictions for #' every observation in a dataset. -#' -#' The vector has a "row-major" format that can be easily re-represented as -#' as a CSR sparse matrix: elements are organized so that the first `n` elements -#' correspond to leaf predictions for all `n` observations in a dataset for the -#' first tree in an ensemble, the next `n` elements correspond to predictions for -#' the second tree and so on. The "data" for each element corresponds to a uniquely -#' mapped column index that corresponds to a single leaf of a single tree (i.e. -#' if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's +#' +#' The vector has a "row-major" format that can be easily re-represented as +#' as a CSR sparse matrix: elements are organized so that the first `n` elements +#' correspond to leaf predictions for all `n` observations in a dataset for the +#' first tree in an ensemble, the next `n` elements correspond to predictions for +#' the second tree and so on. The "data" for each element corresponds to a uniquely +#' mapped column index that corresponds to a single leaf of a single tree (i.e. +#' if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's #' leaf indices begin at 3, etc...). #' #' @param model_object Object of type `bartmodel`, `bcfmodel`, or `ForestSamples` corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestSamples` object. #' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest. -#' @param forest_type Which forest to use from `model_object`. +#' @param forest_type Which forest to use from `model_object`. #' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model. -#' +#' #' **1. BART** #' #' - `'mean'`: Extracts leaf indices for the mean forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' +#' #' **2. BCF** #' #' - `'prognostic'`: Extracts leaf indices for the prognostic forest #' - `'treatment'`: Extracts leaf indices for the treatment effect forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' +#' #' **3. ForestSamples** #' #' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this -#' +#' #' @param propensity (Optional) Propensities used for prediction (BCF-only). -#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, -#' this function will return leaf indices for every sample of a forest. +#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, +#' this function will return leaf indices for every sample of a forest. #' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. -#' @return Vector of size `num_obs * num_trees`, where `num_obs = nrow(covariates)` -#' and `num_trees` is the number of trees in the relevant forest of `model_object`. +#' @return Vector of size `num_obs * num_trees`, where `num_obs = nrow(covariates)` +#' and `num_trees` is the number of trees in the relevant forest of `model_object`. #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) @@ -47,53 +47,76 @@ #' computeForestLeafIndices(bart_model, X, "mean") #' computeForestLeafIndices(bart_model, X, "mean", 0) #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) -computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, propensity=NULL, forest_inds=NULL) { +computeForestLeafIndices <- function( + model_object, + covariates, + forest_type = NULL, + propensity = NULL, + forest_inds = NULL +) { # Extract relevant forest container - stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) - model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type=="mean") { + if (forest_type == "mean") { if (!model_object$model_params$include_mean_forest) { stop("Mean forest was not sampled in the bart model provided") } forest_container <- model_object$mean_forests - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bart model provided") + stop( + "Variance forest was not sampled in the bart model provided" + ) } forest_container <- model_object$variance_forests } } else if (model_type == "bcf") { stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type=="prognostic") { + if (forest_type == "prognostic") { forest_container <- model_object$forests_mu - } else if (forest_type=="treatment") { + } else if (forest_type == "treatment") { forest_container <- model_object$forests_tau - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bcf model provided") + stop( + "Variance forest was not sampled in the bcf model provided" + ) } forest_container <- model_object$variance_forests } } else { forest_container <- model_object } - + # Preprocess covariates if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { stop("covariates must be a matrix or dataframe") } if (model_type %in% c("bart", "bcf")) { train_set_metadata <- model_object$train_set_metadata - covariates_processed <- preprocessPredictionData(covariates, train_set_metadata) + covariates_processed <- preprocessPredictionData( + covariates, + train_set_metadata + ) } else { if (!is.matrix(covariates)) { - stop("covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`") + stop( + "covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`" + ) } covariates_processed <- covariates } - + # Handle BCF propensity covariate if (model_type == "bcf") { # Add propensities to covariate set if necessary @@ -103,59 +126,65 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, stop("propensity must be provided for this model") } # Compute propensity score using the internal bart model - propensity <- rowMeans(predict(model_object$bart_propensity_model, covariates)$y_hat) + propensity <- rowMeans( + predict( + model_object$bart_propensity_model, + covariates + )$y_hat + ) } covariates_processed <- cbind(covariates_processed, propensity) } } - + # Preprocess forest indices num_forests <- forest_container$num_samples() if (is.null(forest_inds)) { forest_inds <- as.integer(1:num_forests - 1) } else { - stopifnot(all(forest_inds <= num_forests-1)) + stopifnot(all(forest_inds <= num_forests - 1)) stopifnot(all(forest_inds >= 0)) forest_inds <- as.integer(forest_inds) } - + # Compute leaf indices leaf_ind_matrix <- compute_leaf_indices_cpp( - forest_container$forest_container_ptr, - covariates_processed, forest_inds + forest_container$forest_container_ptr, + covariates_processed, + forest_inds ) return(leaf_ind_matrix) } #' Compute vector of forest leaf scale parameters -#' +#' #' @description Return each forest's leaf node scale parameters. -#' -#' If leaf scale is not sampled for the forest in question, throws an error that the +#' +#' If leaf scale is not sampled for the forest in question, throws an error that the #' leaf model does not have a stochastic scale parameter. -#' +#' #' @param model_object Object of type `bartmodel` or `bcfmodel` corresponding to a BART / BCF model with at least one forest sample -#' @param forest_type Which forest to use from `model_object`. +#' @param forest_type Which forest to use from `model_object`. #' Valid inputs depend on the model type, and whether or not a given forest was sampled in that model. -#' +#' #' **1. BART** #' #' - `'mean'`: Extracts leaf indices for the mean forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' +#' #' **2. BCF** #' #' - `'prognostic'`: Extracts leaf indices for the prognostic forest #' - `'treatment'`: Extracts leaf indices for the treatment effect forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' -#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, -#' this function will return leaf indices for every sample of a forest. +#' +#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided, +#' this function will return leaf indices for every sample of a forest. #' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. #' @return Vector of size `length(forest_inds)` with the leaf scale parameter for each requested forest. #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) @@ -163,56 +192,77 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, #' computeForestLeafVariances(bart_model, "mean") #' computeForestLeafVariances(bart_model, "mean", 0) #' computeForestLeafVariances(bart_model, "mean", c(1,3,5)) -computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) { +computeForestLeafVariances <- function( + model_object, + forest_type, + forest_inds = NULL +) { # Extract relevant forest container - stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel")))) + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel") + ))) model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type=="mean") { + if (forest_type == "mean") { if (!model_object$model_params$include_mean_forest) { stop("Mean forest was not sampled in the bart model provided") } if (!model_object$model_params$sample_sigma2_leaf) { - stop("Leaf scale parameter was not sampled for the mean forest in the bart model provided") + stop( + "Leaf scale parameter was not sampled for the mean forest in the bart model provided" + ) } leaf_scale_vector <- model_object$sigma2_leaf_samples - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bart model provided") + stop( + "Variance forest was not sampled in the bart model provided" + ) } - stop("Leaf scale parameter was not sampled for the variance forest in the bart model provided") + stop( + "Leaf scale parameter was not sampled for the variance forest in the bart model provided" + ) } } else { stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type=="prognostic") { + if (forest_type == "prognostic") { if (!model_object$model_params$sample_sigma2_leaf_mu) { - stop("Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided") + stop( + "Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided" + ) } leaf_scale_vector <- model_object$sigma2_leaf_mu_samples - } else if (forest_type=="treatment") { + } else if (forest_type == "treatment") { if (!model_object$model_params$sample_sigma2_leaf_tau) { - stop("Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided") + stop( + "Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided" + ) } leaf_scale_vector <- model_object$sigma2_leaf_tau_samples - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bcf model provided") + stop( + "Variance forest was not sampled in the bcf model provided" + ) } - stop("Leaf scale parameter was not sampled for the variance forest in the bcf model provided") + stop( + "Leaf scale parameter was not sampled for the variance forest in the bcf model provided" + ) } } - + # Preprocess forest indices num_forests <- model_object$model_params$num_samples if (is.null(forest_inds)) { forest_inds <- as.integer(1:num_forests) } else { - stopifnot(all(forest_inds <= num_forests-1)) + stopifnot(all(forest_inds <= num_forests - 1)) stopifnot(all(forest_inds >= 0)) forest_inds <- as.integer(forest_inds + 1) } - + # Gather leaf scale parameters leaf_scale_params <- leaf_scale_vector[forest_inds] @@ -222,30 +272,30 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU #' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container. #' #' @param model_object Object of type `bartmodel`, `bcfmodel`, or `ForestSamples` corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestSamples` object. -#' @param forest_type Which forest to use from `model_object`. -#' Valid inputs depend on the model type, and whether or not a -#' +#' @param forest_type Which forest to use from `model_object`. +#' Valid inputs depend on the model type, and whether or not a +#' #' **1. BART** #' #' - `'mean'`: Extracts leaf indices for the mean forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' +#' #' **2. BCF** #' #' - `'prognostic'`: Extracts leaf indices for the prognostic forest #' - `'treatment'`: Extracts leaf indices for the treatment effect forest #' - `'variance'`: Extracts leaf indices for the variance forest -#' +#' #' **3. ForestSamples** #' #' - `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this -#' -#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute max leaf indices. If not provided, -#' this function will return max leaf indices for every sample of a forest. +#' +#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute max leaf indices. If not provided, +#' this function will return max leaf indices for every sample of a forest. #' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. #' @return Vector containing the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container. #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) @@ -253,56 +303,73 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU #' computeForestMaxLeafIndex(bart_model, "mean") #' computeForestMaxLeafIndex(bart_model, "mean", 0) #' computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9)) -computeForestMaxLeafIndex <- function(model_object, forest_type=NULL, forest_inds=NULL) { +computeForestMaxLeafIndex <- function( + model_object, + forest_type = NULL, + forest_inds = NULL +) { # Extract relevant forest container - stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) - model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type=="mean") { + if (forest_type == "mean") { if (!model_object$model_params$include_mean_forest) { stop("Mean forest was not sampled in the bart model provided") } forest_container <- model_object$mean_forests - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bart model provided") + stop( + "Variance forest was not sampled in the bart model provided" + ) } forest_container <- model_object$variance_forests } } else if (model_type == "bcf") { stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type=="prognostic") { + if (forest_type == "prognostic") { forest_container <- model_object$forests_mu - } else if (forest_type=="treatment") { + } else if (forest_type == "treatment") { forest_container <- model_object$forests_tau - } else if (forest_type=="variance") { + } else if (forest_type == "variance") { if (!model_object$model_params$include_variance_forest) { - stop("Variance forest was not sampled in the bcf model provided") + stop( + "Variance forest was not sampled in the bcf model provided" + ) } forest_container <- model_object$variance_forests } } else { forest_container <- model_object } - + # Preprocess forest indices num_forests <- forest_container$num_samples() if (is.null(forest_inds)) { forest_inds <- as.integer(1:num_forests - 1) } else { - stopifnot(all(forest_inds <= num_forests-1)) + stopifnot(all(forest_inds <= num_forests - 1)) stopifnot(all(forest_inds >= 0)) forest_inds <- as.integer(forest_inds) } - + # Compute leaf indices output <- rep(NA, length(forest_inds)) for (i in 1:length(forest_inds)) { output[i] <- forest_container_get_max_leaf_index_cpp( - forest_container$forest_container_ptr,forest_inds[i] + forest_container$forest_container_ptr, + forest_inds[i] ) } - + return(output) } diff --git a/R/model.R b/R/model.R index 57523c0d..38df5970 100644 --- a/R/model.R +++ b/R/model.R @@ -1,15 +1,14 @@ #' Class that wraps a C++ random number generator (for reproducibility) #' #' @description -#' Persists a C++ random number generator throughout an R session to -#' ensure reproducibility from a given random seed. If no seed is provided, +#' Persists a C++ random number generator throughout an R session to +#' ensure reproducibility from a given random seed. If no seed is provided, #' the C++ random number generator is initialized using `std::random_device`. CppRNG <- R6::R6Class( classname = "CppRNG", cloneable = FALSE, public = list( - #' @field rng_ptr External pointer to a C++ std::mt19937 class rng_ptr = NULL, @@ -26,21 +25,20 @@ CppRNG <- R6::R6Class( #' Class that defines and samples a forest model #' #' @description -#' Hosts the C++ data structures needed to sample an ensemble of decision -#' trees, and exposes functionality to run a forest sampler +#' Hosts the C++ data structures needed to sample an ensemble of decision +#' trees, and exposes functionality to run a forest sampler #' (using either MCMC or the grow-from-root algorithm). ForestModel <- R6::R6Class( classname = "ForestModel", cloneable = FALSE, public = list( - #' @field tracker_ptr External pointer to a C++ ForestTracker class tracker_ptr = NULL, - + #' @field tree_prior_ptr External pointer to a C++ TreePrior class - tree_prior_ptr = NULL, - + tree_prior_ptr = NULL, + #' @description #' Create a new ForestModel object. #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures @@ -52,12 +50,31 @@ ForestModel <- R6::R6Class( #' @param min_samples_leaf Minimum number of samples in a tree leaf #' @param max_depth Maximum depth that any tree can reach #' @return A new `ForestModel` object. - initialize = function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1) { + initialize = function( + forest_dataset, + feature_types, + num_trees, + n, + alpha, + beta, + min_samples_leaf, + max_depth = -1 + ) { stopifnot(!is.null(forest_dataset$data_ptr)) - self$tracker_ptr <- forest_tracker_cpp(forest_dataset$data_ptr, feature_types, num_trees, n) - self$tree_prior_ptr <- tree_prior_cpp(alpha, beta, min_samples_leaf, max_depth) - }, - + self$tracker_ptr <- forest_tracker_cpp( + forest_dataset$data_ptr, + feature_types, + num_trees, + n + ) + self$tree_prior_ptr <- tree_prior_cpp( + alpha, + beta, + min_samples_leaf, + max_depth + ) + }, + #' @description #' Run a single iteration of the forest sampling algorithm (MCMC or GFR) #' @param forest_dataset Dataset used to sample the forest @@ -70,11 +87,22 @@ ForestModel <- R6::R6Class( #' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads. #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. - sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, - rng, forest_model_config, global_model_config, num_threads = -1, - keep_forest = TRUE, gfr = TRUE) { + sample_one_iteration = function( + forest_dataset, + residual, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + num_threads = -1, + keep_forest = TRUE, + gfr = TRUE + ) { if (active_forest$is_empty()) { - stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") + stop( + "`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods." + ) } # Unpack parameters from model config object @@ -88,61 +116,113 @@ ForestModel <- R6::R6Class( global_scale <- global_model_config$global_error_variance cutpoint_grid_size <- forest_model_config$cutpoint_grid_size num_features_subsample <- forest_model_config$num_features_subsample - + # Default to empty integer vector if sweep_update_indices is NULL if (is.null(sweep_update_indices)) { # sweep_update_indices <- integer(0) sweep_update_indices <- 0:(forest_model_config$num_trees - 1) } - + # Detect changes to tree prior - if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) { - update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha) + if ( + forest_model_config$alpha != + get_alpha_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_alpha_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$alpha + ) } - if (forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr)) { - update_beta_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$beta) + if ( + forest_model_config$beta != + get_beta_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_beta_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$beta + ) } - if (forest_model_config$min_samples_leaf != get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)) { - update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$min_samples_leaf) + if ( + forest_model_config$min_samples_leaf != + get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$min_samples_leaf + ) } - if (forest_model_config$max_depth != get_max_depth_tree_prior_cpp(self$tree_prior_ptr)) { - update_max_depth_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$max_depth) + if ( + forest_model_config$max_depth != + get_max_depth_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_max_depth_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$max_depth + ) } - + # Run the sampler if (gfr) { sample_gfr_one_iteration_cpp( - forest_dataset$data_ptr, residual$data_ptr, - forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, - self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, num_features_subsample, + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_features_subsample, num_threads ) } else { sample_mcmc_one_iteration_cpp( - forest_dataset$data_ptr, residual$data_ptr, - forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, - self$tree_prior_ptr, rng$rng_ptr, sweep_update_indices, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, num_threads - ) + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_threads + ) } - }, - + }, + #' @description #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. #' @return Vector with as many elements as observations in the training dataset get_cached_forest_predictions = function() { get_cached_forest_predictions_cpp(self$tracker_ptr) - }, - + }, + #' @description - #' Propagates basis update through to the (full/partial) residual by iteratively - #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions + #' Propagates basis update through to the (full/partial) residual by iteratively + #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. - #' - #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside - #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). - #' Once a basis has been updated, the overall "function" represented by a tree model has + #' + #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside + #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). + #' Once a basis has been updated, the overall "function" represented by a tree model has #' changed and this should be reflected through to the residual before the next sampling loop is run. #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions @@ -152,75 +232,83 @@ ForestModel <- R6::R6Class( stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(self$tracker_ptr)) stopifnot(!is.null(active_forest$forest_ptr)) - + propagate_basis_update_active_forest_cpp( - dataset$data_ptr, outcome$data_ptr, active_forest$forest_ptr, + dataset$data_ptr, + outcome$data_ptr, + active_forest$forest_ptr, self$tracker_ptr ) - }, - + }, + #' @description - #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. #' @param residual Outcome used to sample the forest #' @return None propagate_residual_update = function(residual) { - propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr) - }, - + propagate_trees_column_vector_cpp( + self$tracker_ptr, + residual$data_ptr + ) + }, + #' @description #' Update alpha in the tree prior #' @param alpha New value of alpha to be used #' @return None update_alpha = function(alpha) { update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) - }, - + }, + #' @description #' Update beta in the tree prior #' @param beta New value of beta to be used #' @return None update_beta = function(beta) { update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) - }, - + }, + #' @description #' Update min_samples_leaf in the tree prior #' @param min_samples_leaf New value of min_samples_leaf to be used #' @return None update_min_samples_leaf = function(min_samples_leaf) { - update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, min_samples_leaf) - }, - + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + min_samples_leaf + ) + }, + #' @description #' Update max_depth in the tree prior #' @param max_depth New value of max_depth to be used #' @return None update_max_depth = function(max_depth) { update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) - }, - + }, + #' @description #' Update alpha in the tree prior #' @return Value of alpha in the tree prior get_alpha = function() { get_alpha_tree_prior_cpp(self$tree_prior_ptr) - }, - + }, + #' @description #' Update beta in the tree prior #' @return Value of beta in the tree prior get_beta = function() { get_beta_tree_prior_cpp(self$tree_prior_ptr) - }, - + }, + #' @description #' Query min_samples_leaf in the tree prior #' @return Value of min_samples_leaf in the tree prior get_min_samples_leaf = function() { get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) - }, - + }, + #' @description #' Query max_depth in the tree prior #' @return Value of max_depth in the tree prior @@ -236,14 +324,12 @@ ForestModel <- R6::R6Class( #' #' @return `CppRng` object #' @export -#' +#' #' @examples #' rng <- createCppRNG(1234) #' rng <- createCppRNG() -createCppRNG <- function(random_seed = -1){ - return(invisible(( - CppRNG$new(random_seed) - ))) +createCppRNG <- function(random_seed = -1) { + return(invisible((CppRNG$new(random_seed)))) } #' Create a forest model object @@ -254,7 +340,7 @@ createCppRNG <- function(random_seed = -1){ #' #' @return `ForestModel` object #' @export -#' +#' #' @examples #' num_trees <- 100 #' n <- 100 @@ -266,19 +352,30 @@ createCppRNG <- function(random_seed = -1){ #' feature_types <- as.integer(rep(0, p)) #' X <- matrix(runif(n*p), ncol = p) #' forest_dataset <- createForestDataset(X) -#' forest_model_config <- createForestModelConfig(feature_types=feature_types, -#' num_trees=num_trees, num_features=p, -#' num_observations=n, alpha=alpha, beta=beta, -#' min_samples_leaf=min_samples_leaf, +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_features=p, +#' num_observations=n, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, #' max_depth=max_depth, leaf_model_type=1) #' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) -createForestModel <- function(forest_dataset, forest_model_config, global_model_config) { - return(invisible(( - ForestModel$new(forest_dataset, forest_model_config$feature_types, forest_model_config$num_trees, - forest_model_config$num_observations, forest_model_config$alpha, forest_model_config$beta, - forest_model_config$min_samples_leaf, forest_model_config$max_depth) - ))) +createForestModel <- function( + forest_dataset, + forest_model_config, + global_model_config +) { + return(invisible( + (ForestModel$new( + forest_dataset, + forest_model_config$feature_types, + forest_model_config$num_trees, + forest_model_config$num_observations, + forest_model_config$alpha, + forest_model_config$beta, + forest_model_config$min_samples_leaf, + forest_model_config$max_depth + )) + )) } @@ -296,6 +393,14 @@ createForestModel <- function(forest_dataset, forest_model_config, global_model_ #' p <- c(0.7,0.2,0.05,0.02,0.01,0.01,0.01) #' num_samples <- 5 #' sample_without_replacement(a, p, num_samples) -sample_without_replacement <- function(population_vector, sampling_probabilities, sample_size) { - return(sample_without_replacement_integer_cpp(population_vector, sampling_probabilities, sample_size)) +sample_without_replacement <- function( + population_vector, + sampling_probabilities, + sample_size +) { + return(sample_without_replacement_integer_cpp( + population_vector, + sampling_probabilities, + sample_size + )) } diff --git a/R/random_effects.R b/R/random_effects.R index d737ef8e..b91c2678 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -1,45 +1,55 @@ #' Class that wraps the "persistent" aspects of a C++ random effects model -#' (draws of the parameters and a map from the original label indices to the -#' 0-indexed label numbers used to place group samples in memory (i.e. the -#' first label is stored in column 0 of the sample matrix, the second label +#' (draws of the parameters and a map from the original label indices to the +#' 0-indexed label numbers used to place group samples in memory (i.e. the +#' first label is stored in column 0 of the sample matrix, the second label #' is store in column 1 of the sample matrix, etc...)) #' #' @description -#' Coordinates various C++ random effects classes and persists those +#' Coordinates various C++ random effects classes and persists those #' needed for prediction / serialization RandomEffectSamples <- R6::R6Class( classname = "RandomEffectSamples", cloneable = FALSE, public = list( - #' @field rfx_container_ptr External pointer to a C++ StochTree::RandomEffectsContainer class rfx_container_ptr = NULL, - + #' @field label_mapper_ptr External pointer to a C++ StochTree::LabelMapper class label_mapper_ptr = NULL, - + #' @field training_group_ids Unique vector of group IDs that were in the training dataset training_group_ids = NULL, - + #' @description #' Create a new RandomEffectSamples object. #' @return A new `RandomEffectSamples` object. - initialize = function() {}, - + initialize = function() {}, + #' @description #' Construct RandomEffectSamples object from other "in-session" R objects #' @param num_components Number of "components" or bases defining the random effects regression #' @param num_groups Number of random effects groups #' @param random_effects_tracker Object of type `RandomEffectsTracker` #' @return None - load_in_session = function(num_components, num_groups, random_effects_tracker) { + load_in_session = function( + num_components, + num_groups, + random_effects_tracker + ) { # Initialize - self$rfx_container_ptr <- rfx_container_cpp(num_components, num_groups) - self$label_mapper_ptr <- rfx_label_mapper_cpp(random_effects_tracker$rfx_tracker_ptr) - self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp(random_effects_tracker$rfx_tracker_ptr) - }, - + self$rfx_container_ptr <- rfx_container_cpp( + num_components, + num_groups + ) + self$label_mapper_ptr <- rfx_label_mapper_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + }, + #' @description #' Construct RandomEffectSamples object from a json object #' @param json_object Object of class `CppJson` @@ -47,12 +57,26 @@ RandomEffectSamples <- R6::R6Class( #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy #' @return A new `RandomEffectSamples` object. - load_from_json = function(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { - self$rfx_container_ptr <- rfx_container_from_json_cpp(json_object$json_ptr, json_rfx_container_label) - self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp(json_object$json_ptr, json_rfx_mapper_label) - self$training_group_ids <- rfx_group_ids_from_json_cpp(json_object$json_ptr, json_rfx_groupids_label) - }, - + load_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_cpp( + json_object$json_ptr, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp( + json_object$json_ptr, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_cpp( + json_object$json_ptr, + json_rfx_groupids_label + ) + }, + #' @description #' Append random effect draws to `RandomEffectSamples` object from a json object #' @param json_object Object of class `CppJson` @@ -60,10 +84,19 @@ RandomEffectSamples <- R6::R6Class( #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy #' @return None - append_from_json = function(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { - rfx_container_append_from_json_cpp(self$rfx_container_ptr, json_object$json_ptr, json_rfx_container_label) - }, - + append_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + rfx_container_append_from_json_cpp( + self$rfx_container_ptr, + json_object$json_ptr, + json_rfx_container_label + ) + }, + #' @description #' Construct RandomEffectSamples object from a json object #' @param json_string JSON string which parses into object of class `CppJson` @@ -71,12 +104,26 @@ RandomEffectSamples <- R6::R6Class( #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy #' @return A new `RandomEffectSamples` object. - load_from_json_string = function(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { - self$rfx_container_ptr <- rfx_container_from_json_string_cpp(json_string, json_rfx_container_label) - self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp(json_string, json_rfx_mapper_label) - self$training_group_ids <- rfx_group_ids_from_json_string_cpp(json_string, json_rfx_groupids_label) - }, - + load_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_string_cpp( + json_string, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp( + json_string, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_string_cpp( + json_string, + json_rfx_groupids_label + ) + }, + #' @description #' Append random effect draws to `RandomEffectSamples` object from a json object #' @param json_string JSON string which parses into object of class `CppJson` @@ -84,45 +131,67 @@ RandomEffectSamples <- R6::R6Class( #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy #' @return None - append_from_json_string = function(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) { + append_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { # Append RFX objects - rfx_container_append_from_json_string_cpp(self$rfx_container_ptr, json_string, json_rfx_container_label) - }, - + rfx_container_append_from_json_string_cpp( + self$rfx_container_ptr, + json_string, + json_rfx_container_label + ) + }, + #' @description - #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. + #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. #' @param rfx_group_ids Indices of random effects groups in a prediction set #' @param rfx_basis (Optional) Basis used for random effects prediction #' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model. predict = function(rfx_group_ids, rfx_basis = NULL) { num_obs = length(rfx_group_ids) - if (is.null(rfx_basis)) rfx_basis <- matrix(rep(1,num_obs), ncol = 1) + if (is.null(rfx_basis)) { + rfx_basis <- matrix(rep(1, num_obs), ncol = 1) + } num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) rfx_group_ids_int <- as.integer(rfx_group_ids) - stopifnot(sum(abs(rfx_group_ids_int-rfx_group_ids)) < 1e-6) + stopifnot(sum(abs(rfx_group_ids_int - rfx_group_ids)) < 1e-6) stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0) stopifnot(ncol(rfx_basis) == num_components) - rfx_dataset <- createRandomEffectsDataset(rfx_group_ids_int, rfx_basis) - output <- rfx_container_predict_cpp(self$rfx_container_ptr, rfx_dataset$data_ptr, self$label_mapper_ptr) + rfx_dataset <- createRandomEffectsDataset( + rfx_group_ids_int, + rfx_basis + ) + output <- rfx_container_predict_cpp( + self$rfx_container_ptr, + rfx_dataset$data_ptr, + self$label_mapper_ptr + ) dim(output) <- c(num_obs, num_samples) return(output) - }, - + }, + #' @description - #' Extract the random effects parameters sampled. With the "redundant parameterization" - #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" - #' shared across every group), xi (the "group parameter" sampled separately for each group), - #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), + #' Extract the random effects parameters sampled. With the "redundant parameterization" + #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" + #' shared across every group), xi (the "group parameter" sampled separately for each group), + #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), #' and sigma (group-independent prior variance for each component of xi). #' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. #' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and is simply a matrix if `num_components = 1`. #' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. extract_parameter_samples = function() { num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) beta_samples <- rfx_container_get_beta_cpp(self$rfx_container_ptr) xi_samples <- rfx_container_get_xi_cpp(self$rfx_container_ptr) @@ -136,24 +205,28 @@ RandomEffectSamples <- R6::R6Class( dim(xi_samples) <- c(num_components, num_groups, num_samples) dim(alpha_samples) <- c(num_components, num_samples) dim(sigma_samples) <- c(num_components, num_samples) - } else stop("Invalid random effects sample container, num_components is less than 1") - + } else { + stop( + "Invalid random effects sample container, num_components is less than 1" + ) + } + output = list( - "beta_samples" = beta_samples, - "xi_samples" = xi_samples, - "alpha_samples" = alpha_samples, + "beta_samples" = beta_samples, + "xi_samples" = xi_samples, + "alpha_samples" = alpha_samples, "sigma_samples" = sigma_samples ) return(output) - }, - + }, + #' @description #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. #' @param sample_num Index of the RFX sample to be removed delete_sample = function(sample_num) { rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) - }, - + }, + #' @description #' Convert the mapping of group IDs to random effect components indices from C++ to R native format #' @return List mapping group ID to random effect components. @@ -166,23 +239,22 @@ RandomEffectSamples <- R6::R6Class( ) ) -#' Class that defines a "tracker" for random effects models, most notably -#' storing the data indices available in each group for quicker posterior +#' Class that defines a "tracker" for random effects models, most notably +#' storing the data indices available in each group for quicker posterior #' computation and sampling of random effects terms. #' #' @description -#' Stores a mapping from every observation to its group index, a mapping -#' from group indices to the training sample observations available in that +#' Stores a mapping from every observation to its group index, a mapping +#' from group indices to the training sample observations available in that #' group, and predictions for each observation. RandomEffectsTracker <- R6::R6Class( classname = "RandomEffectsTracker", cloneable = FALSE, public = list( - #' @field rfx_tracker_ptr External pointer to a C++ StochTree::RandomEffectsTracker class rfx_tracker_ptr = NULL, - + #' @description #' Create a new RandomEffectsTracker object. #' @param rfx_group_indices Integer indices indicating groups used to define random effects @@ -197,23 +269,22 @@ RandomEffectsTracker <- R6::R6Class( #' The core "model" class for sampling random effects. #' #' @description -#' Stores current model state, prior parameters, and procedures for +#' Stores current model state, prior parameters, and procedures for #' sampling from the conditional posterior of each parameter. RandomEffectsModel <- R6::R6Class( classname = "RandomEffectsModel", cloneable = FALSE, public = list( - #' @field rfx_model_ptr External pointer to a C++ StochTree::RandomEffectsModel class rfx_model_ptr = NULL, - + #' @field num_groups Number of groups in the random effects model num_groups = NULL, - + #' @field num_components Number of components (i.e. dimension of basis) in the random effects model num_components = NULL, - + #' @description #' Create a new RandomEffectsModel object. #' @param num_components Number of "components" or bases defining the random effects regression @@ -225,7 +296,7 @@ RandomEffectsModel <- R6::R6Class( self$num_components <- num_components self$num_groups <- num_groups }, - + #' @description #' Sample from random effects model. #' @param rfx_dataset Object of type `RandomEffectsDataset` @@ -236,25 +307,44 @@ RandomEffectsModel <- R6::R6Class( #' @param global_variance Scalar global variance parameter #' @param rng Object of type `CppRNG` #' @return None - sample_random_effect = function(rfx_dataset, residual, rfx_tracker, rfx_samples, keep_sample, global_variance, rng) { - rfx_model_sample_random_effects_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, - residual$data_ptr, rfx_tracker$rfx_tracker_ptr, - rfx_samples$rfx_container_ptr, keep_sample, global_variance, rng$rng_ptr) + sample_random_effect = function( + rfx_dataset, + residual, + rfx_tracker, + rfx_samples, + keep_sample, + global_variance, + rng + ) { + rfx_model_sample_random_effects_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_tracker$rfx_tracker_ptr, + rfx_samples$rfx_container_ptr, + keep_sample, + global_variance, + rng$rng_ptr + ) }, - + #' @description #' Predict from (a single sample of a) random effects model. #' @param rfx_dataset Object of type `RandomEffectsDataset` #' @param rfx_tracker Object of type `RandomEffectsTracker` #' @return Vector of predictions with size matching number of observations in rfx_dataset predict = function(rfx_dataset, rfx_tracker) { - pred <- rfx_model_predict_cpp(self$rfx_model_ptr, rfx_dataset$data_ptr, rfx_tracker$rfx_tracker_ptr) + pred <- rfx_model_predict_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + rfx_tracker$rfx_tracker_ptr + ) return(pred) }, - + #' @description - #' Set value for the "working parameter." This is typically - #' used for initialization, but could also be used to interrupt + #' Set value for the "working parameter." This is typically + #' used for initialization, but could also be used to interrupt #' or override the sampler. #' @param value Parameter input #' @return None @@ -264,10 +354,10 @@ RandomEffectsModel <- R6::R6Class( stopifnot(length(value) == self$num_components) rfx_model_set_working_parameter_cpp(self$rfx_model_ptr, value) }, - + #' @description - #' Set value for the "group parameters." This is typically - #' used for initialization, but could also be used to interrupt + #' Set value for the "group parameters." This is typically + #' used for initialization, but could also be used to interrupt #' or override the sampler. #' @param value Parameter input #' @return None @@ -278,10 +368,10 @@ RandomEffectsModel <- R6::R6Class( stopifnot(ncol(value) == self$num_groups) rfx_model_set_group_parameters_cpp(self$rfx_model_ptr, value) }, - + #' @description - #' Set value for the working parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt + #' Set value for the working parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt #' or override the sampler. #' @param value Parameter input #' @return None @@ -290,12 +380,15 @@ RandomEffectsModel <- R6::R6Class( stopifnot(is.matrix(value)) stopifnot(nrow(value) == self$num_components) stopifnot(ncol(value) == self$num_components) - rfx_model_set_working_parameter_covariance_cpp(self$rfx_model_ptr, value) + rfx_model_set_working_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) }, - + #' @description - #' Set value for the group parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt + #' Set value for the group parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt #' or override the sampler. #' @param value Parameter input #' @return None @@ -304,9 +397,12 @@ RandomEffectsModel <- R6::R6Class( stopifnot(is.matrix(value)) stopifnot(nrow(value) == self$num_components) stopifnot(ncol(value) == self$num_components) - rfx_model_set_group_parameter_covariance_cpp(self$rfx_model_ptr, value) - }, - + rfx_model_set_group_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) + }, + #' @description #' Set shape parameter for the group parameter variance prior. #' @param value Parameter input @@ -317,7 +413,7 @@ RandomEffectsModel <- R6::R6Class( stopifnot(length(value) == 1) rfx_model_set_variance_prior_shape_cpp(self$rfx_model_ptr, value) }, - + #' @description #' Set shape parameter for the group parameter variance prior. #' @param value Parameter input @@ -338,7 +434,7 @@ RandomEffectsModel <- R6::R6Class( #' @param random_effects_tracker Object of type `RandomEffectsTracker` #' @return `RandomEffectSamples` object #' @export -#' +#' #' @examples #' n <- 100 #' rfx_group_ids <- sample(1:2, size = n, replace = TRUE) @@ -347,7 +443,11 @@ RandomEffectsModel <- R6::R6Class( #' num_components <- ncol(rfx_basis) #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) #' rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) -createRandomEffectSamples <- function(num_components, num_groups, random_effects_tracker) { +createRandomEffectSamples <- function( + num_components, + num_groups, + random_effects_tracker +) { invisible(output <- RandomEffectSamples$new()) output$load_in_session(num_components, num_groups, random_effects_tracker) return(output) @@ -358,7 +458,7 @@ createRandomEffectSamples <- function(num_components, num_groups, random_effects #' @param rfx_group_indices Integer indices indicating groups used to define random effects #' @return `RandomEffectsTracker` object #' @export -#' +#' #' @examples #' n <- 100 #' rfx_group_ids <- sample(1:2, size = n, replace = TRUE) @@ -367,9 +467,7 @@ createRandomEffectSamples <- function(num_components, num_groups, random_effects #' num_components <- ncol(rfx_basis) #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) createRandomEffectsTracker <- function(rfx_group_indices) { - return(invisible(( - RandomEffectsTracker$new(rfx_group_indices) - ))) + return(invisible((RandomEffectsTracker$new(rfx_group_indices)))) } #' Create a `RandomEffectsModel` object @@ -378,7 +476,7 @@ createRandomEffectsTracker <- function(rfx_group_indices) { #' @param num_groups Number of random effects groups #' @return `RandomEffectsModel` object #' @export -#' +#' #' @examples #' n <- 100 #' rfx_group_ids <- sample(1:2, size = n, replace = TRUE) @@ -387,9 +485,7 @@ createRandomEffectsTracker <- function(rfx_group_indices) { #' num_components <- ncol(rfx_basis) #' rfx_model <- createRandomEffectsModel(num_components, num_groups) createRandomEffectsModel <- function(num_components, num_groups) { - return(invisible(( - RandomEffectsModel$new(num_components, num_groups) - ))) + return(invisible((RandomEffectsModel$new(num_components, num_groups)))) } #' Reset a `RandomEffectsModel` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object @@ -400,7 +496,7 @@ createRandomEffectsModel <- function(num_components, num_groups) { #' @param sigma_alpha_init Initial value of the "working parameter" scale parameter. #' @return None #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -429,19 +525,28 @@ createRandomEffectsModel <- function(num_components, num_groups) { #' rfx_model$set_variance_prior_shape(sigma_xi_shape) #' rfx_model$set_variance_prior_scale(sigma_xi_scale) #' for (i in 1:3) { -#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, -#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, +#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, +#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, #' keep_sample=TRUE, global_variance=1.0, rng=rng) #' } #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) -resetRandomEffectsModel <- function(rfx_model, rfx_samples, sample_num, sigma_alpha_init) { +resetRandomEffectsModel <- function( + rfx_model, + rfx_samples, + sample_num, + sigma_alpha_init +) { if (!is.matrix(sigma_alpha_init)) { if (!is.double(sigma_alpha_init)) { stop("`sigma_alpha_init` must be a numeric scalar or matrix") } sigma_alpha_init <- as.matrix(sigma_alpha_init) } - reset_rfx_model_cpp(rfx_model$rfx_model_ptr, rfx_samples$rfx_container_ptr, sample_num) + reset_rfx_model_cpp( + rfx_model$rfx_model_ptr, + rfx_samples$rfx_container_ptr, + sample_num + ) rfx_model$set_working_parameter_cov(sigma_alpha_init) } @@ -454,7 +559,7 @@ resetRandomEffectsModel <- function(rfx_model, rfx_samples, sample_num, sigma_al #' @param rfx_samples Object of type `RandomEffectSamples`. #' @return None #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -483,14 +588,25 @@ resetRandomEffectsModel <- function(rfx_model, rfx_samples, sample_num, sigma_al #' rfx_model$set_variance_prior_shape(sigma_xi_shape) #' rfx_model$set_variance_prior_scale(sigma_xi_scale) #' for (i in 1:3) { -#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, -#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, +#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, +#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, #' keep_sample=TRUE, global_variance=1.0, rng=rng) #' } #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) #' resetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome, rfx_samples) -resetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, residual, rfx_samples) { - reset_rfx_tracker_cpp(rfx_tracker$rfx_tracker_ptr, rfx_dataset$data_ptr, residual$data_ptr, rfx_model$rfx_model_ptr) +resetRandomEffectsTracker <- function( + rfx_tracker, + rfx_model, + rfx_dataset, + residual, + rfx_samples +) { + reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } #' Reset a `RandomEffectsModel` object to its "default" state @@ -504,7 +620,7 @@ resetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, resid #' @param sigma_xi_scale Scale parameter for the inverse gamma variance model on the group parameters. #' @return None #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -533,14 +649,21 @@ resetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, resid #' rfx_model$set_variance_prior_shape(sigma_xi_shape) #' rfx_model$set_variance_prior_scale(sigma_xi_scale) #' for (i in 1:3) { -#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, -#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, +#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, +#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, #' keep_sample=TRUE, global_variance=1.0, rng=rng) #' } #' rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) -rootResetRandomEffectsModel <- function(rfx_model, alpha_init, xi_init, sigma_alpha_init, - sigma_xi_init, sigma_xi_shape, sigma_xi_scale) { +rootResetRandomEffectsModel <- function( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale +) { rfx_model$set_working_parameter(alpha_init) rfx_model$set_group_parameters(xi_init) rfx_model$set_working_parameter_cov(sigma_alpha_init) @@ -557,7 +680,7 @@ rootResetRandomEffectsModel <- function(rfx_model, alpha_init, xi_init, sigma_al #' @param residual Object of type `Outcome`. #' @return None #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -586,13 +709,23 @@ rootResetRandomEffectsModel <- function(rfx_model, alpha_init, xi_init, sigma_al #' rfx_model$set_variance_prior_shape(sigma_xi_shape) #' rfx_model$set_variance_prior_scale(sigma_xi_scale) #' for (i in 1:3) { -#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, -#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, +#' rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, +#' rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, #' keep_sample=TRUE, global_variance=1.0, rng=rng) #' } #' rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) #' rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome) -rootResetRandomEffectsTracker <- function(rfx_tracker, rfx_model, rfx_dataset, residual) { - root_reset_rfx_tracker_cpp(rfx_tracker$rfx_tracker_ptr, rfx_dataset$data_ptr, residual$data_ptr, rfx_model$rfx_model_ptr) +rootResetRandomEffectsTracker <- function( + rfx_tracker, + rfx_model, + rfx_dataset, + residual +) { + root_reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } diff --git a/R/serialization.R b/R/serialization.R index 812b752e..c52e9fec 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -7,28 +7,27 @@ CppJson <- R6::R6Class( classname = "CppJson", cloneable = FALSE, public = list( - #' @field json_ptr External pointer to a C++ nlohmann::json object json_ptr = NULL, - + #' @field num_forests Number of forests in the nlohmann::json object num_forests = NULL, - + #' @field forest_labels Names of forest objects in the overall nlohmann::json object forest_labels = NULL, - + #' @field num_rfx Number of random effects terms in the nlohman::json object num_rfx = NULL, - + #' @field rfx_container_labels Names of rfx container objects in the overall nlohmann::json object rfx_container_labels = NULL, - + #' @field rfx_mapper_labels Names of rfx label mapper objects in the overall nlohmann::json object rfx_mapper_labels = NULL, - + #' @field rfx_groupid_labels Names of rfx group id objects in the overall nlohmann::json object rfx_groupid_labels = NULL, - + #' @description #' Create a new CppJson object. #' @return A new `CppJson` object. @@ -40,33 +39,54 @@ CppJson <- R6::R6Class( self$rfx_container_labels <- c() self$rfx_mapper_labels <- c() self$rfx_groupid_labels <- c() - }, - + }, + #' @description #' Convert a forest container to json and add to the current `CppJson` object #' @param forest_samples `ForestSamples` R class #' @return None add_forest = function(forest_samples) { - forest_label <- json_add_forest_cpp(self$json_ptr, forest_samples$forest_container_ptr) + forest_label <- json_add_forest_cpp( + self$json_ptr, + forest_samples$forest_container_ptr + ) self$num_forests <- self$num_forests + 1 self$forest_labels <- c(self$forest_labels, forest_label) - }, - + }, + #' @description #' Convert a random effects container to json and add to the current `CppJson` object #' @param rfx_samples `RandomEffectSamples` R class #' @return None add_random_effects = function(rfx_samples) { - rfx_container_label <- json_add_rfx_container_cpp(self$json_ptr, rfx_samples$rfx_container_ptr) - self$rfx_container_labels <- c(self$rfx_container_labels, rfx_container_label) - rfx_mapper_label <- json_add_rfx_label_mapper_cpp(self$json_ptr, rfx_samples$label_mapper_ptr) - self$rfx_mapper_labels <- c(self$rfx_mapper_labels, rfx_mapper_label) - rfx_groupid_label <- json_add_rfx_groupids_cpp(self$json_ptr, rfx_samples$training_group_ids) - self$rfx_groupid_labels <- c(self$rfx_groupid_labels, rfx_groupid_label) + rfx_container_label <- json_add_rfx_container_cpp( + self$json_ptr, + rfx_samples$rfx_container_ptr + ) + self$rfx_container_labels <- c( + self$rfx_container_labels, + rfx_container_label + ) + rfx_mapper_label <- json_add_rfx_label_mapper_cpp( + self$json_ptr, + rfx_samples$label_mapper_ptr + ) + self$rfx_mapper_labels <- c( + self$rfx_mapper_labels, + rfx_mapper_label + ) + rfx_groupid_label <- json_add_rfx_groupids_cpp( + self$json_ptr, + rfx_samples$training_group_ids + ) + self$rfx_groupid_labels <- c( + self$rfx_groupid_labels, + rfx_groupid_label + ) json_increment_rfx_count_cpp(self$json_ptr) self$num_rfx <- self$num_rfx + 1 - }, - + }, + #' @description #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json @@ -77,10 +97,15 @@ CppJson <- R6::R6Class( if (is.null(subfolder_name)) { json_add_double_cpp(self$json_ptr, field_name, field_value) } else { - json_add_double_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value) + json_add_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) } - }, - + }, + #' @description #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json @@ -91,10 +116,15 @@ CppJson <- R6::R6Class( if (is.null(subfolder_name)) { json_add_integer_cpp(self$json_ptr, field_name, field_value) } else { - json_add_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value) + json_add_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) } - }, - + }, + #' @description #' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json @@ -105,10 +135,15 @@ CppJson <- R6::R6Class( if (is.null(subfolder_name)) { json_add_bool_cpp(self$json_ptr, field_name, field_value) } else { - json_add_bool_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value) + json_add_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) } - }, - + }, + #' @description #' Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json @@ -119,10 +154,15 @@ CppJson <- R6::R6Class( if (is.null(subfolder_name)) { json_add_string_cpp(self$json_ptr, field_name, field_value) } else { - json_add_string_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_value) + json_add_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) } - }, - + }, + #' @description #' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json @@ -134,69 +174,110 @@ CppJson <- R6::R6Class( if (is.null(subfolder_name)) { json_add_vector_cpp(self$json_ptr, field_name, field_vector) } else { - json_add_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_vector) + json_add_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) } - }, - + }, + #' @description #' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json #' @param field_vector Vector to be stored in json #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value #' @return None - add_integer_vector = function(field_name, field_vector, subfolder_name = NULL) { + add_integer_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { field_vector <- as.numeric(field_vector) if (is.null(subfolder_name)) { - json_add_integer_vector_cpp(self$json_ptr, field_name, field_vector) + json_add_integer_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) } else { - json_add_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_vector) + json_add_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) } - }, - + }, + #' @description #' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be added to json #' @param field_vector Character vector to be stored in json #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value #' @return None - add_string_vector = function(field_name, field_vector, subfolder_name = NULL) { + add_string_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { if (is.null(subfolder_name)) { - json_add_string_vector_cpp(self$json_ptr, field_name, field_vector) + json_add_string_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) } else { - json_add_string_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name, field_vector) + json_add_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) } - }, - + }, + #' @description #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" #' @param field_name The name of the field to be added to json #' @param field_list List to be stored in json #' @return None add_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector))==0) + stopifnot(sum(!sapply(field_list, is.vector)) == 0) list_names <- names(field_list) for (i in 1:length(field_list)) { vec_name <- list_names[i] vec <- field_list[[i]] - json_add_vector_subfolder_cpp(self$json_ptr, field_name, vec_name, vec) + json_add_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) } - }, - + }, + #' @description #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" #' @param field_name The name of the field to be added to json #' @param field_list List to be stored in json #' @return None add_string_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector))==0) + stopifnot(sum(!sapply(field_list, is.vector)) == 0) list_names <- names(field_list) for (i in 1:length(field_list)) { vec_name <- list_names[i] vec <- field_list[[i]] - json_add_string_vector_subfolder_cpp(self$json_ptr, field_name, vec_name, vec) + json_add_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) } - }, - + }, + #' @description #' Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -207,12 +288,20 @@ CppJson <- R6::R6Class( stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) result <- json_extract_double_cpp(self$json_ptr, field_name) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_double_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -223,12 +312,20 @@ CppJson <- R6::R6Class( stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) result <- json_extract_integer_cpp(self$json_ptr, field_name) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_integer_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -239,12 +336,20 @@ CppJson <- R6::R6Class( stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) result <- json_extract_bool_cpp(self$json_ptr, field_name) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_bool_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -255,12 +360,20 @@ CppJson <- R6::R6Class( stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) result <- json_extract_string_cpp(self$json_ptr, field_name) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_string_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -271,12 +384,20 @@ CppJson <- R6::R6Class( stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) result <- json_extract_vector_cpp(self$json_ptr, field_name) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -285,14 +406,25 @@ CppJson <- R6::R6Class( get_integer_vector = function(field_name, subfolder_name = NULL) { if (is.null(subfolder_name)) { stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_integer_vector_cpp(self$json_ptr, field_name) + result <- json_extract_integer_vector_cpp( + self$json_ptr, + field_name + ) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_integer_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") #' @param field_name The name of the field to be accessed from json @@ -301,14 +433,25 @@ CppJson <- R6::R6Class( get_string_vector = function(field_name, subfolder_name = NULL) { if (is.null(subfolder_name)) { stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_string_vector_cpp(self$json_ptr, field_name) + result <- json_extract_string_vector_cpp( + self$json_ptr, + field_name + ) } else { - stopifnot(json_contains_field_subfolder_cpp(self$json_ptr, subfolder_name, field_name)) - result <- json_extract_string_vector_subfolder_cpp(self$json_ptr, subfolder_name, field_name) + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) } return(result) - }, - + }, + #' @description #' Reconstruct a list of numeric vectors from the json object stored under "field_name" #' @param field_name The name of the field to be added to json @@ -318,11 +461,15 @@ CppJson <- R6::R6Class( output <- list() for (i in 1:length(key_names)) { vec_name <- key_names[i] - output[[vec_name]] <- json_extract_vector_subfolder_cpp(self$json_ptr, field_name, vec_name) + output[[vec_name]] <- json_extract_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) } return(output) - }, - + }, + #' @description #' Reconstruct a list of string vectors from the json object stored under "field_name" #' @param field_name The name of the field to be added to json @@ -332,34 +479,38 @@ CppJson <- R6::R6Class( output <- list() for (i in 1:length(key_names)) { vec_name <- key_names[i] - output[[vec_name]] <- json_extract_string_vector_subfolder_cpp(self$json_ptr, field_name, vec_name) + output[[vec_name]] <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) } return(output) - }, - + }, + #' @description #' Convert a JSON object to in-memory string #' @return JSON string return_json_string = function() { return(get_json_string_cpp(self$json_ptr)) - }, - + }, + #' @description #' Save a json object to file #' @param filename String of filepath, must end in ".json" #' @return None save_file = function(filename) { json_save_file_cpp(self$json_ptr, filename) - }, - + }, + #' @description #' Load a json object from file #' @param filename String of filepath, must end in ".json" #' @return None load_from_file = function(filename) { json_load_file_cpp(self$json_ptr, filename) - }, - + }, + #' @description #' Load a json object from string #' @param json_string JSON string dump @@ -377,7 +528,7 @@ CppJson <- R6::R6Class( #' #' @return `ForestSamples` object #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) @@ -385,7 +536,7 @@ CppJson <- R6::R6Class( #' bart_json <- saveBARTModelToJson(bart_model) #' mean_forest <- loadForestContainerJson(bart_json, "forest_0") loadForestContainerJson <- function(json_object, json_forest_label) { - invisible(output <- ForestSamples$new(0,1,T)) + invisible(output <- ForestSamples$new(0, 1, T)) output$load_from_json(json_object, json_forest_label) return(output) } @@ -397,15 +548,18 @@ loadForestContainerJson <- function(json_object, json_forest_label) { #' #' @return `ForestSamples` object #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) #' bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) #' bart_json <- list(saveBARTModelToJson(bart_model)) #' mean_forest <- loadForestContainerCombinedJson(bart_json, "forest_0") -loadForestContainerCombinedJson <- function(json_object_list, json_forest_label) { - invisible(output <- ForestSamples$new(0,1,T)) +loadForestContainerCombinedJson <- function( + json_object_list, + json_forest_label +) { + invisible(output <- ForestSamples$new(0, 1, T)) for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { @@ -424,15 +578,18 @@ loadForestContainerCombinedJson <- function(json_object_list, json_forest_label) #' #' @return `ForestSamples` object #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) #' bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' mean_forest <- loadForestContainerCombinedJsonString(bart_json_string, "forest_0") -loadForestContainerCombinedJsonString <- function(json_string_list, json_forest_label) { - invisible(output <- ForestSamples$new(0,1,T)) +loadForestContainerCombinedJsonString <- function( + json_string_list, + json_forest_label +) { + invisible(output <- ForestSamples$new(0, 1, T)) for (i in 1:length(json_string_list)) { json_string <- json_string_list[[i]] if (i == 1) { @@ -451,7 +608,7 @@ loadForestContainerCombinedJsonString <- function(json_string_list, json_forest_ #' #' @return `RandomEffectSamples` object #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -468,7 +625,12 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) invisible(output <- RandomEffectSamples$new()) - output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) return(output) } @@ -479,7 +641,7 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { #' #' @return `RandomEffectSamples` object #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -491,7 +653,10 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { #' rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) #' bart_json <- list(saveBARTModelToJson(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJson(bart_json, 0) -loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) { +loadRandomEffectSamplesCombinedJson <- function( + json_object_list, + json_rfx_num +) { json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) @@ -499,9 +664,19 @@ loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) for (i in 1:length(json_object_list)) { json_object <- json_object_list[[i]] if (i == 1) { - output$load_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } else { - output$append_from_json(json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$append_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } } return(output) @@ -514,7 +689,7 @@ loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) #' #' @return `RandomEffectSamples` object #' @export -#' +#' #' @examples #' n <- 100 #' p <- 10 @@ -526,7 +701,10 @@ loadRandomEffectSamplesCombinedJson <- function(json_object_list, json_rfx_num) #' rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJsonString(bart_json_string, 0) -loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx_num) { +loadRandomEffectSamplesCombinedJsonString <- function( + json_string_list, + json_rfx_num +) { json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) @@ -534,9 +712,19 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx for (i in 1:length(json_string_list)) { json_string <- json_string_list[[i]] if (i == 1) { - output$load_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$load_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } else { - output$append_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label) + output$append_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } } return(output) @@ -550,13 +738,17 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx #' #' @return R vector #' @export -#' +#' #' @examples #' example_vec <- runif(10) #' example_json <- createCppJson() #' example_json$add_vector("myvec", example_vec) #' roundtrip_vec <- loadVectorJson(example_json, "myvec") -loadVectorJson <- function(json_object, json_vector_label, subfolder_name = NULL) { +loadVectorJson <- function( + json_object, + json_vector_label, + subfolder_name = NULL +) { if (is.null(subfolder_name)) { output <- json_object$get_vector(json_vector_label) } else { @@ -573,13 +765,17 @@ loadVectorJson <- function(json_object, json_vector_label, subfolder_name = NULL #' #' @return R vector #' @export -#' +#' #' @examples #' example_scalar <- 5.4 #' example_json <- createCppJson() #' example_json$add_scalar("myscalar", example_scalar) #' roundtrip_scalar <- loadScalarJson(example_json, "myscalar") -loadScalarJson <- function(json_object, json_scalar_label, subfolder_name = NULL) { +loadScalarJson <- function( + json_object, + json_scalar_label, + subfolder_name = NULL +) { if (is.null(subfolder_name)) { output <- json_object$get_scalar(json_scalar_label) } else { @@ -592,15 +788,13 @@ loadScalarJson <- function(json_object, json_scalar_label, subfolder_name = NULL #' #' @return `CppJson` object #' @export -#' +#' #' @examples #' example_vec <- runif(10) #' example_json <- createCppJson() #' example_json$add_vector("myvec", example_vec) createCppJson <- function() { - return(invisible(( - CppJson$new() - ))) + return(invisible((CppJson$new()))) } #' Create a C++ Json object from a Json file @@ -608,7 +802,7 @@ createCppJson <- function() { #' @param json_filename Name of file to read. Must end in `.json`. #' @return `CppJson` object #' @export -#' +#' #' @examples #' example_vec <- runif(10) #' example_json <- createCppJson() @@ -618,9 +812,7 @@ createCppJson <- function() { #' example_json_roundtrip <- createCppJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createCppJsonFile <- function(json_filename) { - invisible(( - output <- CppJson$new() - )) + invisible((output <- CppJson$new())) output$load_from_file(json_filename) return(output) } @@ -630,7 +822,7 @@ createCppJsonFile <- function(json_filename) { #' @param json_string JSON string dump #' @return `CppJson` object #' @export -#' +#' #' @examples #' example_vec <- runif(10) #' example_json <- createCppJson() @@ -638,9 +830,7 @@ createCppJsonFile <- function(json_filename) { #' example_json_string <- example_json$return_json_string() #' example_json_roundtrip <- createCppJsonString(example_json_string) createCppJsonString <- function(json_string) { - invisible(( - output <- CppJson$new() - )) + invisible((output <- CppJson$new())) output$load_from_string(json_string) return(output) } diff --git a/R/stochtree-package.R b/R/stochtree-package.R index f3fd5c43..97eeded1 100644 --- a/R/stochtree-package.R +++ b/R/stochtree-package.R @@ -18,4 +18,4 @@ NULL #' @useDynLib stochtree, .registration = TRUE -"_PACKAGE" \ No newline at end of file +"_PACKAGE" diff --git a/R/utils.R b/R/utils.R index a4329ad9..ad9752bc 100644 --- a/R/utils.R +++ b/R/utils.R @@ -20,13 +20,13 @@ preprocessParams <- function(default_params, user_params = NULL) { return(default_params) } -#' Preprocess covariates. DataFrames will be preprocessed based on their column +#' Preprocess covariates. DataFrames will be preprocessed based on their column #' types. Matrices will be passed through assuming all columns are numeric. #' #' @param input_data Covariates, provided as either a dataframe or a matrix #' -#' @return List with preprocessed (unmodified) data and details on the number of each type -#' of variable, unique categories associated with categorical variables, and the +#' @return List with preprocessed (unmodified) data and details on the number of each type +#' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. #' @export #' @@ -39,22 +39,22 @@ preprocessTrainData <- function(input_data) { if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { stop("Covariates provided must be a dataframe or matrix") } - + # Routing the correct preprocessing function if (is.matrix(input_data)) { output <- preprocessTrainMatrix(input_data) } else { output <- preprocessTrainDataFrame(input_data) } - + return(output) } -#' Preprocess covariates. DataFrames will be preprocessed based on their column +#' Preprocess covariates. DataFrames will be preprocessed based on their column #' types. Matrices will be passed through assuming all columns are numeric. #' #' @param input_data Covariates, provided as either a dataframe or a matrix -#' @param metadata List containing information on variables, including train set +#' @param metadata List containing information on variables, including train set #' categories for categorical variables #' #' @return Preprocessed data with categorical variables appropriately handled @@ -62,7 +62,7 @@ preprocessTrainData <- function(input_data) { #' #' @examples #' cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionData(cov_df, metadata) preprocessPredictionData <- function(input_data, metadata) { @@ -70,14 +70,14 @@ preprocessPredictionData <- function(input_data, metadata) { if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { stop("Covariates provided must be a dataframe or matrix") } - + # Routing the correct preprocessing function if (is.matrix(input_data)) { X <- preprocessPredictionMatrix(input_data, metadata) } else { X <- preprocessPredictionDataFrame(input_data, metadata) } - + return(X) } @@ -86,8 +86,8 @@ preprocessPredictionData <- function(input_data, metadata) { #' #' @param input_matrix Covariate matrix. #' -#' @return List with preprocessed (unmodified) data and details on the number of each type -#' of variable, unique categories associated with categorical variables, and the +#' @return List with preprocessed (unmodified) data and details on the number of each type +#' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. #' @noRd #' @@ -100,7 +100,7 @@ preprocessTrainMatrix <- function(input_matrix) { if (!is.matrix(input_matrix)) { stop("covariates provided must be a matrix") } - + # Unpack metadata (assuming all variables are numeric) names(input_matrix) <- paste0("x", 1:ncol(input_matrix)) df_vars <- names(input_matrix) @@ -115,25 +115,25 @@ preprocessTrainMatrix <- function(input_matrix) { # Aggregate results into a list metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, - numeric_vars = numeric_vars, + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, + numeric_vars = numeric_vars, original_var_indices = 1:num_numeric_vars ) output <- list( - data = X, + data = X, metadata = metadata ) - + return(output) } #' Preprocess a matrix of covariate values, assuming all columns are numeric. #' #' @param input_matrix Covariate matrix. -#' @param metadata List containing information on variables, including train set +#' @param metadata List containing information on variables, including train set #' categories for categorical variables #' #' @return Preprocessed data with categorical variables appropriately preprocessed @@ -141,7 +141,7 @@ preprocessTrainMatrix <- function(input_matrix) { #' #' @examples #' cov_mat <- matrix(c(1:5, 5:1, 6:10), ncol = 3) -#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionMatrix(cov_mat, metadata) preprocessPredictionMatrix <- function(input_matrix, metadata) { @@ -150,22 +150,24 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) { stop("covariates provided must be a matrix") } if (!(ncol(input_matrix) == metadata$num_numeric_vars)) { - stop("Prediction set covariates have inconsistent dimension from train set covariates") + stop( + "Prediction set covariates have inconsistent dimension from train set covariates" + ) } - + return(input_matrix) } -#' Preprocess a dataframe of covariate values, converting categorical variables -#' to integers and one-hot encoding if need be. Returns a list including a +#' Preprocess a dataframe of covariate values, converting categorical variables +#' to integers and one-hot encoding if need be. Returns a list including a #' matrix of preprocessed covariate values and associated tracking. #' -#' @param input_df Dataframe of covariates. Users must pre-process any +#' @param input_df Dataframe of covariates. Users must pre-process any #' categorical variables as factors (ordered for ordered categorical). #' @noRd #' -#' @return List with preprocessed data and details on the number of each type -#' of variable, unique categories associated with categorical variables, and the +#' @return List with preprocessed data and details on the number of each type +#' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. preprocessTrainDataFrame <- function(input_df) { # Input checks / details @@ -173,10 +175,10 @@ preprocessTrainDataFrame <- function(input_df) { stop("covariates provided must be a data frame") } df_vars <- names(input_df) - + # Detect ordered and unordered categorical variables - - # First, ordered categorical: users must have explicitly + + # First, ordered categorical: users must have explicitly # converted this to a factor with ordered = TRUE factor_mask <- sapply(input_df, is.factor) ordered_mask <- sapply(input_df, is.ordered) @@ -184,50 +186,58 @@ preprocessTrainDataFrame <- function(input_df) { ordered_cat_vars <- df_vars[ordered_cat_matches] ordered_cat_var_inds <- unname(which(ordered_cat_matches)) num_ordered_cat_vars <- length(ordered_cat_vars) - if (num_ordered_cat_vars > 0) ordered_cat_df <- input_df[,ordered_cat_vars,drop=FALSE] - - # Next, unordered categorical: we will convert character - # columns but not integer columns (users must explicitly + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + + # Next, unordered categorical: we will convert character + # columns but not integer columns (users must explicitly # convert these to factor) character_mask <- sapply(input_df, is.character) unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask unordered_cat_vars <- df_vars[unordered_cat_matches] unordered_cat_var_inds <- unname(which(unordered_cat_matches)) num_unordered_cat_vars <- length(unordered_cat_vars) - if (num_unordered_cat_vars > 0) unordered_cat_df <- input_df[,unordered_cat_vars,drop=FALSE] - + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + # Numeric variables numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches) numeric_vars <- df_vars[numeric_matches] numeric_var_inds <- unname(which(numeric_matches)) num_numeric_vars <- length(numeric_vars) - if (num_numeric_vars > 0) numeric_df <- input_df[,numeric_vars,drop=FALSE] - + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + # Empty outputs X <- double(0) unordered_unique_levels <- list() ordered_unique_levels <- list() feature_types <- integer(0) original_var_indices <- integer(0) - + # First, extract the numeric covariates if (num_numeric_vars > 0) { Xnum <- double(0) for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[,i])) - Xnum <- cbind(Xnum, numeric_df[,i]) + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } X <- cbind(X, unname(Xnum)) feature_types <- c(feature_types, rep(0, ncol(Xnum))) original_var_indices <- c(original_var_indices, numeric_var_inds) } - + # Next, run some simple preprocessing on the ordered categorical covariates if (num_ordered_cat_vars > 0) { Xordcat <- double(0) for (i in 1:ncol(ordered_cat_df)) { var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[,i]) + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } @@ -235,29 +245,32 @@ preprocessTrainDataFrame <- function(input_df) { feature_types <- c(feature_types, rep(1, ncol(Xordcat))) original_var_indices <- c(original_var_indices, ordered_cat_var_inds) } - + # Finally, one-hot encode the unordered categorical covariates if (num_unordered_cat_vars > 0) { one_hot_mats <- list() for (i in 1:ncol(unordered_cat_df)) { var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[,i]) + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) unordered_unique_levels[[var_name]] <- encode_list$unique_levels one_hot_mats[[var_name]] <- encode_list$Xtilde - one_hot_var <- rep(unordered_cat_var_inds[i], ncol(encode_list$Xtilde)) + one_hot_var <- rep( + unordered_cat_var_inds[i], + ncol(encode_list$Xtilde) + ) original_var_indices <- c(original_var_indices, one_hot_var) } Xcat <- do.call(cbind, one_hot_mats) X <- cbind(X, unname(Xcat)) feature_types <- c(feature_types, rep(1, ncol(Xcat))) } - + # Aggregate results into a list metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, original_var_indices = original_var_indices ) if (num_ordered_cat_vars > 0) { @@ -268,21 +281,23 @@ preprocessTrainDataFrame <- function(input_df) { metadata[["unordered_cat_vars"]] = unordered_cat_vars metadata[["unordered_unique_levels"]] = unordered_unique_levels } - if (num_numeric_vars > 0) metadata[["numeric_vars"]] = numeric_vars + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } output <- list( - data = X, + data = X, metadata = metadata ) - + return(output) } -#' Preprocess a dataframe of covariate values, converting categorical variables +#' Preprocess a dataframe of covariate values, converting categorical variables #' to integers and one-hot encoding if need be. #' -#' @param input_df Dataframe of covariates. Users must pre-process any +#' @param input_df Dataframe of covariates. Users must pre-process any #' categorical variables as factors (ordered for ordered categorical). -#' @param metadata List containing information on variables, including train set +#' @param metadata List containing information on variables, including train set #' categories for categorical variables #' #' @return Preprocessed data with categorical variables appropriately preprocessed @@ -290,7 +305,7 @@ preprocessTrainDataFrame <- function(input_df) { #' #' @examples #' cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionDataFrame(cov_df, metadata) preprocessPredictionDataFrame <- function(input_df, metadata) { @@ -301,63 +316,69 @@ preprocessPredictionDataFrame <- function(input_df, metadata) { num_ordered_cat_vars <- metadata$num_ordered_cat_vars num_unordered_cat_vars <- metadata$num_unordered_cat_vars num_numeric_vars <- metadata$num_numeric_vars - + if (num_ordered_cat_vars > 0) { ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[,ordered_cat_vars,drop=FALSE] + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] } if (num_unordered_cat_vars > 0) { unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[,unordered_cat_vars,drop=FALSE] + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] } if (num_numeric_vars > 0) { numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[,numeric_vars,drop=FALSE] + numeric_df <- input_df[, numeric_vars, drop = FALSE] } - + # Empty outputs X <- double(0) - + # First, extract the numeric covariates if (num_numeric_vars > 0) { Xnum <- double(0) for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[,i])) - Xnum <- cbind(Xnum, numeric_df[,i]) + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } X <- cbind(X, unname(Xnum)) } - + # Next, run some simple preprocessing on the ordered categorical covariates if (num_ordered_cat_vars > 0) { Xordcat <- double(0) for (i in 1:ncol(ordered_cat_df)) { var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess(ordered_cat_df[,i], metadata$ordered_unique_levels[[var_name]]) + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) Xordcat <- cbind(Xordcat, x_preprocessed) } X <- cbind(X, unname(Xordcat)) } - + # Finally, one-hot encode the unordered categorical covariates if (num_unordered_cat_vars > 0) { one_hot_mats <- list() for (i in 1:ncol(unordered_cat_df)) { var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode(unordered_cat_df[,i], metadata$unordered_unique_levels[[var_name]]) + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) one_hot_mats[[var_name]] <- Xtilde } Xcat <- do.call(cbind, one_hot_mats) X <- cbind(X, unname(Xcat)) } - + return(X) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) C++ JSON object #' -#' @param object List containing information on variables, including train set -#' categories for categorical variables +#' @param object List containing information on variables, including train set +#' categories for categorical variables #' #' @return wrapper around in-memory C++ JSON object #' @export @@ -371,12 +392,12 @@ convertPreprocessorToJson <- function(object) { if (is.null(object$feature_types)) { stop("This covariate preprocessor has not yet been fit") } - + # Add internal scalars jsonobj$add_integer("num_numeric_vars", object$num_numeric_vars) jsonobj$add_integer("num_ordered_cat_vars", object$num_ordered_cat_vars) jsonobj$add_integer("num_unordered_cat_vars", object$num_unordered_cat_vars) - + # Add internal vectors jsonobj$add_vector("feature_types", object$feature_types) jsonobj$add_vector("original_var_indices", object$original_var_indices) @@ -387,26 +408,45 @@ convertPreprocessorToJson <- function(object) { jsonobj$add_string_vector("ordered_cat_vars", object$ordered_cat_vars) for (i in 1:object$num_ordered_cat_vars) { var_key <- names(object$ordered_unique_levels)[i] - jsonobj$add_string(paste0("key_", i), var_key, "ordered_unique_level_keys") - jsonobj$add_string_vector(var_key, object$ordered_unique_levels[[i]], "ordered_unique_levels") + jsonobj$add_string( + paste0("key_", i), + var_key, + "ordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$ordered_unique_levels[[i]], + "ordered_unique_levels" + ) } } if (object$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector("unordered_cat_vars", object$unordered_cat_vars) + jsonobj$add_string_vector( + "unordered_cat_vars", + object$unordered_cat_vars + ) for (i in 1:object$num_unordered_cat_vars) { var_key <- names(object$unordered_unique_levels)[i] - jsonobj$add_string(paste0("key_", i), var_key, "unordered_unique_level_keys") - jsonobj$add_string_vector(var_key, object$unordered_unique_levels[[i]], "unordered_unique_levels") + jsonobj$add_string( + paste0("key_", i), + var_key, + "unordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$unordered_unique_levels[[i]], + "unordered_unique_levels" + ) } } - + return(jsonobj) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) JSON string #' -#' @param object List containing information on variables, including train set -#' categories for categorical variables +#' @param object List containing information on variables, including train set +#' categories for categorical variables #' #' @return in-memory JSON string #' @export @@ -415,10 +455,10 @@ convertPreprocessorToJson <- function(object) { #' cov_mat <- matrix(1:12, ncol = 3) #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) -savePreprocessorToJsonString <- function(object){ +savePreprocessorToJsonString <- function(object) { # Convert to Json jsonobj <- convertPreprocessorToJson(object) - + # Dump to string return(jsonobj$return_json_string()) } @@ -435,40 +475,66 @@ savePreprocessorToJsonString <- function(object){ #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json <- convertPreprocessorToJson(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJson(preprocessor_json) -createPreprocessorFromJson <- function(json_object){ +createPreprocessorFromJson <- function(json_object) { # Initialize the metadata list metadata <- list() - + # Unpack internal scalars - metadata[["num_numeric_vars"]] <- json_object$get_integer("num_numeric_vars") - metadata[["num_ordered_cat_vars"]] <- json_object$get_integer("num_ordered_cat_vars") - metadata[["num_unordered_cat_vars"]] <- json_object$get_integer("num_unordered_cat_vars") - + metadata[["num_numeric_vars"]] <- json_object$get_integer( + "num_numeric_vars" + ) + metadata[["num_ordered_cat_vars"]] <- json_object$get_integer( + "num_ordered_cat_vars" + ) + metadata[["num_unordered_cat_vars"]] <- json_object$get_integer( + "num_unordered_cat_vars" + ) + # Unpack internal vectors metadata[["feature_types"]] <- json_object$get_vector("feature_types") - metadata[["original_var_indices"]] <- json_object$get_vector("original_var_indices") + metadata[["original_var_indices"]] <- json_object$get_vector( + "original_var_indices" + ) if (metadata$num_numeric_vars > 0) { - metadata[["numeric_vars"]] <- json_object$get_string_vector("numeric_vars") + metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" + ) } if (metadata$num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] <- json_object$get_string_vector("ordered_cat_vars") + metadata[["ordered_cat_vars"]] <- json_object$get_string_vector( + "ordered_cat_vars" + ) ordered_unique_levels <- list() for (i in 1:metadata$num_ordered_cat_vars) { - var_key <- json_object$get_string(paste0("key_", i), "ordered_unique_level_keys") - ordered_unique_levels[[var_key]] <- json_object$get_string_vector(var_key, "ordered_unique_levels") + var_key <- json_object$get_string( + paste0("key_", i), + "ordered_unique_level_keys" + ) + ordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "ordered_unique_levels" + ) } metadata[["ordered_unique_levels"]] <- ordered_unique_levels } if (metadata$num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] <- json_object$get_string_vector("unordered_cat_vars") + metadata[["unordered_cat_vars"]] <- json_object$get_string_vector( + "unordered_cat_vars" + ) unordered_unique_levels <- list() for (i in 1:metadata$num_unordered_cat_vars) { - var_key <- json_object$get_string(paste0("key_", i), "unordered_unique_level_keys") - unordered_unique_levels[[var_key]] <- json_object$get_string_vector(var_key, "unordered_unique_levels") + var_key <- json_object$get_string( + paste0("key_", i), + "unordered_unique_level_keys" + ) + unordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "unordered_unique_levels" + ) } metadata[["unordered_unique_levels"]] <- unordered_unique_levels } - + return(metadata) } @@ -484,27 +550,27 @@ createPreprocessorFromJson <- function(json_object){ #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJsonString(preprocessor_json_string) -createPreprocessorFromJsonString <- function(json_string){ +createPreprocessorFromJsonString <- function(json_string) { # Load a `CppJson` object from string preprocessor_json <- createCppJsonString(json_string) - + # Create and return the BCF object preprocessor_object <- createPreprocessorFromJson(preprocessor_json) - + return(preprocessor_object) } -#' Preprocess a dataframe of covariate values, converting categorical variables -#' to integers and one-hot encoding if need be. Returns a list including a +#' Preprocess a dataframe of covariate values, converting categorical variables +#' to integers and one-hot encoding if need be. Returns a list including a #' matrix of preprocessed covariate values and associated tracking. #' -#' @param input_data Dataframe or matrix of covariates. Users may pre-process any +#' @param input_data Dataframe or matrix of covariates. Users may pre-process any #' categorical variables as factors but it is not necessary. #' @param ordered_cat_vars (Optional) Vector of names of ordered categorical variables, or vector of column indices if `input_data` is a matrix. #' @param unordered_cat_vars (Optional) Vector of names of unordered categorical variables, or vector of column indices if `input_data` is a matrix. #' -#' @return List with preprocessed data and details on the number of each type -#' of variable, unique categories associated with categorical variables, and the +#' @return List with preprocessed data and details on the number of each type +#' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. #' @noRd #' @@ -512,15 +578,26 @@ createPreprocessorFromJsonString <- function(json_string){ #' cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) #' preprocess_list <- createForestCovariates(cov_df) #' X <- preprocess_list$X -createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordered_cat_vars = NULL) { +createForestCovariates <- function( + input_data, + ordered_cat_vars = NULL, + unordered_cat_vars = NULL +) { if (is.matrix(input_data)) { input_df <- as.data.frame(input_data) names(input_df) <- paste0("x", 1:ncol(input_data)) if (!is.null(ordered_cat_vars)) { - if (is.numeric(ordered_cat_vars)) ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) + if (is.numeric(ordered_cat_vars)) { + ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) + } } if (!is.null(unordered_cat_vars)) { - if (is.numeric(unordered_cat_vars)) unordered_cat_vars <- paste0("x", as.integer(unordered_cat_vars)) + if (is.numeric(unordered_cat_vars)) { + unordered_cat_vars <- paste0( + "x", + as.integer(unordered_cat_vars) + ) + } } } else if (is.data.frame(input_data)) { input_df <- input_data @@ -528,10 +605,16 @@ createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordere stop("input_data must be either a matrix or a data frame") } df_vars <- names(input_df) - if (is.null(ordered_cat_vars)) ordered_cat_matches <- rep(FALSE, length(df_vars)) - else ordered_cat_matches <- df_vars %in% ordered_cat_vars - if (is.null(unordered_cat_vars)) unordered_cat_matches <- rep(FALSE, length(df_vars)) - else unordered_cat_matches <- df_vars %in% unordered_cat_vars + if (is.null(ordered_cat_vars)) { + ordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + ordered_cat_matches <- df_vars %in% ordered_cat_vars + } + if (is.null(unordered_cat_vars)) { + unordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + unordered_cat_matches <- df_vars %in% unordered_cat_vars + } numeric_matches <- ((!ordered_cat_matches) & (!unordered_cat_matches)) ordered_cat_vars <- df_vars[ordered_cat_matches] unordered_cat_vars <- df_vars[unordered_cat_matches] @@ -539,46 +622,54 @@ createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordere num_ordered_cat_vars <- length(ordered_cat_vars) num_unordered_cat_vars <- length(unordered_cat_vars) num_numeric_vars <- length(numeric_vars) - if (num_ordered_cat_vars > 0) ordered_cat_df <- input_df[,ordered_cat_vars,drop=FALSE] - if (num_unordered_cat_vars > 0) unordered_cat_df <- input_df[,unordered_cat_vars,drop=FALSE] - if (num_numeric_vars > 0) numeric_df <- input_df[,numeric_vars,drop=FALSE] - + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + # Empty outputs X <- double(0) unordered_unique_levels <- list() ordered_unique_levels <- list() feature_types <- integer(0) - + # First, extract the numeric covariates if (num_numeric_vars > 0) { Xnum <- double(0) for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[,i])) - Xnum <- cbind(Xnum, numeric_df[,i]) + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } X <- cbind(X, unname(Xnum)) feature_types <- c(feature_types, rep(0, ncol(Xnum))) } - + # Next, run some simple preprocessing on the ordered categorical covariates if (num_ordered_cat_vars > 0) { Xordcat <- double(0) for (i in 1:ncol(ordered_cat_df)) { var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[,i]) + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } X <- cbind(X, unname(Xordcat)) feature_types <- c(feature_types, rep(1, ncol(Xordcat))) } - + # Finally, one-hot encode the unordered categorical covariates if (num_unordered_cat_vars > 0) { one_hot_mats <- list() for (i in 1:ncol(unordered_cat_df)) { var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[,i]) + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) unordered_unique_levels[[var_name]] <- encode_list$unique_levels one_hot_mats[[var_name]] <- encode_list$Xtilde } @@ -586,12 +677,12 @@ createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordere X <- cbind(X, unname(Xcat)) feature_types <- c(feature_types, rep(1, ncol(Xcat))) } - + # Aggregate results into a list metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, num_numeric_vars = num_numeric_vars ) if (num_ordered_cat_vars > 0) { @@ -602,22 +693,24 @@ createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordere metadata[["unordered_cat_vars"]] = unordered_cat_vars metadata[["unordered_unique_levels"]] = unordered_unique_levels } - if (num_numeric_vars > 0) metadata[["numeric_vars"]] = numeric_vars + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } output <- list( - data = X, + data = X, metadata = metadata ) - + return(output) } -#' Preprocess a dataframe of covariate values, converting categorical variables -#' to integers and one-hot encoding if need be. Returns a list including a +#' Preprocess a dataframe of covariate values, converting categorical variables +#' to integers and one-hot encoding if need be. Returns a list including a #' matrix of preprocessed covariate values and associated tracking. #' -#' @param input_data Dataframe or matrix of covariates. Users may pre-process any +#' @param input_data Dataframe or matrix of covariates. Users may pre-process any #' categorical variables as factors but it is not necessary. -#' @param metadata List containing information on variables, including train set +#' @param metadata List containing information on variables, including train set #' categories for categorical variables #' #' @return Preprocessed data with categorical variables appropriately preprocessed @@ -625,7 +718,7 @@ createForestCovariates <- function(input_data, ordered_cat_vars = NULL, unordere #' #' @examples #' cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +#' metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- createForestCovariatesFromMetadata(cov_df, metadata) createForestCovariatesFromMetadata <- function(input_data, metadata) { @@ -641,20 +734,20 @@ createForestCovariatesFromMetadata <- function(input_data, metadata) { num_ordered_cat_vars <- metadata$num_ordered_cat_vars num_unordered_cat_vars <- metadata$num_unordered_cat_vars num_numeric_vars <- metadata$num_numeric_vars - + if (num_ordered_cat_vars > 0) { ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[,ordered_cat_vars,drop=FALSE] + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] } if (num_unordered_cat_vars > 0) { unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[,unordered_cat_vars,drop=FALSE] + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] } if (num_numeric_vars > 0) { numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[,numeric_vars,drop=FALSE] + numeric_df <- input_df[, numeric_vars, drop = FALSE] } - + # Empty outputs X <- double(0) @@ -662,51 +755,57 @@ createForestCovariatesFromMetadata <- function(input_data, metadata) { if (num_numeric_vars > 0) { Xnum <- double(0) for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[,i])) - Xnum <- cbind(Xnum, numeric_df[,i]) + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } X <- cbind(X, unname(Xnum)) } - + # Next, run some simple preprocessing on the ordered categorical covariates if (num_ordered_cat_vars > 0) { Xordcat <- double(0) for (i in 1:ncol(ordered_cat_df)) { var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess(ordered_cat_df[,i], metadata$ordered_unique_levels[[var_name]]) + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) Xordcat <- cbind(Xordcat, x_preprocessed) } X <- cbind(X, unname(Xordcat)) } - + # Finally, one-hot encode the unordered categorical covariates if (num_unordered_cat_vars > 0) { one_hot_mats <- list() for (i in 1:ncol(unordered_cat_df)) { var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode(unordered_cat_df[,i], metadata$unordered_unique_levels[[var_name]]) + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) one_hot_mats[[var_name]] <- Xtilde } Xcat <- do.call(cbind, one_hot_mats) X <- cbind(X, unname(Xcat)) } - + return(X) } -#' Convert a vector of unordered categorical data (either numeric or character -#' labels) to a "one-hot" encoded matrix in which a 1 in a column indicates -#' the presence of the relevant category. -#' -#' To allow for prediction on "unseen" categories in a test dataset, this -#' procedure pads the one-hot matrix with a blank "other" column. +#' Convert a vector of unordered categorical data (either numeric or character +#' labels) to a "one-hot" encoded matrix in which a 1 in a column indicates +#' the presence of the relevant category. +#' +#' To allow for prediction on "unseen" categories in a test dataset, this +#' procedure pads the one-hot matrix with a blank "other" column. #' Test set observations that contain categories not in `levels(factor(x_input))` #' will all be mapped to this column. #' -#' @param x_input Vector of unordered categorical data (typically either strings +#' @param x_input Vector of unordered categorical data (typically either strings #' integers, but this function also accepts floating point data). #' -#' @return List containing a binary one-hot matrix and the unique levels of the +#' @return List containing a binary one-hot matrix and the unique levels of the #' input variable. These unique levels are used in the BCF and BART functions. #' @noRd #' @@ -715,28 +814,30 @@ createForestCovariatesFromMetadata <- function(input_data, metadata) { #' x_onehot <- oneHotInitializeAndEncode(x) oneHotInitializeAndEncode <- function(x_input) { stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - if (is.factor(x_input) && is.ordered(x_input)) warning("One-hot encoding an ordered categorical variable") + if (is.factor(x_input) && is.ordered(x_input)) { + warning("One-hot encoding an ordered categorical variable") + } x_factor <- factor(x_input) unique_levels <- levels(x_factor) - Xtilde <- cbind(unname(model.matrix(~0+x_factor)), 0) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) output <- list(Xtilde = Xtilde, unique_levels = unique_levels) return(output) } -#' Convert a vector of unordered categorical data (either numeric or character -#' labels) to a "one-hot" encoded matrix in which a 1 in a column indicates -#' the presence of the relevant category. -#' -#' This procedure assumes that a reference set of observations for this variable +#' Convert a vector of unordered categorical data (either numeric or character +#' labels) to a "one-hot" encoded matrix in which a 1 in a column indicates +#' the presence of the relevant category. +#' +#' This procedure assumes that a reference set of observations for this variable #' (typically a training set that was used to sample a forest) has already been -#' one-hot encoded and that the unique levels of the training set variable are -#' available (and passed as `unique_levels`). Test set observations that contain -#' categories not in `unique_levels` will all be mapped to the last column of +#' one-hot encoded and that the unique levels of the training set variable are +#' available (and passed as `unique_levels`). Test set observations that contain +#' categories not in `unique_levels` will all be mapped to the last column of #' this matrix #' -#' @param x_input Vector of unordered categorical data (typically either strings +#' @param x_input Vector of unordered categorical data (typically either strings #' integers, but this function also accepts floating point data). -#' @param unique_levels Unique values of the categorical variable used to create +#' @param unique_levels Unique values of the categorical variable used to create #' the initial one-hot matrix (typically a training set) #' #' @return Binary one-hot matrix @@ -755,33 +856,43 @@ oneHotEncode <- function(x_input, unique_levels) { has_out_of_sample <- sum(out_of_sample) > 0 if (has_out_of_sample) { x_factor_insample <- factor(x_input[in_sample], levels = unique_levels) - Xtilde <- matrix(0, nrow = length(x_input), ncol = num_unique_levels + 1) - Xtilde_insample <- cbind(unname(model.matrix(~0+x_factor_insample)), 0) - Xtilde_out_of_sample <- cbind(matrix(0, nrow=sum(out_of_sample), ncol=num_unique_levels), 1) - Xtilde[in_sample,] <- Xtilde_insample - Xtilde[out_of_sample,] <- Xtilde_out_of_sample + Xtilde <- matrix( + 0, + nrow = length(x_input), + ncol = num_unique_levels + 1 + ) + Xtilde_insample <- cbind( + unname(model.matrix(~ 0 + x_factor_insample)), + 0 + ) + Xtilde_out_of_sample <- cbind( + matrix(0, nrow = sum(out_of_sample), ncol = num_unique_levels), + 1 + ) + Xtilde[in_sample, ] <- Xtilde_insample + Xtilde[out_of_sample, ] <- Xtilde_out_of_sample } else { x_factor <- factor(x_input, levels = unique_levels) - Xtilde <- cbind(unname(model.matrix(~0+x_factor)), 0) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) } return(Xtilde) } -#' Run some simple preprocessing of ordered categorical variables, converting -#' ordered levels to integers if necessary, and storing the unique levels of a +#' Run some simple preprocessing of ordered categorical variables, converting +#' ordered levels to integers if necessary, and storing the unique levels of a #' variable. #' -#' @param x_input Vector of ordered categorical data. If the data is not already -#' stored as an ordered factor, it will be converted to one using the default +#' @param x_input Vector of ordered categorical data. If the data is not already +#' stored as an ordered factor, it will be converted to one using the default #' sort order. #' -#' @return List containing a preprocessed vector of integer-converted ordered -#' categorical observations and the unique level of the original ordered +#' @return List containing a preprocessed vector of integer-converted ordered +#' categorical observations and the unique level of the original ordered #' categorical feature. #' @noRd #' #' @examples -#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", +#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", #' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' preprocess_list <- orderedCatInitializeAndPreprocess(x) #' x_preprocessed <- preprocess_list$x_preprocessed @@ -799,26 +910,26 @@ orderedCatInitializeAndPreprocess <- function(x_input) { return(list(x_preprocessed = x_preprocessed, unique_levels = unique_levels)) } -#' Run some simple preprocessing of ordered categorical variables, converting -#' ordered levels to integers if necessary, and storing the unique levels of a +#' Run some simple preprocessing of ordered categorical variables, converting +#' ordered levels to integers if necessary, and storing the unique levels of a #' variable. #' -#' @param x_input Vector of ordered categorical data. If the data is not already -#' stored as an ordered factor, it will be converted to one using the default +#' @param x_input Vector of ordered categorical data. If the data is not already +#' stored as an ordered factor, it will be converted to one using the default #' sort order. #' @param unique_levels Vector of unique levels for a categorical feature. #' @param var_name (Optional) Name of variable. #' -#' @return List containing a preprocessed vector of integer-converted ordered -#' categorical observations and the unique level of the original ordered +#' @return List containing a preprocessed vector of integer-converted ordered +#' categorical observations and the unique level of the original ordered #' categorical feature. #' @noRd #' #' @examples -#' x_levels <- c("1. Strongly disagree", "2. Disagree", -#' "3. Neither agree nor disagree", +#' x_levels <- c("1. Strongly disagree", "2. Disagree", +#' "3. Neither agree nor disagree", #' "4. Agree", "5. Strongly agree") -#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", +#' x <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", #' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' x_processed <- orderedCatPreprocess(x, x_levels) orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { @@ -829,8 +940,17 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { # Run time checks levels_not_in_reflist <- !(levels(x_input) %in% unique_levels) if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) warning_message <- paste0("Variable ", var_name, " includes ordered categorical levels not included in the original training set") - else warning_message <- paste0("Variable includes ordered categorical levels not included in the original training set") + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } warning(warning_message) } # Preprocessing @@ -843,8 +963,17 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { # Run time checks levels_not_in_reflist <- !(levels(x_factor) %in% unique_levels) if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) warning_message <- paste0("Variable ", var_name, " includes ordered categorical levels not included in the original training set") - else warning_message <- paste0("Variable includes ordered categorical levels not included in the original training set") + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } warning(warning_message) } # Preprocessing @@ -856,9 +985,9 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { return(x_preprocessed) } -#' Convert scalar input to vector of dimension `output_size`, +#' Convert scalar input to vector of dimension `output_size`, #' or check that input array is equivalent to a vector of dimension `output_size`. -#' +#' #' @param input Input to be converted to a vector (or passed through as-is) #' @param output_size Intended size of the output vector #' @return A vector of length `output_size` @@ -872,19 +1001,21 @@ expand_dims_1d <- function(input, output_size) { } output <- input } else { - stop("`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times") + stop( + "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" + ) } return(output) } -#' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`. +#' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`. #' Handles the following cases: #' 1. `input` is a scalar: output is simply a (`output_rows`, `output_cols`) matrix with `input` repeated for each element #' 2. `input` is a vector of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns #' 3. `input` is a vector of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows #' 4. `input` is a matrix of dimension (`output_rows`, `output_cols`): input is passed through as-is #' All other cases throw an error. -#' +#' #' @param input Input to be converted to a matrix (or passed through as-is) #' @param output_rows Intended number of rows in the output array #' @param output_cols Intended number of columns in the output array @@ -892,14 +1023,27 @@ expand_dims_1d <- function(input, output_size) { #' @export expand_dims_2d <- function(input, output_rows, output_cols) { if (length(input) == 1) { - output <- matrix(rep(input, output_rows * output_cols), ncol = output_cols) + output <- matrix( + rep(input, output_rows * output_cols), + ncol = output_cols + ) } else if (is.numeric(input)) { if (length(input) == output_cols) { - output <- matrix(rep(input, output_rows), nrow=output_rows, byrow = T) + output <- matrix( + rep(input, output_rows), + nrow = output_rows, + byrow = T + ) } else if (length(input) == output_rows) { - output <- matrix(rep(input, output_cols), ncol=output_cols, byrow = F) + output <- matrix( + rep(input, output_cols), + ncol = output_cols, + byrow = F + ) } else { - stop("If `input` is a vector, it must either contain `output_rows` or `output_cols` elements") + stop( + "If `input` is a vector, it must either contain `output_rows` or `output_cols` elements" + ) } } else if (is.matrix(input)) { if (nrow(input) != output_rows) { @@ -915,9 +1059,9 @@ expand_dims_2d <- function(input, output_rows, output_cols) { return(output) } -#' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal, +#' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal, #' or check that input array is equivalent to a square matrix of dimension `output_size` x `output_size`. -#' +#' #' @param input Input to be converted to a square matrix (or passed through as-is) #' @param output_size Intended row and column dimension of the square output matrix #' @return A square matrix of dimension `output_size` x `output_size` @@ -930,7 +1074,9 @@ expand_dims_2d_diag <- function(input, output_size) { stop("`input` must be a square matrix") } if (nrow(input) != output_size) { - stop("`input` must be a square matrix with `output_size` rows and columns") + stop( + "`input` must be a square matrix with `output_size` rows and columns" + ) } output <- input } else { diff --git a/R/variance.R b/R/variance.R index 32fc9a21..b0ad722e 100644 --- a/R/variance.R +++ b/R/variance.R @@ -7,7 +7,7 @@ #' @param b Global variance scale parameter #' @return None #' @export -#' +#' #' @examples #' X <- matrix(runif(10*100), ncol = 10) #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) @@ -18,8 +18,20 @@ #' a <- 1.0 #' b <- 1.0 #' sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset, rng, a, b) -sampleGlobalErrorVarianceOneIteration <- function(residual, dataset, rng, a, b) { - return(sample_sigma2_one_iteration_cpp(residual$data_ptr, dataset$data_ptr, rng$rng_ptr, a, b)) +sampleGlobalErrorVarianceOneIteration <- function( + residual, + dataset, + rng, + a, + b +) { + return(sample_sigma2_one_iteration_cpp( + residual$data_ptr, + dataset$data_ptr, + rng$rng_ptr, + a, + b + )) } #' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) @@ -30,7 +42,7 @@ sampleGlobalErrorVarianceOneIteration <- function(residual, dataset, rng, a, b) #' @param b Leaf variance scale parameter #' @return None #' @export -#' +#' #' @examples #' num_trees <- 100 #' leaf_dimension <- 1