Skip to content

Commit 78cc4c6

Browse files
authored
Merge pull request #178 from StochasticTree/train-set-predicition-caching
Initial BART refactor that avoids double predicting forests on the train dataset
2 parents 33174a6 + 49b51ca commit 78cc4c6

File tree

18 files changed

+604
-55
lines changed

18 files changed

+604
-55
lines changed

R/bart.R

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
707707
num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains
708708
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
709709
if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples)
710+
if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
711+
if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
710712
sample_counter <- 0
711713

712714
# Initialize the leaves of each tree in the mean forest
@@ -757,13 +759,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
757759
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
758760
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
759761
)
762+
763+
# Cache train set predictions since they are already computed during sampling
764+
if (keep_sample) {
765+
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
766+
}
760767
}
761768
if (include_variance_forest) {
762769
forest_model_variance$sample_one_iteration(
763770
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
764771
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
765772
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
766773
)
774+
775+
# Cache train set predictions since they are already computed during sampling
776+
if (keep_sample) {
777+
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
778+
}
767779
}
768780
if (sample_sigma2_global) {
769781
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -910,13 +922,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
910922
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
911923
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
912924
)
925+
926+
# Cache train set predictions since they are already computed during sampling
927+
if (keep_sample) {
928+
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
929+
}
913930
}
914931
if (include_variance_forest) {
915932
forest_model_variance$sample_one_iteration(
916933
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
917934
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
918935
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
919936
)
937+
938+
# Cache train set predictions since they are already computed during sampling
939+
if (keep_sample) {
940+
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
941+
}
920942
}
921943
if (sample_sigma2_global) {
922944
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -949,6 +971,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
949971
rfx_samples$delete_sample(0)
950972
}
951973
}
974+
if (include_mean_forest) {
975+
mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)]
976+
}
977+
if (include_variance_forest) {
978+
variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)]
979+
}
952980
if (sample_sigma2_global) {
953981
global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)]
954982
}
@@ -960,13 +988,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
960988

961989
# Mean forest predictions
962990
if (include_mean_forest) {
963-
y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
991+
# y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
992+
y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train
964993
if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train
965994
}
966995

967996
# Variance forest predictions
968997
if (include_variance_forest) {
969-
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
998+
# sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
999+
sigma2_x_hat_train <- exp(variance_forest_pred_train)
9701000
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
9711001
}
9721002

R/bcf.R

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
885885
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
886886
if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples)
887887
if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples)
888+
muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
889+
if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
888890
sample_counter <- 0
889891

890892
# Prepare adaptive coding structure
@@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
997999
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
9981000
)
9991001

1002+
# Cache train set predictions since they are already computed during sampling
1003+
if (keep_sample) {
1004+
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
1005+
}
1006+
10001007
# Sample variance parameters (if requested)
10011008
if (sample_sigma2_global) {
10021009
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10161023
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
10171024
)
10181025

1026+
# Cannot cache train set predictions for tau because the cached predictions in the
1027+
# tracking data structures are pre-multiplied by the basis (treatment)
1028+
# ...
1029+
10191030
# Sample coding parameters (if requested)
10201031
if (adaptive_coding) {
10211032
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
10601071
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
10611072
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
10621073
)
1074+
1075+
# Cache train set predictions since they are already computed during sampling
1076+
if (keep_sample) {
1077+
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
1078+
}
10631079
}
10641080
if (sample_sigma2_global) {
10651081
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12631279
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
12641280
)
12651281

1282+
# Cache train set predictions since they are already computed during sampling
1283+
if (keep_sample) {
1284+
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
1285+
}
1286+
12661287
# Sample variance parameters (if requested)
12671288
if (sample_sigma2_global) {
12681289
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
12821303
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
12831304
)
12841305

1306+
# Cannot cache train set predictions for tau because the cached predictions in the
1307+
# tracking data structures are pre-multiplied by the basis (treatment)
1308+
# ...
1309+
12851310
# Sample coding parameters (if requested)
12861311
if (adaptive_coding) {
12871312
# Estimate mu(X) and tau(X) and compute y - mu(X)
@@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13261351
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
13271352
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
13281353
)
1354+
1355+
# Cache train set predictions since they are already computed during sampling
1356+
if (keep_sample) {
1357+
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
1358+
}
13291359
}
13301360
if (sample_sigma2_global) {
13311361
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
@@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13721402
b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)]
13731403
b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)]
13741404
}
1405+
muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)]
1406+
if (include_variance_forest) {
1407+
sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)]
1408+
}
13751409
num_retained_samples <- num_retained_samples - num_gfr
13761410
}
13771411

13781412
# Forest predictions
1379-
mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train
1413+
mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train
13801414
if (adaptive_coding) {
13811415
tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train)
13821416
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train
@@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
13951429
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
13961430
}
13971431
if (include_variance_forest) {
1398-
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
1432+
sigma2_x_hat_train <- exp(sigma2_x_train_raw)
13991433
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
14001434
}
14011435

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
640640
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
641641
}
642642

643+
get_cached_forest_predictions_cpp <- function(tracker_ptr) {
644+
.Call(`_stochtree_get_cached_forest_predictions_cpp`, tracker_ptr)
645+
}
646+
643647
sample_without_replacement_integer_cpp <- function(population_vector, sampling_probs, sample_size) {
644648
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
645649
}

R/model.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ ForestModel <- R6::R6Class(
126126
}
127127
},
128128

129+
#' @description
130+
#' Extract an internally-cached prediction of a forest on the training dataset in a sampler.
131+
#' @return Vector with as many elements as observations in the training dataset
132+
get_cached_forest_predictions = function() {
133+
get_cached_forest_predictions_cpp(self$tracker_ptr)
134+
},
135+
129136
#' @description
130137
#' Propagates basis update through to the (full/partial) residual by iteratively
131138
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions

include/stochtree/partition_tracker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class ForestTracker {
9191
SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();}
9292
UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();}
9393
SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();}
94+
int GetNumObservations() {return num_observations_;}
95+
int GetNumTrees() {return num_trees_;}
96+
int GetNumFeatures() {return num_features_;}
97+
bool Initialized() {return initialized_;}
9498

9599
private:
96100
/*! \brief Mapper from observations to predicted values summed over every tree in a forest */

man/ForestModel.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/container.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
206206
CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees"));
207207
CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension"));
208208
CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant"));
209+
CHECK_EQ(this->is_exponentiated_, forest_container_json.at("is_exponentiated"));
209210
CHECK_EQ(this->initialized_, forest_container_json.at("initialized"));
210211
int new_num_samples = forest_container_json.at("num_samples");
211212

@@ -215,8 +216,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
215216
for (int i = 0; i < forest_container_json.at("num_samples"); i++) {
216217
forest_ind = this->num_samples_ + i;
217218
forest_label = "forest_" + std::to_string(i);
218-
// forests_[forest_ind] = std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_);
219-
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_));
219+
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_, this->is_exponentiated_));
220220
forests_[forest_ind]->from_json(forest_container_json.at(forest_label));
221221
}
222222
this->num_samples_ += new_num_samples;

src/cpp11.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,13 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX
11871187
END_CPP11
11881188
}
11891189
// sampler.cpp
1190+
cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker_ptr);
1191+
extern "C" SEXP _stochtree_get_cached_forest_predictions_cpp(SEXP tracker_ptr) {
1192+
BEGIN_CPP11
1193+
return cpp11::as_sexp(get_cached_forest_predictions_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker_ptr)));
1194+
END_CPP11
1195+
}
1196+
// sampler.cpp
11901197
cpp11::writable::integers sample_without_replacement_integer_cpp(cpp11::integers population_vector, cpp11::doubles sampling_probs, int sample_size);
11911198
extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP population_vector, SEXP sampling_probs, SEXP sample_size) {
11921199
BEGIN_CPP11
@@ -1539,6 +1546,7 @@ static const R_CallMethodDef CallEntries[] = {
15391546
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
15401547
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
15411548
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},
1549+
{"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1},
15421550
{"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3},
15431551
{"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2},
15441552
{"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2},

src/py_stochtree.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,16 @@ class ForestSamplerCpp {
11661166
}
11671167
}
11681168

1169+
py::array_t<double> GetCachedForestPredictions() {
1170+
int n_train = tracker_->GetNumObservations();
1171+
auto output = py::array_t<double>(py::detail::any_container<py::ssize_t>({n_train}));
1172+
auto accessor = output.mutable_unchecked<1>();
1173+
for (size_t i = 0; i < n_train; i++) {
1174+
accessor(i) = tracker_->GetSamplePrediction(i);
1175+
}
1176+
return output;
1177+
}
1178+
11691179
void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) {
11701180
// Perform the update operation
11711181
StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble());
@@ -2147,6 +2157,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
21472157
.def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest)
21482158
.def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration)
21492159
.def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel)
2160+
.def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions)
21502161
.def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate)
21512162
.def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate)
21522163
.def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha)

0 commit comments

Comments
 (0)