Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
num_retained_samples <- num_gfr + ifelse(keep_burnin, num_burnin, 0) + num_mcmc * num_chains
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf) leaf_scale_samples <- rep(NA, num_retained_samples)
if (include_mean_forest) mean_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
if (include_variance_forest) variance_forest_pred_train <- matrix(NA_real_, nrow(X_train), num_retained_samples)
sample_counter <- 0

# Initialize the leaves of each tree in the mean forest
Expand Down Expand Up @@ -757,13 +759,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
}
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
}
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand Down Expand Up @@ -910,13 +922,23 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
mean_forest_pred_train[,sample_counter] <- forest_model_mean$get_cached_forest_predictions()
}
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
variance_forest_pred_train[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
}
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand Down Expand Up @@ -949,6 +971,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_samples$delete_sample(0)
}
}
if (include_mean_forest) {
mean_forest_pred_train <- mean_forest_pred_train[,(num_gfr+1):ncol(mean_forest_pred_train)]
}
if (include_variance_forest) {
variance_forest_pred_train <- variance_forest_pred_train[,(num_gfr+1):ncol(variance_forest_pred_train)]
}
if (sample_sigma2_global) {
global_var_samples <- global_var_samples[(num_gfr+1):length(global_var_samples)]
}
Expand All @@ -960,13 +988,15 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

# Mean forest predictions
if (include_mean_forest) {
y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
# y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train
y_hat_train <- mean_forest_pred_train*y_std_train + y_bar_train
if (has_test) y_hat_test <- forest_samples_mean$predict(forest_dataset_test)*y_std_train + y_bar_train
}

# Variance forest predictions
if (include_variance_forest) {
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
# sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
sigma2_x_hat_train <- exp(variance_forest_pred_train)
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
}

Expand Down
38 changes: 36 additions & 2 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
if (sample_sigma2_global) global_var_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf_mu) leaf_scale_mu_samples <- rep(NA, num_retained_samples)
if (sample_sigma2_leaf_tau) leaf_scale_tau_samples <- rep(NA, num_retained_samples)
muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
if (include_variance_forest) sigma2_x_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples)
sample_counter <- 0

# Prepare adaptive coding structure
Expand Down Expand Up @@ -997,6 +999,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
}

# Sample variance parameters (if requested)
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand All @@ -1016,6 +1023,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)

# Cannot cache train set predictions for tau because the cached predictions in the
# tracking data structures are pre-multiplied by the basis (treatment)
# ...

# Sample coding parameters (if requested)
if (adaptive_coding) {
# Estimate mu(X) and tau(X) and compute y - mu(X)
Expand Down Expand Up @@ -1060,6 +1071,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
}
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand Down Expand Up @@ -1263,6 +1279,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
muhat_train_raw[,sample_counter] <- forest_model_mu$get_cached_forest_predictions()
}

# Sample variance parameters (if requested)
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand All @@ -1282,6 +1303,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)

# Cannot cache train set predictions for tau because the cached predictions in the
# tracking data structures are pre-multiplied by the basis (treatment)
# ...

# Sample coding parameters (if requested)
if (adaptive_coding) {
# Estimate mu(X) and tau(X) and compute y - mu(X)
Expand Down Expand Up @@ -1326,6 +1351,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)

# Cache train set predictions since they are already computed during sampling
if (keep_sample) {
sigma2_x_train_raw[,sample_counter] <- forest_model_variance$get_cached_forest_predictions()
}
}
if (sample_sigma2_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
Expand Down Expand Up @@ -1372,11 +1402,15 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
b_1_samples <- b_1_samples[(num_gfr+1):length(b_1_samples)]
b_0_samples <- b_0_samples[(num_gfr+1):length(b_0_samples)]
}
muhat_train_raw <- muhat_train_raw[,(num_gfr+1):ncol(muhat_train_raw)]
if (include_variance_forest) {
sigma2_x_train_raw <- sigma2_x_train_raw[,(num_gfr+1):ncol(sigma2_x_train_raw)]
}
num_retained_samples <- num_retained_samples - num_gfr
}

# Forest predictions
mu_hat_train <- forest_samples_mu$predict(forest_dataset_train)*y_std_train + y_bar_train
mu_hat_train <- muhat_train_raw*y_std_train + y_bar_train
if (adaptive_coding) {
tau_hat_train_raw <- forest_samples_tau$predict_raw(forest_dataset_train)
tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples))*y_std_train
Expand All @@ -1395,7 +1429,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
}
if (include_variance_forest) {
sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train)
sigma2_x_hat_train <- exp(sigma2_x_train_raw)
if (has_test) sigma2_x_hat_test <- forest_samples_variance$predict(forest_dataset_test)
}

Expand Down
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,10 @@ forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
}

get_cached_forest_predictions_cpp <- function(tracker_ptr) {
.Call(`_stochtree_get_cached_forest_predictions_cpp`, tracker_ptr)
}

sample_without_replacement_integer_cpp <- function(population_vector, sampling_probs, sample_size) {
.Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size)
}
Expand Down
7 changes: 7 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ ForestModel <- R6::R6Class(
}
},

#' @description
#' Extract an internally-cached prediction of a forest on the training dataset in a sampler.
#' @return Vector with as many elements as observations in the training dataset
get_cached_forest_predictions = function() {
get_cached_forest_predictions_cpp(self$tracker_ptr)
},

#' @description
#' Propagates basis update through to the (full/partial) residual by iteratively
#' (a) adding back in the previous prediction of each tree, (b) recomputing predictions
Expand Down
4 changes: 4 additions & 0 deletions include/stochtree/partition_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class ForestTracker {
SampleNodeMapper* GetSampleNodeMapper() {return sample_node_mapper_.get();}
UnsortedNodeSampleTracker* GetUnsortedNodeSampleTracker() {return unsorted_node_sample_tracker_.get();}
SortedNodeSampleTracker* GetSortedNodeSampleTracker() {return sorted_node_sample_tracker_.get();}
int GetNumObservations() {return num_observations_;}
int GetNumTrees() {return num_trees_;}
int GetNumFeatures() {return num_features_;}
bool Initialized() {return initialized_;}

private:
/*! \brief Mapper from observations to predicted values summed over every tree in a forest */
Expand Down
14 changes: 14 additions & 0 deletions man/ForestModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
CHECK_EQ(this->num_trees_, forest_container_json.at("num_trees"));
CHECK_EQ(this->output_dimension_, forest_container_json.at("output_dimension"));
CHECK_EQ(this->is_leaf_constant_, forest_container_json.at("is_leaf_constant"));
CHECK_EQ(this->is_exponentiated_, forest_container_json.at("is_exponentiated"));
CHECK_EQ(this->initialized_, forest_container_json.at("initialized"));
int new_num_samples = forest_container_json.at("num_samples");

Expand All @@ -215,8 +216,7 @@ void ForestContainer::append_from_json(const json& forest_container_json) {
for (int i = 0; i < forest_container_json.at("num_samples"); i++) {
forest_ind = this->num_samples_ + i;
forest_label = "forest_" + std::to_string(i);
// forests_[forest_ind] = std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_);
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_));
forests_.push_back(std::make_unique<TreeEnsemble>(this->num_trees_, this->output_dimension_, this->is_leaf_constant_, this->is_exponentiated_));
forests_[forest_ind]->from_json(forest_container_json.at(forest_label));
}
this->num_samples_ += new_num_samples;
Expand Down
8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,13 @@ extern "C" SEXP _stochtree_forest_tracker_cpp(SEXP data, SEXP feature_types, SEX
END_CPP11
}
// sampler.cpp
cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_pointer<StochTree::ForestTracker> tracker_ptr);
extern "C" SEXP _stochtree_get_cached_forest_predictions_cpp(SEXP tracker_ptr) {
BEGIN_CPP11
return cpp11::as_sexp(get_cached_forest_predictions_cpp(cpp11::as_cpp<cpp11::decay_t<cpp11::external_pointer<StochTree::ForestTracker>>>(tracker_ptr)));
END_CPP11
}
// sampler.cpp
cpp11::writable::integers sample_without_replacement_integer_cpp(cpp11::integers population_vector, cpp11::doubles sampling_probs, int sample_size);
extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP population_vector, SEXP sampling_probs, SEXP sample_size) {
BEGIN_CPP11
Expand Down Expand Up @@ -1539,6 +1546,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4},
{"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1},
{"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1},
{"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1},
{"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3},
{"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2},
{"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2},
Expand Down
11 changes: 11 additions & 0 deletions src/py_stochtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,16 @@ class ForestSamplerCpp {
}
}

py::array_t<double> GetCachedForestPredictions() {
int n_train = tracker_->GetNumObservations();
auto output = py::array_t<double>(py::detail::any_container<py::ssize_t>({n_train}));
auto accessor = output.mutable_unchecked<1>();
for (size_t i = 0; i < n_train; i++) {
accessor(i) = tracker_->GetSamplePrediction(i);
}
return output;
}

void PropagateBasisUpdate(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestCpp& forest) {
// Perform the update operation
StochTree::UpdateResidualNewBasis(*tracker_, *(dataset.GetDataset()), *(residual.GetData()), forest.GetEnsemble());
Expand Down Expand Up @@ -2147,6 +2157,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
.def("ReconstituteTrackerFromForest", &ForestSamplerCpp::ReconstituteTrackerFromForest)
.def("SampleOneIteration", &ForestSamplerCpp::SampleOneIteration)
.def("InitializeForestModel", &ForestSamplerCpp::InitializeForestModel)
.def("GetCachedForestPredictions", &ForestSamplerCpp::GetCachedForestPredictions)
.def("PropagateBasisUpdate", &ForestSamplerCpp::PropagateBasisUpdate)
.def("PropagateResidualUpdate", &ForestSamplerCpp::PropagateResidualUpdate)
.def("UpdateAlpha", &ForestSamplerCpp::UpdateAlpha)
Expand Down
Loading
Loading