Skip to content

Commit aee9b14

Browse files
authored
Merge pull request #71 from StochasticTree/multi_chain
Enable multi-chain sampling (serial or parallel)
2 parents 27fbaca + cce5966 commit aee9b14

File tree

106 files changed

+10520
-1692
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+10520
-1692
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
## System and data files
22
*.pdf
33
*.csv
4-
*.txt
54
*.DS_Store
65
lib/
76
build/
87
.vscode/
98
xcode/
109
*.json
1110
.vs/
11+
cpp_docs/doxyoutput/html
12+
cpp_docs/doxyoutput/xml
13+
cpp_docs/doxyoutput/latex
1214

1315
## R gitignore
1416

NAMESPACE

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,19 @@ export(bcf)
99
export(calibrate_inverse_gamma_error_variance)
1010
export(computeForestKernels)
1111
export(computeForestLeafIndices)
12+
export(convertBARTModelToJson)
1213
export(convertBCFModelToJson)
14+
export(createBARTModelFromCombinedJson)
15+
export(createBARTModelFromCombinedJsonString)
16+
export(createBARTModelFromJson)
17+
export(createBARTModelFromJsonFile)
18+
export(createBARTModelFromJsonString)
1319
export(createBCFModelFromJson)
1420
export(createBCFModelFromJsonFile)
21+
export(createBCFModelFromJsonString)
1522
export(createCppJson)
1623
export(createCppJsonFile)
24+
export(createCppJsonString)
1725
export(createForestContainer)
1826
export(createForestCovariates)
1927
export(createForestCovariatesFromMetadata)
@@ -27,7 +35,11 @@ export(createRandomEffectsDataset)
2735
export(createRandomEffectsModel)
2836
export(createRandomEffectsTracker)
2937
export(getRandomEffectSamples)
38+
export(loadForestContainerCombinedJson)
39+
export(loadForestContainerCombinedJsonString)
3040
export(loadForestContainerJson)
41+
export(loadRandomEffectSamplesCombinedJson)
42+
export(loadRandomEffectSamplesCombinedJsonString)
3143
export(loadRandomEffectSamplesJson)
3244
export(loadScalarJson)
3345
export(loadVectorJson)
@@ -43,7 +55,10 @@ export(preprocessTrainDataFrame)
4355
export(preprocessTrainMatrix)
4456
export(sample_sigma2_one_iteration)
4557
export(sample_tau_one_iteration)
58+
export(saveBARTModelToJsonFile)
59+
export(saveBARTModelToJsonString)
4660
export(saveBCFModelToJsonFile)
61+
export(saveBCFModelToJsonString)
4762
importFrom(R6,R6Class)
4863
importFrom(stats,lm)
4964
importFrom(stats,model.matrix)

R/bart.R

Lines changed: 921 additions & 99 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 356 additions & 55 deletions
Large diffs are not rendered by default.

R/cpp11.R

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,26 @@ rfx_group_ids_from_json_cpp <- function(json_ptr, rfx_label) {
104104
.Call(`_stochtree_rfx_group_ids_from_json_cpp`, json_ptr, rfx_label)
105105
}
106106

107+
rfx_container_append_from_json_cpp <- function(rfx_container_ptr, json_ptr, rfx_label) {
108+
invisible(.Call(`_stochtree_rfx_container_append_from_json_cpp`, rfx_container_ptr, json_ptr, rfx_label))
109+
}
110+
111+
rfx_container_from_json_string_cpp <- function(json_string, rfx_label) {
112+
.Call(`_stochtree_rfx_container_from_json_string_cpp`, json_string, rfx_label)
113+
}
114+
115+
rfx_label_mapper_from_json_string_cpp <- function(json_string, rfx_label) {
116+
.Call(`_stochtree_rfx_label_mapper_from_json_string_cpp`, json_string, rfx_label)
117+
}
118+
119+
rfx_group_ids_from_json_string_cpp <- function(json_string, rfx_label) {
120+
.Call(`_stochtree_rfx_group_ids_from_json_string_cpp`, json_string, rfx_label)
121+
}
122+
123+
rfx_container_append_from_json_string_cpp <- function(rfx_container_ptr, json_string, rfx_label) {
124+
invisible(.Call(`_stochtree_rfx_container_append_from_json_string_cpp`, rfx_container_ptr, json_string, rfx_label))
125+
}
126+
107127
rfx_model_cpp <- function(num_components, num_groups) {
108128
.Call(`_stochtree_rfx_model_cpp`, num_components, num_groups)
109129
}
@@ -188,14 +208,26 @@ rfx_label_mapper_to_list_cpp <- function(label_mapper_ptr) {
188208
.Call(`_stochtree_rfx_label_mapper_to_list_cpp`, label_mapper_ptr)
189209
}
190210

191-
forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant) {
192-
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant)
211+
forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
212+
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
193213
}
194214

195215
forest_container_from_json_cpp <- function(json_ptr, forest_label) {
196216
.Call(`_stochtree_forest_container_from_json_cpp`, json_ptr, forest_label)
197217
}
198218

219+
forest_container_append_from_json_cpp <- function(forest_sample_ptr, json_ptr, forest_label) {
220+
invisible(.Call(`_stochtree_forest_container_append_from_json_cpp`, forest_sample_ptr, json_ptr, forest_label))
221+
}
222+
223+
forest_container_from_json_string_cpp <- function(json_string, forest_label) {
224+
.Call(`_stochtree_forest_container_from_json_string_cpp`, json_string, forest_label)
225+
}
226+
227+
forest_container_append_from_json_string_cpp <- function(forest_sample_ptr, json_string, forest_label) {
228+
invisible(.Call(`_stochtree_forest_container_append_from_json_string_cpp`, forest_sample_ptr, json_string, forest_label))
229+
}
230+
199231
num_samples_forest_container_cpp <- function(forest_samples) {
200232
.Call(`_stochtree_num_samples_forest_container_cpp`, forest_samples)
201233
}
@@ -284,6 +316,10 @@ set_leaf_vector_forest_container_cpp <- function(forest_samples, leaf_vector) {
284316
invisible(.Call(`_stochtree_set_leaf_vector_forest_container_cpp`, forest_samples, leaf_vector))
285317
}
286318

319+
initialize_forest_model_cpp <- function(data, residual, forest_samples, tracker, init_values, leaf_model_int) {
320+
invisible(.Call(`_stochtree_initialize_forest_model_cpp`, data, residual, forest_samples, tracker, init_values, leaf_model_int))
321+
}
322+
287323
adjust_residual_forest_container_cpp <- function(data, residual, forest_samples, tracker, requires_basis, forest_num, add) {
288324
invisible(.Call(`_stochtree_adjust_residual_forest_container_cpp`, data, residual, forest_samples, tracker, requires_basis, forest_num, add))
289325
}
@@ -332,16 +368,16 @@ forest_kernel_compute_kernel_train_test_cpp <- function(forest_kernel, covariate
332368
.Call(`_stochtree_forest_kernel_compute_kernel_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num)
333369
}
334370

335-
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) {
336-
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
371+
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
372+
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
337373
}
338374

339-
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized) {
340-
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, global_variance, leaf_model_int, pre_initialized))
375+
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
376+
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
341377
}
342378

343-
sample_sigma2_one_iteration_cpp <- function(residual, rng, a, b) {
344-
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, rng, a, b)
379+
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
380+
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, dataset, rng, a, b)
345381
}
346382

347383
sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) {
@@ -472,10 +508,18 @@ json_add_rfx_groupids_cpp <- function(json_ptr, groupids) {
472508
.Call(`_stochtree_json_add_rfx_groupids_cpp`, json_ptr, groupids)
473509
}
474510

475-
json_save_cpp <- function(json_ptr, filename) {
476-
invisible(.Call(`_stochtree_json_save_cpp`, json_ptr, filename))
511+
get_json_string_cpp <- function(json_ptr) {
512+
.Call(`_stochtree_get_json_string_cpp`, json_ptr)
513+
}
514+
515+
json_save_file_cpp <- function(json_ptr, filename) {
516+
invisible(.Call(`_stochtree_json_save_file_cpp`, json_ptr, filename))
517+
}
518+
519+
json_load_file_cpp <- function(json_ptr, filename) {
520+
invisible(.Call(`_stochtree_json_load_file_cpp`, json_ptr, filename))
477521
}
478522

479-
json_load_cpp <- function(json_ptr, filename) {
480-
invisible(.Call(`_stochtree_json_load_cpp`, json_ptr, filename))
523+
json_load_string_cpp <- function(json_ptr, json_string) {
524+
invisible(.Call(`_stochtree_json_load_string_cpp`, json_ptr, json_string))
481525
}

R/forest.R

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,48 @@ ForestSamples <- R6::R6Class(
1616
#' @param num_trees Number of trees
1717
#' @param output_dimension Dimensionality of the outcome model
1818
#' @param is_leaf_constant Whether leaf is constant
19+
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
1920
#' @return A new `ForestContainer` object.
20-
initialize = function(num_trees, output_dimension=1, is_leaf_constant=F) {
21-
self$forest_container_ptr <- forest_container_cpp(num_trees, output_dimension, is_leaf_constant)
21+
initialize = function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
22+
self$forest_container_ptr <- forest_container_cpp(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
2223
},
2324

2425
#' @description
25-
#' Create a new ForestContainer object from a json object
26+
#' Create a new `ForestContainer` object from a json object
2627
#' @param json_object Object of class `CppJson`
2728
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
2829
#' @return A new `ForestContainer` object.
2930
load_from_json = function(json_object, json_forest_label) {
3031
self$forest_container_ptr <- forest_container_from_json_cpp(json_object$json_ptr, json_forest_label)
3132
},
3233

34+
#' @description
35+
#' Append to a `ForestContainer` object from a json object
36+
#' @param json_object Object of class `CppJson`
37+
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
38+
#' @return NULL
39+
append_from_json = function(json_object, json_forest_label) {
40+
forest_container_append_from_json_cpp(self$forest_container_ptr, json_object$json_ptr, json_forest_label)
41+
},
42+
43+
#' @description
44+
#' Create a new `ForestContainer` object from a json object
45+
#' @param json_string JSON string which parses into object of class `CppJson`
46+
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
47+
#' @return A new `ForestContainer` object.
48+
load_from_json_string = function(json_string, json_forest_label) {
49+
self$forest_container_ptr <- forest_container_from_json_string_cpp(json_string, json_forest_label)
50+
},
51+
52+
#' @description
53+
#' Append to a `ForestContainer` object from a json object
54+
#' @param json_string JSON string which parses into object of class `CppJson`
55+
#' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
56+
#' @return NULL
57+
append_from_json_string = function(json_string, json_forest_label) {
58+
forest_container_append_from_json_string_cpp(self$forest_container_ptr, json_string, json_forest_label)
59+
},
60+
3361
#' @description
3462
#' Predict every tree ensemble on every sample in `forest_dataset`
3563
#' @param forest_dataset `ForestDataset` R class
@@ -106,6 +134,26 @@ ForestSamples <- R6::R6Class(
106134
}
107135
},
108136

137+
#' @description
138+
#' Set a constant predicted value for every tree in the ensemble.
139+
#' Stops program if any tree is more than a root node.
140+
#' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...)
141+
#' @param outcome `Outcome` Outcome class (residual / partial residual)
142+
#' @param forest_model `ForestModel` object storing tracking structures used in training / sampling
143+
#' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
144+
#' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension.
145+
prepare_for_sampler = function(dataset, outcome, forest_model, leaf_model_int, leaf_value) {
146+
stopifnot(!is.null(dataset$data_ptr))
147+
stopifnot(!is.null(outcome$data_ptr))
148+
stopifnot(!is.null(forest_model$tracker_ptr))
149+
stopifnot(!is.null(self$forest_container_ptr))
150+
stopifnot(num_samples_forest_container_cpp(self$forest_container_ptr) == 0)
151+
152+
# Initialize the model
153+
initialize_forest_model_cpp(dataset$data_ptr, outcome$data_ptr, self$forest_container_ptr,
154+
forest_model$tracker_ptr, leaf_value, leaf_model_int)
155+
},
156+
109157
#' @description
110158
#' Adjusts residual based on the predictions of a forest
111159
#'
@@ -294,11 +342,12 @@ ForestSamples <- R6::R6Class(
294342
#' @param num_trees Number of trees
295343
#' @param output_dimension Dimensionality of the outcome model
296344
#' @param is_leaf_constant Whether leaf is constant
345+
#' @param is_exponentiated Whether forest predictions should be exponentiated before being returned
297346
#'
298347
#' @return `ForestSamples` object
299348
#' @export
300-
createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constant=F) {
349+
createForestContainer <- function(num_trees, output_dimension=1, is_leaf_constant=F, is_exponentiated=F) {
301350
return(invisible((
302-
ForestSamples$new(num_trees, output_dimension, is_leaf_constant)
351+
ForestSamples$new(num_trees, output_dimension, is_leaf_constant, is_exponentiated)
303352
)))
304353
}

R/kernel.R

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,30 @@ createForestKernel <- function() {
132132
#' corresponds to the observations for which outcomes are unobserved and must be estimated
133133
#' based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided,
134134
#' this function will only compute k(X_train, X_train).
135-
#' @param forest_num (Option) Index of the forest sample to use for kernel computation. If not provided,
135+
#' @param forest_num (Optional) Index of the forest sample to use for kernel computation. If not provided,
136136
#' this function will use the last forest.
137+
#' @param forest_type (Optional) Whether to compute the kernel from the mean or variance forest. Default: "mean". Specify "variance" for the variance forest.
138+
#' All other inputs are invalid. Must have sampled the relevant forest or an error will occur.
137139
#' @return List of kernel matrices. If `X_test = NULL`, the list contains
138140
#' one `n_train` x `n_train` matrix, where `n_train = nrow(X_train)`.
139141
#' This matrix is the kernel defined by `W_train %*% t(W_train)` where `W_train`
140142
#' is a matrix with `n_train` rows and as many columns as there are total leaves in an ensemble.
141143
#' If `X_test` is not `NULL`, the list contains two more matrices defined by
142144
#' `W_test %*% t(W_train)` and `W_test %*% t(W_test)`.
143145
#' @export
144-
computeForestKernels <- function(bart_model, X_train, X_test=NULL, forest_num=NULL) {
146+
computeForestKernels <- function(bart_model, X_train, X_test=NULL, forest_num=NULL, forest_type="mean") {
145147
stopifnot(class(bart_model)=="bartmodel")
148+
if (forest_type=="mean") {
149+
if (!bart_model$model_params$include_mean_forest) {
150+
stop("Mean forest was not sampled in the bart model provided")
151+
}
152+
} else if (forest_type=="variance") {
153+
if (!bart_model$model_params$include_variance_forest) {
154+
stop("Variance forest was not sampled in the bart model provided")
155+
}
156+
} else {
157+
stop("Must provide either 'mean' or 'variance' for the `forest_type` parameter")
158+
}
146159

147160
# Preprocess covariates
148161
if (!is.data.frame(X_train)) {
@@ -164,10 +177,17 @@ computeForestKernels <- function(bart_model, X_train, X_test=NULL, forest_num=NU
164177
num_samples <- bart_model$model_params$num_samples
165178
stopifnot(forest_num <= num_samples)
166179
sample_index <- ifelse(is.null(forest_num), num_samples-1, forest_num-1)
167-
return(forest_kernel$compute_kernel(
168-
covariates_train = X_train, covariates_test = X_test,
169-
forest_container = bart_model$forests, forest_num = sample_index
170-
))
180+
if (forest_type=="mean") {
181+
return(forest_kernel$compute_kernel(
182+
covariates_train = X_train, covariates_test = X_test,
183+
forest_container = bart_model$mean_forests, forest_num = sample_index
184+
))
185+
} else if (forest_type=="variance") {
186+
return(forest_kernel$compute_kernel(
187+
covariates_train = X_train, covariates_test = X_test,
188+
forest_container = bart_model$variance_forests, forest_num = sample_index
189+
))
190+
}
171191
}
172192

173193
#' Compute and return a vector representation of a forest's leaf predictions for
@@ -192,21 +212,41 @@ computeForestKernels <- function(bart_model, X_train, X_test=NULL, forest_num=NU
192212
#' corresponds to the observations for which outcomes are unobserved and must be estimated
193213
#' based on the kernels k(X_test,X_test), k(X_test,X_train), and k(X_train,X_train). If not provided,
194214
#' this function will only compute k(X_train, X_train).
195-
#' @param forest_num (Option) Index of the forest sample to use for kernel computation. If not provided,
215+
#' @param forest_num (Optional) Index of the forest sample to use for kernel computation. If not provided,
196216
#' this function will use the last forest.
217+
#' @param forest_type (Optional) Whether to compute the kernel from the mean or variance forest. Default: "mean". Specify "variance" for the variance forest.
218+
#' All other inputs are invalid. Must have sampled the relevant forest or an error will occur.
197219
#' @return List of vectors. If `X_test = NULL`, the list contains
198220
#' one vector of length `n_train * num_trees`, where `n_train = nrow(X_train)`
199221
#' and `num_trees` is the number of trees in `bart_model`. If `X_test` is not `NULL`,
200222
#' the list contains another vector of length `n_test * num_trees`.
201223
#' @export
202-
computeForestLeafIndices <- function(bart_model, X_train, X_test=NULL, forest_num=NULL) {
224+
computeForestLeafIndices <- function(bart_model, X_train, X_test=NULL, forest_num=NULL, forest_type="mean") {
203225
stopifnot(class(bart_model)=="bartmodel")
226+
if (forest_type=="mean") {
227+
if (!bart_model$model_params$include_mean_forest) {
228+
stop("Mean forest was not sampled in the bart model provided")
229+
}
230+
} else if (forest_type=="variance") {
231+
if (!bart_model$model_params$include_variance_forest) {
232+
stop("Variance forest was not sampled in the bart model provided")
233+
}
234+
} else {
235+
stop("Must provide either 'mean' or 'variance' for the `forest_type` parameter")
236+
}
204237
forest_kernel <- createForestKernel()
205238
num_samples <- bart_model$model_params$num_samples
206239
stopifnot(forest_num <= num_samples)
207240
sample_index <- ifelse(is.null(forest_num), num_samples-1, forest_num-1)
208-
return(forest_kernel$compute_leaf_indices(
209-
covariates_train = X_train, covariates_test = X_test,
210-
forest_container = bart_model$forests, forest_num = sample_index
211-
))
241+
if (forest_type == "mean") {
242+
return(forest_kernel$compute_leaf_indices(
243+
covariates_train = X_train, covariates_test = X_test,
244+
forest_container = bart_model$mean_forests, forest_num = sample_index
245+
))
246+
} else if (forest_type == "variance") {
247+
return(forest_kernel$compute_leaf_indices(
248+
covariates_train = X_train, covariates_test = X_test,
249+
forest_container = bart_model$variance_forests, forest_num = sample_index
250+
))
251+
}
212252
}

0 commit comments

Comments
 (0)