Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ RoxygenNote: 7.3.1
LinkingTo:
cpp11
Suggests:
doParallel,
foreach,
ggplot2,
knitr,
rmarkdown,
latex2exp,
Matrix,
tgp,
MASS,
mvtnorm,
ggplot2,
latex2exp,
rmarkdown,
testthat (>= 3.0.0),
foreach,
doParallel
tgp
VignetteBuilder: knitr
SystemRequirements: C++17
Imports:
Expand Down
3 changes: 1 addition & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ S3method(predict,bcf)
export(bart)
export(bcf)
export(calibrate_inverse_gamma_error_variance)
export(computeForestKernels)
export(computeForestLeafIndices)
export(computeMaxLeafIndex)
export(convertBARTModelToJson)
export(convertBCFModelToJson)
export(createBARTModelFromCombinedJson)
Expand All @@ -26,7 +26,6 @@ export(createForestContainer)
export(createForestCovariates)
export(createForestCovariatesFromMetadata)
export(createForestDataset)
export(createForestKernel)
export(createForestModel)
export(createOutcome)
export(createRNG)
Expand Down
28 changes: 4 additions & 24 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -340,32 +340,12 @@ predict_forest_raw_single_forest_cpp <- function(forest_samples, dataset, forest
.Call(`_stochtree_predict_forest_raw_single_forest_cpp`, forest_samples, dataset, forest_num)
}

forest_kernel_cpp <- function() {
.Call(`_stochtree_forest_kernel_cpp`)
forest_container_get_max_leaf_index_cpp <- function(forest_container, forest_num) {
.Call(`_stochtree_forest_container_get_max_leaf_index_cpp`, forest_container, forest_num)
}

forest_kernel_compute_leaf_indices_train_cpp <- function(forest_kernel, covariates_train, forest_container, forest_num) {
invisible(.Call(`_stochtree_forest_kernel_compute_leaf_indices_train_cpp`, forest_kernel, covariates_train, forest_container, forest_num))
}

forest_kernel_compute_leaf_indices_train_test_cpp <- function(forest_kernel, covariates_train, covariates_test, forest_container, forest_num) {
invisible(.Call(`_stochtree_forest_kernel_compute_leaf_indices_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num))
}

forest_kernel_get_train_leaf_indices_cpp <- function(forest_kernel) {
.Call(`_stochtree_forest_kernel_get_train_leaf_indices_cpp`, forest_kernel)
}

forest_kernel_get_test_leaf_indices_cpp <- function(forest_kernel) {
.Call(`_stochtree_forest_kernel_get_test_leaf_indices_cpp`, forest_kernel)
}

forest_kernel_compute_kernel_train_cpp <- function(forest_kernel, covariates_train, forest_container, forest_num) {
.Call(`_stochtree_forest_kernel_compute_kernel_train_cpp`, forest_kernel, covariates_train, forest_container, forest_num)
}

forest_kernel_compute_kernel_train_test_cpp <- function(forest_kernel, covariates_train, covariates_test, forest_container, forest_num) {
.Call(`_stochtree_forest_kernel_compute_kernel_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num)
compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) {
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
Expand Down
370 changes: 141 additions & 229 deletions R/kernel.R

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ reference:
- createForestModel
- ForestSamples
- createForestContainer
- ForestKernel
- createForestKernel
- CppRNG
- createRNG
- calibrate_inverse_gamma_error_variance
- preprocessBartParams
- preprocessBcfParams
- computeForestLeafIndices
- computeMaxLeafIndex

- subtitle: Random Effects
desc: >
Expand All @@ -105,8 +105,6 @@ reference:
- sample_sigma2_one_iteration
- sample_tau_one_iteration
- sample_tau_one_iteration
- computeForestKernels
- computeForestLeafIndices

- title: Package info
desc: >
Expand Down
3 changes: 3 additions & 0 deletions include/stochtree/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ForestContainer {
void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
std::vector<int>& forest_indices, int num_trees, data_size_t n);

inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
inline int32_t NumSamples() {return num_samples_;}
Expand Down
49 changes: 48 additions & 1 deletion include/stochtree/ensemble.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,20 @@ class TreeEnsemble {
}
}

/*!
* \brief Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the
* number of leaves in each tree. This is used in conjunction with `PredictLeafIndicesInplace`,
* which returns an observation-specific leaf index for every observation-tree pair.
*/
int GetMaxLeafIndex() {
int max_leaf = 0;
for (int j = 0; j < num_trees_; j++) {
auto &tree = *trees_[j];
max_leaf += tree.NumLeaves();
}
return max_leaf;
}

/*!
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
* observation in a ForestDataset. Internally, trees are stored as essentially
Expand Down Expand Up @@ -274,7 +288,7 @@ class TreeEnsemble {
*
* Note: this assumes the creation of a vector of column indices of size
* `dataset.NumObservations()` x `ensemble.NumTrees()`
* \param ForestDataset Dataset with which to predict leaf indices from the tree
* \param covariates Matrix of covariates
* \param output Vector of length num_trees*n which stores the leaf node prediction
* \param num_trees Number of trees in an ensemble
* \param n Size of dataset
Expand All @@ -292,6 +306,39 @@ class TreeEnsemble {
}
}

/*!
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
* observation in a ForestDataset. Internally, trees are stored as essentially
* vectors of node information, and the leaves_ vector gives us node IDs for every
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
* which leaf number it is mapped to. Since the leaf numbers themselves
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
* We compute this at the tree-level and coordinate this computation at the
* ensemble level.
*
* Note: this assumes the creation of a matrix of column indices with `num_trees*n` rows
* and as many columns as forests that were requested from R / Python
* \param covariates Matrix of covariates
* \param output Matrix with num_trees*n rows and as many columns as forests that were requested from R / Python
* \param column_ind Index of column in `output` into which the result should be unpacked
* \param num_trees Number of trees in an ensemble
* \param n Size of dataset
*/
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
int column_ind, int num_trees, data_size_t n) {
CHECK_GE(output.size(), num_trees*n);
int offset = 0;
int max_leaf = 0;
for (int j = 0; j < num_trees; j++) {
auto &tree = *trees_[j];
int num_leaves = tree.NumLeaves();
tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf);
offset += n;
max_leaf += num_leaves;
}
}

/*!
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
* observation in a ForestDataset. Internally, trees are stored as essentially
Expand Down
Loading
Loading