Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
af5c7ea
Updating C++ and Python docs
andrewherren Dec 16, 2024
0762310
Updated python documentation and added methods to reset tree prior pa…
andrewherren Dec 27, 2024
f22ed19
Updated python package docs
andrewherren Dec 31, 2024
c1b4293
Updated R version
andrewherren Dec 31, 2024
4900146
Fixed latex error in bcf docs
andrewherren Dec 31, 2024
9be9fe7
Updated R package URL in DESCRIPTION file
andrewherren Jan 1, 2025
58f3c5f
Deleted commented-out C++ code
andrewherren Jan 2, 2025
3208ead
Deleting more commented C++ code
andrewherren Jan 2, 2025
3eb26c6
Move doxyfile to top level project directory and delete deprecated cp…
andrewherren Jan 3, 2025
8fd27ff
Adding documentation for more of the C++ core
andrewherren Jan 3, 2025
043bac4
Updated C++ documentation
andrewherren Jan 6, 2025
61633b7
Updated C++ docs
andrewherren Jan 6, 2025
c49309e
Expanded C++ leaf model documentation
andrewherren Jan 8, 2025
32d27f1
Updated leaf model documentation
andrewherren Jan 8, 2025
ece6ad2
Fix typo in leaf model docs
andrewherren Jan 8, 2025
97eb1f3
Updated leaf model documentation
andrewherren Jan 10, 2025
ab83025
Merge branch 'main' into documentation-updates
andrewherren Jan 14, 2025
ecb8c08
Updated code formatting in python docs
andrewherren Jan 14, 2025
61c3130
Updating docs
andrewherren Jan 14, 2025
7c76405
Fixed typo
andrewherren Jan 14, 2025
d6f3ebd
Fixed some CRAN checks and added CRAN checks to R github action
andrewherren Jan 15, 2025
bfda50e
Fix quoting error in the R CMD check action
andrewherren Jan 15, 2025
5f125ce
Add rcmdcheck as R dependency
andrewherren Jan 15, 2025
5dc5f54
Remove obsolete functions from pkgdown yml
andrewherren Jan 15, 2025
f412ee8
Simplify logic in factory functions with constrained (enum) inputs
andrewherren Jan 15, 2025
324dcf3
Updated R examples to avoid lines running over 100 characters
andrewherren Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .github/workflows/r-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::testthat, any::decor
extra-packages: any::testthat, any::decor, any::rcmdcheck
needs: check

- name: Run unit tests
- name: Create a CRAN-ready version of the R package
run: |
Rscript cran-bootstrap.R
Rscript -e 'testthat::test_local("stochtree_cran")'
Rscript cran-bootstrap.R 0

- uses: r-lib/actions/check-r-package@v2
with:
working-directory: 'stochtree_cran'
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: stochtree
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
Version: 0.0.0.9000
Version: 0.0.1
Authors@R:
c(
person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),
Expand All @@ -17,6 +17,7 @@ RoxygenNote: 7.3.2
LinkingTo:
cpp11, BH
Suggests:
testthat (>= 3.0.0),
doParallel,
foreach,
ggplot2,
Expand All @@ -26,12 +27,11 @@ Suggests:
MASS,
mvtnorm,
rmarkdown,
testthat (>= 3.0.0),
tgp
VignetteBuilder: knitr
SystemRequirements: C++17
Imports:
R6,
stats
URL: https://stochastictree.github.io/stochtree-r/
URL: https://stochtree.ai
Config/testthat/edition: 3
17 changes: 10 additions & 7 deletions cpp_docs/Doxyfile → Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ EXTRACT_PACKAGE = NO
# included in the documentation.
# The default value is: NO.

EXTRACT_STATIC = NO
EXTRACT_STATIC = YES

# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined
# locally in source files will be included in the documentation. If set to NO,
Expand Down Expand Up @@ -588,7 +588,7 @@ RESOLVE_UNNAMED_PARAMS = YES
# section is generated. This option has no effect if EXTRACT_ALL is enabled.
# The default value is: NO.

HIDE_UNDOC_MEMBERS = NO
HIDE_UNDOC_MEMBERS = YES

# If the HIDE_UNDOC_CLASSES tag is set to YES, Doxygen will hide all
# undocumented classes that are normally visible in the class hierarchy. If set
Expand Down Expand Up @@ -687,7 +687,7 @@ INLINE_INFO = YES
# name. If set to NO, the members will appear in declaration order.
# The default value is: YES.

SORT_MEMBER_DOCS = YES
SORT_MEMBER_DOCS = NO

# If the SORT_BRIEF_DOCS tag is set to YES then Doxygen will sort the brief
# descriptions of file, namespace and class members alphabetically by member
Expand Down Expand Up @@ -740,7 +740,7 @@ STRICT_PROTO_MATCHING = NO
# list. This list is created by putting \todo commands in the documentation.
# The default value is: YES.

GENERATE_TODOLIST = YES
GENERATE_TODOLIST = NO

# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test
# list. This list is created by putting \test commands in the documentation.
Expand Down Expand Up @@ -965,7 +965,7 @@ WARN_LOGFILE =
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched.

INPUT =
INPUT = include/stochtree

# This tag can be used to specify the character encoding of the source files
# that Doxygen parses. Internally Doxygen uses the UTF-8 encoding. Doxygen uses
Expand Down Expand Up @@ -1081,7 +1081,10 @@ EXCLUDE_PATTERNS = */test/* \
# wildcard * is used, a substring. Examples: ANamespace, AClass,
# ANamespace::AClass, ANamespace::*Test

EXCLUDE_SYMBOLS = StochTree::CommonC
EXCLUDE_SYMBOLS = StochTree::CommonC \
StochTree::CategorySampleTracker \
StochTree::ExtractMultipleFeaturesFromMemory \
StochTree::ExtractSingleFeatureFromMemory

# The EXAMPLE_PATH tag can be used to specify one or more files or directories
# that contain example code fragments that are included (see the \include
Expand Down Expand Up @@ -1805,7 +1808,7 @@ FORMULA_MACROFILE =
# The default value is: NO.
# This tag requires that the tag GENERATE_HTML is set to YES.

USE_MATHJAX = NO
USE_MATHJAX = YES

# With MATHJAX_VERSION it is possible to specify the MathJax version to be used.
# Note that the different versions of MathJax have different requirements with
Expand Down
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ export(oneHotEncode)
export(oneHotInitializeAndEncode)
export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(preprocessBartParams)
export(preprocessBcfParams)
export(preprocessParams)
export(preprocessPredictionData)
export(preprocessPredictionDataFrame)
Expand Down
8 changes: 4 additions & 4 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,9 @@ predict.bartmodel <- function(bart, X_test, W_test = NULL, group_ids_test = NULL
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train,
#' group_ids_train = group_ids_train, rfx_basis_train = rfx_basis_train,
#' X_test = X_test, group_ids_test = group_ids_test, rfx_basis_test = rfx_basis_test,
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
#' group_ids_train = group_ids_train, group_ids_test = group_ids_test,
#' rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100)
#' rfx_samples <- getRandomEffectSamples(bart_model)
getRandomEffectSamples.bartmodel <- function(object, ...){
Expand Down Expand Up @@ -1180,7 +1180,7 @@ convertBARTStateToJson <- function(param_list, mean_forest = NULL, variance_fore
jsonobj$add_forest(mean_forest)
}
if (param_list$include_variance_forest) {
jsonobj$add_forest(object$variance_forests)
jsonobj$add_forest(variance_forest)
}

# Add sampled parameters
Expand Down
60 changes: 32 additions & 28 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -1395,36 +1395,28 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU
rfx_basis_test <- matrix(rep(1, nrow(X_test)), ncol = 1)
}

# Add propensities to any covariate set
if (bcf$model_params$propensity_covariate == "both") {
X_test_mu <- cbind(X_test, pi_test)
X_test_tau <- cbind(X_test, pi_test)
} else if (bcf$model_params$propensity_covariate == "mu") {
X_test_mu <- cbind(X_test, pi_test)
X_test_tau <- X_test
} else if (bcf$model_params$propensity_covariate == "tau") {
X_test_mu <- X_test
X_test_tau <- cbind(X_test, pi_test)
# Add propensities to covariate set if necessary
if (bcf$model_params$propensity_covariate != "none") {
X_test_combined <- cbind(X_test, pi_test)
}

# Create prediction datasets
prediction_dataset_mu <- createForestDataset(X_test_mu)
prediction_dataset_tau <- createForestDataset(X_test_tau, Z_test)
forest_dataset_pred <- createForestDataset(X_test_combined, Z_test)

# Compute forest predictions
num_samples <- bcf$model_params$num_samples
y_std <- bcf$model_params$outcome_scale
y_bar <- bcf$model_params$outcome_mean
initial_sigma2 <- bcf$model_params$initial_sigma2
mu_hat_test <- bcf$forests_mu$predict(prediction_dataset_mu)*y_std + y_bar
mu_hat_test <- bcf$forests_mu$predict(forest_dataset_pred)*y_std + y_bar
if (bcf$model_params$adaptive_coding) {
tau_hat_test_raw <- bcf$forests_tau$predict_raw(prediction_dataset_tau)
tau_hat_test_raw <- bcf$forests_tau$predict_raw(forest_dataset_pred)
tau_hat_test <- t(t(tau_hat_test_raw) * (bcf$b_1_samples - bcf$b_0_samples))*y_std
} else {
tau_hat_test <- bcf$forests_tau$predict_raw(prediction_dataset_tau)*y_std
tau_hat_test <- bcf$forests_tau$predict_raw(forest_dataset_pred)*y_std
}
if (bcf$model_params$include_variance_forest) {
s_x_raw <- bcf$variance_forests$predict(prediction_dataset)
s_x_raw <- bcf$variance_forests$predict(forest_dataset_pred)
}

# Compute rfx predictions (if needed)
Expand Down Expand Up @@ -1520,14 +1512,16 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' rfx_samples <- getRandomEffectSamples(bcf_model)
getRandomEffectSamples.bcf <- function(object, ...){
result = list()
Expand Down Expand Up @@ -1607,14 +1601,16 @@ getRandomEffectSamples.bcf <- function(object, ...){
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' # bcf_json <- convertBCFModelToJson(bcf_model)
convertBCFModelToJson <- function(object){
jsonobj <- createCppJson()
Expand Down Expand Up @@ -1749,14 +1745,16 @@ convertBCFModelToJson <- function(object){
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' # saveBCFModelToJsonFile(bcf_model, "test.json")
saveBCFModelToJsonFile <- function(object, filename){
# Convert to Json
Expand Down Expand Up @@ -1823,14 +1821,16 @@ saveBCFModelToJsonFile <- function(object, filename){
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' # saveBCFModelToJsonString(bcf_model)
saveBCFModelToJsonString <- function(object){
# Convert to Json
Expand Down Expand Up @@ -1899,14 +1899,16 @@ saveBCFModelToJsonString <- function(object){
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' # bcf_json <- convertBCFModelToJson(bcf_model)
#' # bcf_model_roundtrip <- createBCFModelFromJson(bcf_json)
createBCFModelFromJson <- function(json_object){
Expand Down Expand Up @@ -2045,14 +2047,16 @@ createBCFModelFromJson <- function(json_object){
#' rfx_basis_train <- rfx_basis[train_inds,]
#' rfx_term_test <- rfx_term[test_inds]
#' rfx_term_train <- rfx_term[train_inds]
#' bcf_params <- list(sample_sigma_leaf_mu = TRUE, sample_sigma_leaf_tau = FALSE)
#' mu_params <- list(sample_sigma_leaf = TRUE)
#' tau_params <- list(sample_sigma_leaf = FALSE)
#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train,
#' pi_train = pi_train, group_ids_train = group_ids_train,
#' rfx_basis_train = rfx_basis_train, X_test = X_test,
#' Z_test = Z_test, pi_test = pi_test, group_ids_test = group_ids_test,
#' rfx_basis_test = rfx_basis_test,
#' num_gfr = 100, num_burnin = 0, num_mcmc = 100,
#' params = bcf_params)
#' mu_forest_params = mu_params,
#' tau_forest_params = tau_params)
#' # saveBCFModelToJsonFile(bcf_model, "test.json")
#' # bcf_model_roundtrip <- createBCFModelFromJsonFile("test.json")
createBCFModelFromJsonFile <- function(json_filename){
Expand Down
16 changes: 16 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,22 @@ tree_prior_cpp <- function(alpha, beta, min_samples_leaf, max_depth) {
.Call(`_stochtree_tree_prior_cpp`, alpha, beta, min_samples_leaf, max_depth)
}

update_alpha_tree_prior_cpp <- function(tree_prior_ptr, alpha) {
invisible(.Call(`_stochtree_update_alpha_tree_prior_cpp`, tree_prior_ptr, alpha))
}

update_beta_tree_prior_cpp <- function(tree_prior_ptr, beta) {
invisible(.Call(`_stochtree_update_beta_tree_prior_cpp`, tree_prior_ptr, beta))
}

update_min_samples_leaf_tree_prior_cpp <- function(tree_prior_ptr, min_samples_leaf) {
invisible(.Call(`_stochtree_update_min_samples_leaf_tree_prior_cpp`, tree_prior_ptr, min_samples_leaf))
}

update_max_depth_tree_prior_cpp <- function(tree_prior_ptr, max_depth) {
invisible(.Call(`_stochtree_update_max_depth_tree_prior_cpp`, tree_prior_ptr, max_depth))
}

forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
}
Expand Down
2 changes: 1 addition & 1 deletion R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
}

# Preprocess forest indices
num_forests <- forest_container$num_samples()
num_forests <- model_object$model_params$num_samples
if (is.null(forest_inds)) {
forest_inds <- as.integer(1:num_forests)
} else {
Expand Down
32 changes: 32 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ ForestModel <- R6::R6Class(
#' @return NULL
propagate_residual_update = function(residual) {
propagate_trees_column_vector_cpp(self$tracker_ptr, residual$data_ptr)
},

#' @description
#' Update alpha in the tree prior
#' @param alpha New value of alpha to be used
#' @return NULL
update_alpha = function(alpha) {
update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha)
},

#' @description
#' Update beta in the tree prior
#' @param beta New value of beta to be used
#' @return NULL
update_beta = function(beta) {
update_beta_tree_prior_cpp(self$tree_prior_ptr, beta)
},

#' @description
#' Update min_samples_leaf in the tree prior
#' @param min_samples_leaf New value of min_samples_leaf to be used
#' @return NULL
update_min_samples_leaf = function(min_samples_leaf) {
update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, min_samples_leaf)
},

#' @description
#' Update max_depth in the tree prior
#' @param max_depth New value of max_depth to be used
#' @return NULL
update_max_depth = function(max_depth) {
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth)
}
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/serialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ loadRandomEffectSamplesCombinedJsonString <- function(json_string_list, json_rfx
json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num)
json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num)
invisible(output <- RandomEffectSamples$new())
for (i in 1:length(json_object_list)) {
for (i in 1:length(json_string_list)) {
json_string <- json_string_list[[i]]
if (i == 1) {
output$load_from_json_string(json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label)
Expand Down
Loading
Loading