Skip to content

Commit 744ca40

Browse files
authored
Merge pull request #135 from StochasticTree/model-config-refactor
Refactor model parameters into "config" objects to future-proof low-level interface
2 parents 0a9352f + 1e73f92 commit 744ca40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2121
-352
lines changed

.github/workflows/r-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939

4040
- name: Create a CRAN-ready version of the R package
4141
run: |
42-
Rscript cran-bootstrap.R 0 0
42+
Rscript cran-bootstrap.R 0 0 1
4343
4444
- uses: r-lib/actions/check-r-package@v2
4545
with:

DESCRIPTION

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: stochtree
2-
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
2+
Title: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
33
Version: 0.1.0
44
Authors@R:
55
c(
@@ -10,7 +10,11 @@ Authors@R:
1010
person("Jingyu", "He", role = "aut"),
1111
person("stochtree contributors", role = c("cph"))
1212
)
13-
Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference.
13+
Description: Flexible stochastic tree ensemble software. Robust implementations of
14+
Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>)
15+
for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) <doi:10.1214/19-BA1195>)
16+
for causal inference. Enables model serialization and parallel sampling
17+
and provides a low-level interface for custom stochastic forest samplers.
1418
License: MIT + file LICENSE
1519
Encoding: UTF-8
1620
Roxygen: list(markdown = TRUE)

LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
YEAR: 2024
2-
COPYRIGHT HOLDER: stochtree authors
1+
YEAR: 2025
2+
COPYRIGHT HOLDER: stochtree contributors

LICENSE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22

3-
Copyright (c) 2024 stochtree authors
3+
Copyright (c) 2023-2025 stochtree authors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ export(createCppRNG)
2828
export(createForest)
2929
export(createForestDataset)
3030
export(createForestModel)
31+
export(createForestModelConfig)
3132
export(createForestSamples)
33+
export(createGlobalModelConfig)
3234
export(createOutcome)
3335
export(createPreprocessorFromJson)
3436
export(createPreprocessorFromJsonString)

NEWS.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
11
# stochtree 0.1.0
22

3-
* Initial CRAN submission.
3+
* Initial release on CRAN.
4+
* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR)
5+
* High-level model types supported:
6+
* Supervised learning with constant leaves or user-specified leaf regression models
7+
* Causal effect estimation with binary, continuous, or multivariate treatments
8+
* Additional high-level modeling features:
9+
* Forest-based variance function estimation (heteroskedasticity)
10+
* Additive (univariate or multivariate) group random effects
11+
* Multi-chain sampling and support for parallelism
12+
* "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm
13+
* Automated preprocessing / handling of categorical variables
14+
* Low-level interface:
15+
* Ability to combine a forest sampler with other (additive) model terms, without using C++
16+
* Combine and sample an arbitrary number of forests or random effects terms

R/bart.R

Lines changed: 67 additions & 25 deletions
Large diffs are not rendered by default.

R/bcf.R

Lines changed: 135 additions & 53 deletions
Large diffs are not rendered by default.

R/config.R

Lines changed: 395 additions & 0 deletions
Large diffs are not rendered by default.

R/cpp11.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
556556
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
557557
}
558558

559-
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
560-
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
559+
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
560+
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
561561
}
562562

563-
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
564-
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
563+
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) {
564+
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest))
565565
}
566566

567567
sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {

0 commit comments

Comments
 (0)