Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
efaa3ab
WIP implementation of multi-threading in GFR using OpenMP
andrewherren Jul 1, 2025
f47873c
Refactored GFR sampler to enable threads to make better use of shared…
andrewherren Jul 2, 2025
ee156fb
Updated stochtree to use multi-threading in sifting observations in t…
andrewherren Jul 3, 2025
2b74ba8
Propagating multi-threading updates to python, temporarily removing e…
andrewherren Jul 22, 2025
ad5b6bb
Update CI to handle stochtree changes expecting OpenMP and autoconf f…
andrewherren Jul 22, 2025
75baf9d
Updating CI workflows
andrewherren Jul 22, 2025
01b6c67
Updated CMakeLists for OpenMP
andrewherren Jul 22, 2025
3f43669
Updated CMakeLists
andrewherren Jul 22, 2025
0eec468
Updating CMakeLists
andrewherren Jul 22, 2025
a84ce21
Updated MACOSX_DEPLOYMENT_TARGET for PyPI wheel build workflow
andrewherren Jul 22, 2025
92461b4
Added windows specific makevars and config scripts for the R package
andrewherren Jul 22, 2025
ecfa4cc
Updated windows configure script
andrewherren Jul 22, 2025
a39eda1
Updated end of line for various R build scripts
andrewherren Jul 22, 2025
7673744
Handling line endings in code and in windows GHA
andrewherren Jul 23, 2025
c62185c
Removing unused headers
andrewherren Jul 23, 2025
61ba8b9
Added BCF profiling script
andrewherren Jul 25, 2025
64e2f7c
Basic regression testing setup, manually dispatched
andrewherren Jul 25, 2025
9c8b865
Merge branch 'main' into multi-threaded-sampling
andrewherren Jul 28, 2025
5778e21
Moving bart regression tests to subfolder
andrewherren Jul 29, 2025
e8978b6
Updated regression testing workflows to include R BCF
andrewherren Aug 2, 2025
46e49f3
Updated regression tests
andrewherren Aug 11, 2025
73bd635
Refactored predict method for BART and BCF to include labels for pred…
andrewherren Aug 11, 2025
a85cfcb
Updated regression testing framework and added python regression tests
andrewherren Aug 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
* text=auto

*.c text eol=lf
*.h text eol=lf
*.cc text eol=lf
*.cuh text eol=lf
*.cu text eol=lf
*.py text eol=lf
*.txt text eol=lf
*.R text eol=lf

*.sh text eol=lf
*.ac text eol=lf

*.md text eol=lf
*.csv text eol=lf
7 changes: 7 additions & 0 deletions .github/workflows/cpp-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ jobs:
shell: bash
run: |
echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT"

- name: Set up dependencies (linux clang)
# Set up openMP on ubuntu-latest with clang compiler toolset (doesn't ship with the compiler suite like GCC and MSVC)
if: matrix.os == 'ubuntu-latest' && matrix.cpp_compiler == 'clang++'
run: |
sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Configure CMake
# Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make.
Expand All @@ -69,6 +75,7 @@ jobs:
-DUSE_SANITIZER=OFF
-DBUILD_TEST=ON
-DBUILD_DEBUG_TARGETS=OFF
-DUSE_OPENMP=ON
-S ${{ github.workspace }}

- name: Build
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/pypi-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,37 @@ jobs:
include:
- os: ubuntu-latest
cibw_archs: "x86_64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: ubuntu-24.04-arm
cibw_archs: "aarch64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: windows-latest
cibw_archs: "auto64"
macos_deployment_target: "10.13" # Unused, just setting the variable as a placeholder
- os: macos-13
cibw_archs: "x86_64"
macos_deployment_target: "13.0"
- os: macos-14
cibw_archs: "arm64"
macos_deployment_target: "14.0"

steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Set up openmp (macos)
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
if: matrix.os == 'macos-13' || matrix.os == 'macos-14'
run: |
brew install libomp

- name: Build wheels
uses: pypa/cibuildwheel@v2.23.2
env:
CIBW_SKIP: "pp* *-musllinux_* *-win32"
CIBW_ARCHS: ${{ matrix.cibw_archs }}
MACOSX_DEPLOYMENT_TARGET: "10.13"
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos_deployment_target }}

- uses: actions/upload-artifact@v4
with:
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ jobs:
with:
python-version: "3.10"
cache: "pip"

- name: Set up openmp (macos)
# Set up openMP on MacOS since it doesn't ship with the apple clang compiler suite
if: matrix.os == 'macos-latest'
run: |
brew install libomp

- name: Install Package with Relevant Dependencies
run: |
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/r-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- name: Prevent conversion of line endings on Windows
if: startsWith(matrix.os, 'windows')
shell: pwsh
run: git config --global core.autocrlf false

- uses: actions/checkout@v4
with:
submodules: 'recursive'
Expand Down
59 changes: 59 additions & 0 deletions .github/workflows/regression-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
on:
workflow_dispatch:

name: Running stochtree on benchmark datasets

jobs:
stochtree_r:
name: stochtree-r-bart-regression-test
runs-on: ubuntu-latest

steps:
- name: Checkout stochtree repo
uses: actions/checkout@v4
with:
submodules: 'recursive'

- name: Setup pandoc
uses: r-lib/actions/setup-pandoc@v2

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- name: Create a properly formatted version of the stochtree R package in a subfolder
run: |
Rscript cran-bootstrap.R 0 0 1

- name: Setup R dependencies
uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::testthat, any::decor, local::stochtree_cran

- name: Create output directory for BART regression test results
run: |
mkdir -p tools/regression/bart/stochtree_bart_r_results
mkdir -p tools/regression/bcf/stochtree_bcf_r_results

- name: Run the BART regression test benchmark suite
run: |
Rscript tools/regression/bart/regression_test_dispatch_bart.R
Rscript tools/regression/bcf/regression_test_dispatch_bcf.R

- name: Collate and analyze regression test results
run: |
Rscript tools/regression/bart/regression_test_analysis_bart.R
Rscript tools/regression/bcf/regression_test_analysis_bcf.R

- name: Store BART benchmark test results as an artifact of the run
uses: actions/upload-artifact@v4
with:
name: stochtree-r-bart-summary
path: tools/regression/bart/stochtree_bart_r_results/stochtree_bart_r_summary.csv

- name: Store BCF benchmark test results as an artifact of the run
uses: actions/upload-artifact@v4
with:
name: stochtree-r-bcf-summary
path: tools/regression/bcf/stochtree_bcf_r_results/stochtree_bcf_r_summary.csv
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ po/*~
# RStudio Connect folder
rsconnect/

# Configuration files generated by R build
config.status
config.log
src/Makevars

## Python gitignore

# Byte-compiled / optimized / DLL files
Expand Down
83 changes: 69 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Build options
option(USE_DEBUG "Set to ON for Debug mode" OFF)
option(USE_DEBUG "Build with debug symbols and without optimization" OFF)
option(USE_SANITIZER "Use santizer flags" OFF)
option(USE_OPENMP "Use openMP" ON)
option(USE_HOMEBREW_FALLBACK "(macOS-only) also look in 'brew --prefix' for libraries (e.g. OpenMP)" ON)
option(BUILD_TEST "Build C++ tests with Google Test" OFF)
option(BUILD_DEBUG_TARGETS "Build Standalone C++ Programs for Debugging" ON)
option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
Expand All @@ -9,8 +11,8 @@ option(BUILD_PYTHON "Build Shared Library for Python Package" OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Default to CMake 3.16
cmake_minimum_required(VERSION 3.16)
# Default to CMake 3.20
cmake_minimum_required(VERSION 3.20)

# Define the project
project(stochtree LANGUAGES C CXX)
Expand All @@ -34,6 +36,13 @@ if(USE_DEBUG)
add_definitions(-DDEBUG)
endif()

# Linker flags (empty by default, updated if using openmp)
set(
STOCHTREE_LINK_FLAGS
""
)

# Unix / MinGW compiler flags
if(UNIX OR MINGW OR CYGWIN)
set(
CMAKE_CXX_FLAGS
Expand All @@ -42,11 +51,12 @@ if(UNIX OR MINGW OR CYGWIN)
if(USE_DEBUG)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O3")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unknown-pragmas -Wno-unused-private-field")
endif()

# MSVC compiler flags
if(MSVC)
set(
variables
Expand All @@ -72,6 +82,33 @@ else()
endif()
endif()

# OpenMP
if(USE_OPENMP)
add_definitions(-DSTOCHTREE_OPENMP_AVAILABLE)
if(APPLE)
find_package(OpenMP)
if(NOT OpenMP_FOUND)
if(USE_HOMEBREW_FALLBACK)
execute_process(COMMAND brew --prefix libomp
OUTPUT_VARIABLE HOMEBREW_LIBOMP_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(OpenMP_C_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp -I${HOMEBREW_LIBOMP_PREFIX}/include")
set(OpenMP_C_INCLUDE_DIR "")
set(OpenMP_CXX_INCLUDE_DIR "")
set(OpenMP_C_LIB_NAMES libomp)
set(OpenMP_CXX_LIB_NAMES libomp)
set(OpenMP_libomp_LIBRARY ${HOMEBREW_LIBOMP_PREFIX}/lib/libomp.dylib)
endif()
find_package(OpenMP REQUIRED)
endif()
else()
find_package(OpenMP REQUIRED)
endif()
# Update flags with openmp
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
endif()

# Header file directory
set(StochTree_HEADER_DIR ${PROJECT_SOURCE_DIR}/include)

Expand All @@ -80,6 +117,8 @@ set(BOOSTMATH_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/boost_math/include)

# Eigen header file directory
set(EIGEN_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/eigen)
add_definitions(-DEIGEN_MPL2_ONLY)
add_definitions(-DEIGEN_DONT_PARALLELIZE)

# fast_double_parser header file directory
set(FAST_DOUBLE_PARSER_HEADER_DIR ${PROJECT_SOURCE_DIR}/deps/fast_double_parser/include)
Expand Down Expand Up @@ -109,10 +148,11 @@ file(
add_library(stochtree_objs OBJECT ${SOURCES})

# Include the headers in the source library
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})

if(APPLE)
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
if(USE_OPENMP)
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(stochtree_objs PRIVATE ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(stochtree_objs PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
endif()

# Python shared library
Expand All @@ -122,8 +162,13 @@ if (BUILD_PYTHON)
pybind11_add_module(stochtree_cpp src/py_stochtree.cpp)

# Link to C++ source and headers
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
if(USE_OPENMP)
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(stochtree_cpp PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(stochtree_cpp PRIVATE stochtree_objs)
endif()

# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
# define (VERSION_INFO) here.
Expand Down Expand Up @@ -154,8 +199,13 @@ if(BUILD_TEST)
file(GLOB CPP_TEST_SOURCES test/cpp/*.cpp)
add_executable(teststochtree ${CPP_TEST_SOURCES})
set(STOCHTREE_TEST_HEADER_DIR ${PROJECT_SOURCE_DIR}/test/cpp)
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
if(USE_OPENMP)
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(teststochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${STOCHTREE_TEST_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(teststochtree PRIVATE stochtree_objs GTest::gtest_main)
endif()
gtest_discover_tests(teststochtree)
endif()

Expand All @@ -164,7 +214,12 @@ if(BUILD_DEBUG_TARGETS)
# Build test suite
add_executable(debugstochtree debug/api_debug.cpp)
set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/debug)
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
if(USE_OPENMP)
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR} ${OpenMP_CXX_INCLUDE_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs ${OpenMP_libomp_LIBRARY})
else()
target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR} ${FAST_DOUBLE_PARSER_HEADER_DIR} ${FMT_HEADER_DIR})
target_link_libraries(debugstochtree PRIVATE stochtree_objs)
endif()
endif()

16 changes: 14 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix.
#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`.
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
#'
#' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
#'
Expand Down Expand Up @@ -130,7 +131,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_working_parameter_prior_cov = NULL,
rfx_group_parameter_prior_cov = NULL,
rfx_variance_prior_shape = 1,
rfx_variance_prior_scale = 1
rfx_variance_prior_scale = 1,
num_threads = -1
)
general_params_updated <- preprocessParams(
general_params_default, general_params
Expand Down Expand Up @@ -186,6 +188,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov
rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape
rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale
num_threads <- general_params_updated$num_threads

# 2. Mean forest parameters
num_trees_mean <- mean_forest_params_updated$num_trees
Expand Down Expand Up @@ -795,7 +798,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
global_model_config = global_model_config, num_threads = num_threads,
keep_forest = keep_sample, gfr = TRUE
)

# Cache train set predictions since they are already computed during sampling
Expand Down Expand Up @@ -1272,15 +1276,23 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
result <- list()
if ((object$model_params$has_rfx) || (object$model_params$include_mean_forest)) {
result[["y_hat"]] = y_hat
} else {
result[["y_hat"]] <- NULL
}
if (object$model_params$include_mean_forest) {
result[["mean_forest_predictions"]] = mean_forest_predictions
} else {
result[["mean_forest_predictions"]] <- NULL
}
if (object$model_params$has_rfx) {
result[["rfx_predictions"]] = rfx_predictions
} else {
result[["rfx_predictions"]] <- NULL
}
if (object$model_params$include_variance_forest) {
result[["variance_forest_predictions"]] = variance_forest_predictions
} else {
result[["variance_forest_predictions"]] <- NULL
}
return(result)
}
Expand Down
Loading
Loading