Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2017053
Refactored the sampler classes into stateless templated functions
andrewherren Aug 24, 2024
c3bfa66
Fixed R package bug and rearranged tree_sampler header file
andrewherren Aug 24, 2024
e934db2
Added StochTree scope to sampler function calls
andrewherren Aug 24, 2024
9646a08
Refactor sampler iteration to avoid incremental object creation
andrewherren Aug 27, 2024
894deb2
Refactored R package C++ calls
andrewherren Aug 27, 2024
7e2a110
Updated python library C++ code
andrewherren Aug 27, 2024
64c19e8
Added include <variant>
andrewherren Aug 27, 2024
06efbb5
Updated unit tests
andrewherren Aug 27, 2024
ad03adb
Initial setup for building and publishing C++ documentation
andrewherren Aug 31, 2024
c444a62
Updated C++ documentation
andrewherren Sep 1, 2024
0a78499
Updated C++ documentation and doc build config
andrewherren Sep 2, 2024
5d135b2
Updated C++ documentation
andrewherren Sep 3, 2024
330abf8
Merge branch 'cpp_docs' into cpp-api-streamline
andrewherren Sep 5, 2024
6409d48
Updated C++ doc build instructions
andrewherren Sep 5, 2024
923ae54
Not-yet-fully-functional heteroskedasticity forest implementation
andrewherren Sep 12, 2024
e04c33a
Functional, but numerically incorrect heteroskedasticity BART impleme…
andrewherren Sep 17, 2024
7d7b432
Updated heteroscedasticity model
andrewherren Sep 17, 2024
a0e2422
Fixed prediction bug
andrewherren Sep 18, 2024
d7217b5
Added debugging scripts and data (and a non-working update of varianc…
andrewherren Sep 25, 2024
191fcb8
Parameterizing as precision, rather than variance (still not producin…
andrewherren Sep 25, 2024
dcbdd99
Updated variance forest code and demo
andrewherren Sep 26, 2024
01bd2d2
Added a parameter to rescale y to variance other than 1
andrewherren Sep 26, 2024
7dc9463
Rescale samples by variance_scale after sampling is complete
andrewherren Sep 26, 2024
decf243
Adding TODO
andrewherren Sep 26, 2024
b76423a
Correct log integrated likelihood
andrewherren Sep 26, 2024
9e815df
Updated BART docs and vignettes
andrewherren Sep 27, 2024
acda38a
Simplified variance model sufficient statistic class
andrewherren Oct 1, 2024
e00f9df
Converted internal heteroskedastic model back to variance, rather tha…
andrewherren Oct 1, 2024
150db50
Updated predict.bartmodel() function
andrewherren Oct 1, 2024
18fa86d
Merge branch 'main' into cpp-api-streamline
andrewherren Oct 1, 2024
78114d9
Update R unit tests
andrewherren Oct 1, 2024
e643c4b
Fixed python unit tests
andrewherren Oct 1, 2024
0abaaa0
Updated python interface to include variance forest
andrewherren Oct 4, 2024
9a3160f
Fixed python unit tests
andrewherren Oct 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
## System and data files
*.pdf
*.csv
*.txt
*.DS_Store
lib/
build/
.vscode/
xcode/
*.json
.vs/
cpp_docs/doxyoutput/html
cpp_docs/doxyoutput/xml
cpp_docs/doxyoutput/latex

## R gitignore

Expand Down
332 changes: 241 additions & 91 deletions R/bart.R

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Sample variance parameters (if requested)
if (sample_sigma_global) {
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
current_sigma2 <- global_var_samples[i]
}
if (sample_sigma_leaf_mu) {
Expand Down Expand Up @@ -578,7 +578,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Sample variance parameters (if requested)
if (sample_sigma_global) {
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
current_sigma2 <- global_var_samples[i]
}
if (sample_sigma_leaf_tau) {
Expand Down Expand Up @@ -625,7 +625,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Sample variance parameters (if requested)
if (sample_sigma_global) {
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
current_sigma2 <- global_var_samples[i]
}
if (sample_sigma_leaf_mu) {
Expand Down Expand Up @@ -677,7 +677,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Sample variance parameters (if requested)
if (sample_sigma_global) {
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, rng, a_global, b_global)
global_var_samples[i] <- sample_sigma2_one_iteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
current_sigma2 <- global_var_samples[i]
}
if (sample_sigma_leaf_tau) {
Expand Down
20 changes: 12 additions & 8 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ rfx_label_mapper_to_list_cpp <- function(label_mapper_ptr) {
.Call(`_stochtree_rfx_label_mapper_to_list_cpp`, label_mapper_ptr)
}

forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant) {
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant)
forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
}

forest_container_from_json_cpp <- function(json_ptr, forest_label) {
Expand Down Expand Up @@ -284,6 +284,10 @@ set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
}

initialize_forest_model_cpp <- function(data, residual, forest_samples, tracker, init_values, leaf_model_int) {
invisible(.Call(`_stochtree_initialize_forest_model_cpp`, data, residual, forest_samples, tracker, init_values, leaf_model_int))
}

adjust_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, requires_basis, forest_num, add) {
invisible(.Call(`_stochtree_adjust_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
}
Expand Down Expand Up @@ -332,16 +336,16 @@ forest_kernel_compute_kernel_train_test_cpp <- function(forest_kernel, covariate
.Call(`_stochtree_forest_kernel_compute_kernel_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
}

sample_sigma2_one_iteration_cpp <- function(residual, rng, a, b) {
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, rng, a, b)
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, dataset, rng, a, b)
}

sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) {
Expand Down
30 changes: 26 additions & 4 deletions R/forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ ForestSamples <- R6::R6Class(
#' @param num_trees Number of trees
#' @param output_dimension Dimensionality of the outcome model
#' @param is_leaf_constant Whether leaf is constant
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
#' @return A new `ForestContainer` object.
initialize = function(num_trees, output_dimension=1, is_leaf_constant=F) {
self$forest_container_ptr <- forest_container_cpp(num_trees, output_dimension, is_leaf_constant)
initialize = function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
self$forest_container_ptr <- forest_container_cpp(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
},

#' @description
Expand Down Expand Up @@ -106,6 +107,26 @@ ForestSamples <- R6::R6Class(
}
},

#' @description
#' Set a constant predicted value for every tree in the ensemble.
#' Stops program if any tree is more than a root node.
#' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...)
#' @param outcome `Outcome` Outcome class (residual / partial residual)
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
#' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
#' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension.
prepare_for_sampler = function(dataset, outcome, forest_model, leaf_model_int, leaf_value) {
stopifnot(!is.null(dataset$data_ptr))
stopifnot(!is.null(outcome$data_ptr))
stopifnot(!is.null(forest_model$tracker_ptr))
stopifnot(!is.null(self$forest_container_ptr))
stopifnot(num_samples_forest_container_cpp(self$forest_container_ptr) == 0)

# Initialize the model
initialize_forest_model_cpp(dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr,
forest_model$tracker_ptr, leaf_value, leaf_model_int)
},

#' @description
#' Adjusts residual based on the predictions of a forest
#'
Expand Down Expand Up @@ -294,11 +315,12 @@ ForestSamples <- R6::R6Class(
#' @param num_trees Number of trees
#' @param output_dimension Dimensionality of the outcome model
#' @param is_leaf_constant Whether leaf is constant
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
#'
#' @return `ForestSamples` object
#' @export
createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constant=F) {
createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
return(invisible((
ForestSamples$new(num_trees, output_dimension, is_leaf_constant)
ForestSamples$new(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
)))
}
10 changes: 6 additions & 4 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,29 @@ ForestModel <- R6::R6Class(
#' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)
#' @param leaf_model_scale Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`)
#' @param variable_weights Vector specifying sampling probability for all p covariates in `forest_dataset`
#' @param a_forest Shape parameter on variance forest model (if applicable)
#' @param b_forest Scale parameter on variance forest model (if applicable)
#' @param global_scale Global variance parameter
#' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: 500, currently only used when `GFR = TRUE`)
#' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm
#' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: F.
sample_one_iteration = function(forest_dataset, residual, forest_samples, rng, feature_types,
leaf_model_int, leaf_model_scale, variable_weights,
global_scale, cutpoint_grid_size = 500, gfr = T,
pre_initialized = F) {
a_forest, b_forest, global_scale, cutpoint_grid_size = 500,
gfr = T, pre_initialized = F) {
if (gfr) {
sample_gfr_one_iteration_cpp(
forest_dataset$data_ptr, residual$data_ptr,
forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr,
rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
variable_weights, global_scale, leaf_model_int, pre_initialized
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
)
} else {
sample_mcmc_one_iteration_cpp(
forest_dataset$data_ptr, residual$data_ptr,
forest_samples$forest_container_ptr, self$tracker_ptr, self$tree_prior_ptr,
rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale,
variable_weights, global_scale, leaf_model_int, pre_initialized
variable_weights, a_forest, b_forest, global_scale, leaf_model_int, pre_initialized
)
}
}
Expand Down
5 changes: 3 additions & 2 deletions R/variance.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#' Sample one iteration of the (inverse gamma) global variance model
#'
#' @param residual Outcome class
#' @param dataset ForestDataset class
#' @param rng C++ random number generator
#' @param a Global variance shape parameter
#' @param b Global variance scale parameter
#'
#' @export
sample_sigma2_one_iteration <- function(residual, rng, a, b) {
return(sample_sigma2_one_iteration_cpp(residual$data_ptr, rng$rng_ptr, a, b))
sample_sigma2_one_iteration <- function(residual, dataset, rng, a, b) {
return(sample_sigma2_one_iteration_cpp(residual$data_ptr, dataset$data_ptr, rng$rng_ptr, a, b))
}

#' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
Expand Down
Loading