diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d1efe55..08957562 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,6 +139,7 @@ file( src/io.cpp src/json11.cpp src/leaf_model.cpp + src/ordinal_sampler.cpp src/partition_tracker.cpp src/random_effects.cpp src/tree.cpp diff --git a/DESCRIPTION b/DESCRIPTION index 15e4c0c0..1e1b9806 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 LinkingTo: cpp11, BH Suggests: diff --git a/NAMESPACE b/NAMESPACE index 2f4103c0..a4062f5e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ S3method(predict,bcfmodel) export(bart) export(bcf) export(calibrateInverseGammaErrorVariance) +export(cloglog_ordinal_bart) export(computeForestLeafIndices) export(computeForestLeafVariances) export(computeForestMaxLeafIndex) diff --git a/R/cloglog_ordinal_bart.R b/R/cloglog_ordinal_bart.R new file mode 100644 index 00000000..7220774d --- /dev/null +++ b/R/cloglog_ordinal_bart.R @@ -0,0 +1,218 @@ +#' Run the BART algorithm for ordinal outcomes using a complementary log-log link +#' +#' @param X A numeric matrix of predictors (training data). +#' @param y A numeric vector of ordinal outcomes (positive integers starting from 1). +#' @param X_test An optional numeric matrix of predictors (test data). +#' @param n_trees Number of trees in the BART ensemble. Default: `50`. +#' @param num_gfr Number of GFR samples to draw at the beginning of the sampler. Default: `0`. +#' @param num_burnin Number of burn-in MCMC samples to discard. Default: `1000`. +#' @param num_mcmc Total number of MCMC samples to draw. Default: `500`. +#' @param n_thin Thinning interval for MCMC samples. Default: `1`. +#' @param alpha_gamma Shape parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param beta_gamma Rate parameter for the log-gamma prior on cutpoints. Default: `2.0`. +#' @param variable_weights (Optional) vector of variable weights for splitting (default: equal weights). +#' @param feature_types (Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous). +#' @param seed (Optional) random seed for reproducibility. +#' @param num_threads (Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1. +#' @export +cloglog_ordinal_bart <- function(X, y, X_test = NULL, + n_trees = 50, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 500, + n_thin = 1, + alpha_gamma = 2.0, + beta_gamma = 2.0, + variable_weights = NULL, + feature_types = NULL, + seed = NULL, + num_threads = 1) { + # BART parameters + alpha_bart <- 0.95 + beta_bart <- 2 + min_samples_in_leaf <- 5 + max_depth <- 10 + scale_leaf <- 2 / sqrt(n_trees) + cutpoint_grid_size <- 100 # Needed for stochtree::sample_gfr_one_iteration_cpp, not used in MCMC BART + + # Fixed for identifiability (can be pass as argument later if desired) + gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 + + # Determine whether a test dataset is provided + has_test <- !is.null(X_test) + + # Data checks + if (!is.matrix(X)) X <- as.matrix(X) + if (!is.numeric(y)) y <- as.numeric(y) + if (has_test && !is.matrix(X_test)) X_test <- as.matrix(X_test) + + n_samples <- nrow(X) + n_features <- ncol(X) + + if (any(y < 1) || any(y != round(y))) { + stop("Ordinal outcome y must contain positive integers starting from 1") + } + + # Convert from 1-based (R) to 0-based (C++) indexing + ordinal_outcome <- as.integer(y - 1) + n_levels <- max(y) # Number of ordinal categories + + if (n_levels < 2) { + stop("Ordinal outcome must have at least 2 categories") + } + + if (is.null(variable_weights)) { + variable_weights <- rep(1.0, n_features) + } + + if (is.null(feature_types)) { + feature_types <- rep(0L, n_features) + } + + if (!is.null(seed)) { + set.seed(seed) + } + + # Indices of MCMC samples to keep after GFR, burn-in, and thinning + keep_idx <- seq(num_gfr + num_burnin + 1, num_gfr + num_burnin + num_mcmc, by = n_thin) + n_keep <- length(keep_idx) + + # Storage for MCMC samples + forest_pred_train <- matrix(0, n_samples, n_keep) + if (has_test) { + n_samples_test <- nrow(X_test) + forest_pred_test <- matrix(0, n_samples_test, n_keep) + } + gamma_samples <- matrix(0, n_levels - 1, n_keep) + latent_samples <- matrix(0, n_samples, n_keep) + + # Initialize samplers + ordinal_sampler <- stochtree:::ordinal_sampler_cpp() + rng <- stochtree::createCppRNG(if (is.null(seed)) sample.int(.Machine$integer.max, 1) else seed) + + # Initialize other model structures as before + dataX <- stochtree::createForestDataset(X) + if (has_test) { + dataXtest <- stochtree::createForestDataset(X_test) + } + outcome_data <- stochtree::createOutcome(as.numeric(ordinal_outcome)) + active_forest <- stochtree::createForest(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + active_forest$set_root_leaves(0.0) + split_prior <- stochtree:::tree_prior_cpp(alpha_bart, beta_bart, min_samples_in_leaf, max_depth) + forest_samples <- stochtree::createForestSamples(as.integer(n_trees), 1L, TRUE, FALSE) # Use constant leaves + forest_tracker <- stochtree:::forest_tracker_cpp( + dataX$data_ptr, + as.integer(feature_types), + as.integer(n_trees), + as.integer(n_samples) + ) + + # Latent variable (Z in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(nrow(X)) + # Forest predictions (eta in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(nrow(X)) + # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + dataX$add_auxiliary_dimension(n_levels - 1) + # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + # This auxiliary series is designed so that the element stored at position `i` + # corresponds to the sum of all exponentiated gamma_j values for j < i. + # It has n_levels elements instead of n_levels - 1 because even the largest + # categorical index has a valid value of sum_{j < i} exp(gamma_j) + dataX$add_auxiliary_dimension(n_levels) + + # Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing) + initial_gamma <- rep(0.0, n_levels - 1) + for (i in seq_along(initial_gamma)) { + dataX$set_auxiliary_data_value(2, i - 1, initial_gamma[i]) + } + + # Convert the log-scale parameters into cumulative exponentiated parameters. + # This is done under the hood in a C++ function for efficiency. + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) + + # Initialize forest predictions to zero (slot 1) + for (i in 1:n_samples) { + dataX$set_auxiliary_data_value(1, i - 1, 0.0) + } + + # Initialize latent variables to zero (slot 0) + for (i in 1:n_samples) { + dataX$set_auxiliary_data_value(0, i - 1, 0.0) + } + + # Set up sweep indices for tree updates (sample all trees each iteration) + sweep_indices <- as.integer(seq(0, n_trees - 1)) + + sample_counter <- 0 + for (i in 1:(num_mcmc + num_burnin + num_gfr)) { + keep_sample <- i %in% keep_idx + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + + # 1. Sample forest using MCMC + if (i > num_gfr) { + stochtree:::sample_mcmc_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, + num_threads + ) + } else { + stochtree:::sample_gfr_one_iteration_cpp( + dataX$data_ptr, outcome_data$data_ptr, forest_samples$forest_container_ptr, + active_forest$forest_ptr, forest_tracker, split_prior, rng$rng_ptr, + sweep_indices, as.integer(feature_types), as.integer(cutpoint_grid_size), + scale_leaf, variable_weights, alpha_gamma, beta_gamma, 1.0, 4L, keep_sample, + ncol(X), num_threads + ) + } + + # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions + # This is needed for updating gamma parameters, latent z_i's + forest_pred_current <- active_forest$predict(dataX) + for (i in 1:n_samples) { + dataX$set_auxiliary_data_value(1, i - 1, forest_pred_current[i]); + } + + # 2. Sample latent z_i's using truncated exponential + stochtree:::ordinal_sampler_update_latent_variables_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, rng$rng_ptr + ) + + # 3. Sample gamma parameters + stochtree:::ordinal_sampler_update_gamma_params_cpp( + ordinal_sampler, dataX$data_ptr, outcome_data$data_ptr, + alpha_gamma, beta_gamma, gamma_0, rng$rng_ptr + ) + + # 4. Update cumulative sum of exp(gamma) values + stochtree:::ordinal_sampler_update_cumsum_exp_cpp(ordinal_sampler, dataX$data_ptr) + + if (keep_sample) { + forest_pred_train[, sample_counter] <- active_forest$predict(dataX) + if (has_test) { + forest_pred_test[, sample_counter] <- active_forest$predict(dataXtest) + } + gamma_current <- dataX$get_auxiliary_data_vector(2) + gamma_samples[, sample_counter] <- gamma_current + latent_current <- dataX$get_auxiliary_data_vector(0) + latent_samples[, sample_counter] <- latent_current + } + } + + result <- list( + forest_predictions_train = forest_pred_train, + forest_predictions_test = if (has_test) forest_pred_test else NULL, + gamma_samples = gamma_samples, + latent_samples = latent_samples, + scale_leaf = scale_leaf, + ordinal_outcome = ordinal_outcome, + n_trees = n_trees, + n_keep = n_keep + ) + + class(result) <- "cloglog_ordinal_bart" + return(result) +} diff --git a/R/cpp11.R b/R/cpp11.R index d77c7472..c714cc75 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -56,6 +56,30 @@ forest_dataset_get_variance_weights_cpp <- function(dataset_ptr) { .Call(`_stochtree_forest_dataset_get_variance_weights_cpp`, dataset_ptr) } +forest_dataset_has_auxiliary_dimension_cpp <- function(dataset_ptr, dim_idx) { + .Call(`_stochtree_forest_dataset_has_auxiliary_dimension_cpp`, dataset_ptr, dim_idx) +} + +forest_dataset_add_auxiliary_dimension_cpp <- function(dataset_ptr, dim_size) { + invisible(.Call(`_stochtree_forest_dataset_add_auxiliary_dimension_cpp`, dataset_ptr, dim_size)) +} + +forest_dataset_get_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx) { + .Call(`_stochtree_forest_dataset_get_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx) +} + +forest_dataset_set_auxiliary_data_value_cpp <- function(dataset_ptr, dim_idx, element_idx, value) { + invisible(.Call(`_stochtree_forest_dataset_set_auxiliary_data_value_cpp`, dataset_ptr, dim_idx, element_idx, value)) +} + +forest_dataset_get_auxiliary_data_vector_cpp <- function(dataset_ptr, dim_idx) { + .Call(`_stochtree_forest_dataset_get_auxiliary_data_vector_cpp`, dataset_ptr, dim_idx) +} + +forest_dataset_store_auxiliary_data_vector_as_column_cpp <- function(dataset_ptr, output_matrix, dim_idx, matrix_col_idx) { + .Call(`_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp`, dataset_ptr, output_matrix, dim_idx, matrix_col_idx) +} + create_column_vector_cpp <- function(outcome) { .Call(`_stochtree_create_column_vector_cpp`, outcome) } @@ -692,6 +716,22 @@ sample_without_replacement_integer_cpp <- function(population_vector, sampling_p .Call(`_stochtree_sample_without_replacement_integer_cpp`, population_vector, sampling_probs, sample_size) } +ordinal_sampler_cpp <- function() { + .Call(`_stochtree_ordinal_sampler_cpp`) +} + +ordinal_sampler_update_latent_variables_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_latent_variables_cpp`, sampler_ptr, data_ptr, outcome_ptr, rng_ptr)) +} + +ordinal_sampler_update_gamma_params_cpp <- function(sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_gamma_params_cpp`, sampler_ptr, data_ptr, outcome_ptr, alpha_gamma, beta_gamma, gamma_0, rng_ptr)) +} + +ordinal_sampler_update_cumsum_exp_cpp <- function(sampler_ptr, data_ptr) { + invisible(.Call(`_stochtree_ordinal_sampler_update_cumsum_exp_cpp`, sampler_ptr, data_ptr)) +} + init_json_cpp <- function() { .Call(`_stochtree_init_json_cpp`) } diff --git a/R/data.R b/R/data.R index 13cd714f..f2aa46aa 100644 --- a/R/data.R +++ b/R/data.R @@ -108,6 +108,59 @@ ForestDataset <- R6::R6Class( #' @return True if variance weights are loaded, false otherwise has_variance_weights = function() { return(dataset_has_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has auxiliary data stored at the dimension indicated + #' @param dim_idx Dimension of auxiliary data + #' @return True if auxiliary data has been allocated for `dim_idx` False otherwise + has_auxiliary_dimension = function(dim_idx) { + return(forest_dataset_has_auxiliary_dimension_cpp(self$data_ptr, dim_idx)) + }, + + #' @description + #' Initialize a new dimension / lane of auxiliary data and allocate data in its place + #' @param dim_size Size of the new vector of data to allocate + #' @return None + add_auxiliary_dimension = function(dim_size) { + return(forest_dataset_add_auxiliary_dimension_cpp(self$data_ptr, dim_size)) + }, + + #' @description + #' Retrieve auxiliary data value + #' @param dim_idx Dimension from which data value to be retrieved + #' @param element_idx Element to retrieve from dimension `dim_idx` + #' @return Floating point value stored in the requested auxiliary data space + get_auxiliary_data_value = function(dim_idx, element_idx) { + return(forest_dataset_get_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx)) + }, + + #' @description + #' Set auxiliary data value + #' @param dim_idx Dimension in which data value to be set + #' @param element_idx Element to set within dimension `dim_idx` + #' @param value Data value to set at auxiliary data dimension `dim_idx` and element `element_idx` + #' @return None + set_auxiliary_data_value = function(dim_idx, element_idx, value) { + return(forest_dataset_set_auxiliary_data_value_cpp(self$data_ptr, dim_idx, element_idx, value)) + }, + + #' @description + #' Retrieve entire auxiliary data vector + #' @param dim_idx Dimension to retrieve + #' @return Vector of all of the auxiliary data stored at dimension `dim_idx` + get_auxiliary_data_vector = function(dim_idx) { + return(forest_dataset_get_auxiliary_data_vector_cpp(self$data_ptr, dim_idx)) + }, + + #' @description + #' Retrieve auxiliary data vector and place it into a column of the supplied matrix + #' @param output_matrix Matrix to be overwritten + #' @param dim_idx Auxiliary data dimension to retrieve + #' @param matrix_col_idx Matrix column in which to copy auxiliary data + #' @return Vector of all of the auxiliary data stored at dimension `dim_idx` + store_auxiliary_data_vector_matrix = function(output_matrix, dim_idx, matrix_col_idx) { + return(forest_dataset_store_auxiliary_data_vector_as_column_cpp(self$data_ptr, output_matrix, dim_idx, matrix_col_idx)) } ) ) diff --git a/R/model.R b/R/model.R index 38df5970..a6553dd6 100644 --- a/R/model.R +++ b/R/model.R @@ -378,7 +378,6 @@ createForestModel <- function( )) } - #' Draw `sample_size` samples from `population_vector` without replacement, weighted by `sampling_probabilities` #' #' @param population_vector Vector from which to draw samples. diff --git a/RC_README.md b/RC_README.md new file mode 100644 index 00000000..2d677fd0 --- /dev/null +++ b/RC_README.md @@ -0,0 +1,16 @@ +# Release Candidate for StochTree Cloglog BART + +This branch serves as a staging / testing zone for the planned incorporation of BART / BCF with a complementary log-log link function into `stochtree`. + +## Installation + +The cloglog release candidate version of `stochtree` can be installed from github via + +``` +remotes::install_github("StochasticTree/stochtree", ref="cloglog-bart-rc") +``` + +## Vignettes and Demos + +Before incorporating this functionality into `stochtree`, we intend to develop a rich set of vignettes. +We have included demo scripts for the cloglog model on synthetic ordinal data with 2, 3 and 4 categories in the `tools` subfolder of this branch. diff --git a/demo/debug/cloglog_ordinary_bart_three_category.py b/demo/debug/cloglog_ordinary_bart_three_category.py new file mode 100644 index 00000000..19e90a98 --- /dev/null +++ b/demo/debug/cloglog_ordinary_bart_three_category.py @@ -0,0 +1,229 @@ +# Load libraries +import numpy as np +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from stochtree import CloglogOrdinalBARTModel + +# Set seed +seed = 2025 +rng = np.random.default_rng(seed) + +# Sample size and number of predictors +n = 2000 +p = 5 + +# Design matrix and true lambda function +X = rng.normal(0, 1, size=(n, p)) +beta = np.repeat(1 / np.sqrt(p), p) +true_lambda_function = X @ beta + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories = 3 +gamma_true = np.array([-2, 1]) +ordinal_cutpoints = np.log(np.cumsum(np.exp(gamma_true))) +print("Ordinal cutpoints:", ordinal_cutpoints) + +# True ordinal class probabilities +true_probs = np.zeros((n, n_categories)) +for j in range(n_categories): + if j == 0: + true_probs[:, j] = 1 - np.exp(-np.exp(gamma_true[j] + true_lambda_function)) + elif j == n_categories - 1: + true_probs[:, j] = 1 - np.sum(true_probs[:, :j], axis=1) + else: + true_probs[:, j] = np.exp(-np.exp(gamma_true[j - 1] + true_lambda_function)) * ( + 1 - np.exp(-np.exp(gamma_true[j] + true_lambda_function)) + ) +print(f"Probability distribution: {np.mean(true_probs, axis=0)}") + +# Generate ordinal outcomes +y = np.zeros(n, dtype=int) +for i in range(n): + y[i] = rng.choice(np.arange(n_categories), p=true_probs[i, :]) +print(f"Outcome distribution: {np.bincount(y)}") + +# Train-test split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] + +# Run cloglog ordinal BART model +bart_model = CloglogOrdinalBARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + n_trees=50, + num_gfr=0, + num_burnin=1000, + num_mcmc=500, + n_thin=1, +) + +# Traceplots of cutoff parameters +plt.subplot(1, 2, 1) +plt.plot(bart_model.gamma_samples[0, :], linestyle="-", label=r"$\gamma_1$") +plt.subplot(1, 2, 2) +plt.plot(bart_model.gamma_samples[1, :], linestyle="-", label=r"$\gamma_2$") +plt.show() + +# Histograms of cutoff parameters +plt.clf() +gamma1 = bart_model.gamma_samples[0, :] + np.mean(bart_model.forest_pred_train, axis=0) +plt.subplot(1, 2, 1) +plt.hist(gamma1, bins=30, edgecolor="black") +gamma2 = bart_model.gamma_samples[1, :] + np.mean(bart_model.forest_pred_train, axis=0) +plt.subplot(1, 2, 2) +plt.hist(gamma2, bins=30, edgecolor="black") +plt.show() + +# Traceplots of cutoff parameters combined with average forest predictions +plt.clf() +plt.subplot(1, 2, 1) +plt.plot(gamma1, linestyle="-", label=r"$\gamma_1$") +plt.axhline( + y=gamma_true[0] + np.mean(true_lambda_function[train_inds]), + color="red", + linestyle="--", +) +plt.subplot(1, 2, 2) +plt.plot(gamma2, linestyle="-", label=r"$\gamma_2$") +plt.axhline( + y=gamma_true[1] + np.mean(true_lambda_function[train_inds]), + color="red", + linestyle="--", +) +plt.show() + +# Compare forest predictions with the truth (for training and test sets) + +# Train set +plt.clf() +lambda_pred_train = np.mean(bart_model.forest_pred_train, axis=1) - np.mean( + bart_model.forest_pred_train +) +plt.subplot(1, 2, 1) +plt.plot(lambda_pred_train, true_lambda_function[train_inds], "o") +plt.xlabel("Predicted lambda (train)") +plt.ylabel("True lambda (train)") +plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) +cor_train = np.corrcoef(true_lambda_function[train_inds], lambda_pred_train)[0, 1] +plt.text( + min(true_lambda_function[train_inds]), + max(true_lambda_function[train_inds]), + f"Correlation: {round(cor_train, 3)}", +) + +# Test set +lambda_pred_test = np.mean(bart_model.forest_pred_test, axis=1) - np.mean( + bart_model.forest_pred_test +) +plt.subplot(1, 2, 2) +plt.plot(lambda_pred_test, true_lambda_function[test_inds], "o") +plt.xlabel("Predicted lambda (test)") +plt.ylabel("True lambda (test)") +plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) +cor_test = np.corrcoef(true_lambda_function[test_inds], lambda_pred_test)[0, 1] +plt.text( + min(true_lambda_function[test_inds]), + max(true_lambda_function[test_inds]), + f"Correlation: {round(cor_test, 3)}", +) +plt.show() + +# Estimated ordinal class probabilities for the training set +est_probs_train = np.zeros((len(train_inds), n_categories)) +for j in range(n_categories): + if j == 0: + est_probs_train[:, j] = np.mean( + 1 + - np.exp( + -np.exp(bart_model.forest_pred_train + bart_model.gamma_samples[j, :]) + ), + axis=1, + ) + elif j == n_categories - 1: + est_probs_train[:, j] = 1 - np.sum(est_probs_train[:, :j], axis=1) + else: + est_probs_train[:, j] = np.mean( + np.exp( + -np.exp( + bart_model.forest_pred_train + bart_model.gamma_samples[j - 1, :] + ) + ) + * ( + 1 + - np.exp( + -np.exp( + bart_model.forest_pred_train + bart_model.gamma_samples[j, :] + ) + ) + ), + axis=1, + ) + +# Plot estimated vs true class probabilities for training set +plt.clf() +for j in range(n_categories): + plt.subplot(1, n_categories, j + 1) + plt.plot(est_probs_train[:, j], true_probs[train_inds, j], "o") + plt.xlabel("Predicted prob (train)") + plt.ylabel("True prob (train)") + plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) + cor_train = np.corrcoef(est_probs_train[:, j], true_probs[train_inds, j])[0, 1] + plt.text( + min(est_probs_train[:, j]), + max(true_probs[train_inds, j]), + f"Correlation: {round(cor_train, 3)}", + ) + plt.show() + +# Estimated ordinal class probabilities for the training set +est_probs_test = np.zeros((len(test_inds), n_categories)) +for j in range(n_categories): + if j == 0: + est_probs_test[:, j] = np.mean( + 1 + - np.exp( + -np.exp(bart_model.forest_pred_test + bart_model.gamma_samples[j, :]) + ), + axis=1, + ) + elif j == n_categories - 1: + est_probs_test[:, j] = 1 - np.sum(est_probs_test[:, :j], axis=1) + else: + est_probs_test[:, j] = np.mean( + np.exp( + -np.exp( + bart_model.forest_pred_test + bart_model.gamma_samples[j - 1, :] + ) + ) + * ( + 1 + - np.exp( + -np.exp( + bart_model.forest_pred_test + bart_model.gamma_samples[j, :] + ) + ) + ), + axis=1, + ) + +# Plot estimated vs true class probabilities for test set +plt.clf() +for j in range(n_categories): + plt.subplot(1, n_categories, j + 1) + plt.plot(est_probs_test[:, j], true_probs[test_inds, j], "o") + plt.xlabel("Predicted prob (test)") + plt.ylabel("True prob (test)") + plt.axline((0, 0), slope=1, color="blue", linestyle=(0, (3, 3))) + cor_test = np.corrcoef(est_probs_test[:, j], true_probs[test_inds, j])[0, 1] + plt.text( + min(est_probs_test[:, j]), + max(true_probs[test_inds, j]), + f"Correlation: {round(cor_test, 3)}", + ) + plt.show() diff --git a/include/stochtree/data.h b/include/stochtree/data.h index a6061f4b..393203b1 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -469,6 +469,32 @@ class ForestDataset { if (exponentiate) var_weights_.SetElement(row_id, std::exp(new_value)); else var_weights_.SetElement(row_id, new_value); } + /*! + * \brief Auxiliary data management methods + * Methods to initialize, get, and set auxiliary data for BART models with more structure than the ``classic`` conjugate-Gaussian leaf BART model + */ + void AddAuxiliaryDimension(int dim_size) { + if (!has_auxiliary_data_) has_auxiliary_data_ = true; + auxiliary_data_.resize(num_auxiliary_dims_ + 1); + auxiliary_data_[num_auxiliary_dims_].resize(dim_size); + num_auxiliary_dims_++; + } + double GetAuxiliaryDataValue(int dim_idx, data_size_t element_idx) { + return auxiliary_data_[dim_idx][element_idx]; + } + void SetAuxiliaryDataValue(int dim_idx, data_size_t element_idx, double value) { + auxiliary_data_[dim_idx][element_idx] = value; + } + std::vector& GetAuxiliaryDataVector(int dim_idx) { + return auxiliary_data_[dim_idx]; + } + const std::vector& GetAuxiliaryDataVectorConst(int dim_idx) { + return auxiliary_data_[dim_idx]; + } + bool HasAuxiliaryDimension(int dim_idx) { + return (num_auxiliary_dims_ > dim_idx) & (dim_idx >= 0); + } + private: ColumnMatrix covariates_; ColumnMatrix basis_; @@ -479,6 +505,13 @@ class ForestDataset { bool has_covariates_{false}; bool has_basis_{false}; bool has_var_weights_{false}; + + /*! + * \brief Vector of vectors to track (potentially jagged) auxiliary data for complex BART models + */ + std::vector> auxiliary_data_; + int num_auxiliary_dims_{0}; + bool has_auxiliary_data_{false}; }; /*! \brief API for loading and accessing data used to sample (additive) random effects */ diff --git a/include/stochtree/leaf_model.h b/include/stochtree/leaf_model.h index 5359775d..4ea38014 100644 --- a/include/stochtree/leaf_model.h +++ b/include/stochtree/leaf_model.h @@ -23,311 +23,311 @@ namespace StochTree { -/*! +/*! * \defgroup leaf_model_group Leaf Model API - * + * * \brief Classes / functions for implementing leaf models. - * - * Stochastic tree algorithms are all essentially hierarchical - * models with an adaptive group structure defined by an ensemble - * of decision trees. Each novel model is governed by - * + * + * Stochastic tree algorithms are all essentially hierarchical + * models with an adaptive group structure defined by an ensemble + * of decision trees. Each novel model is governed by + * * - A `LeafModel` class, defining the integrated likelihood and posterior, conditional on a particular tree structure * - A `SuffStat` class that tracks and accumulates sufficient statistics necessary for a `LeafModel` - * - * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. + * + * To provide a thorough overview of this interface (and, importantly, how to extend it), we must introduce some mathematical notation. * Any forest-based regression model involves an outcome, which we'll call \f$y\f$, and features (or "covariates"), which we'll call \f$X\f$. - * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. - * - * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model + * Our goal is to predict \f$y\f$ as a function of \f$X\f$, which we'll call \f$f(X)\f$. + * + * NOTE: if we have a more complicated, but still additive, model, such as \f$y = X\beta + f(X)\f$, then we can just model * \f$y - X\beta = f(X)\f$, treating the residual \f$y - X\beta\f$ as the outcome data, and we are back to the general setting above. - * - * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, + * + * Now, since \f$f(X)\f$ is an additive tree ensemble, we can think of it as the sum of \f$b\f$ separate decision tree functions, * where \f$b\f$ is the number of trees in an ensemble, so that - * + * * \f[ * f(X) = f_1(X) + \dots + f_b(X) * \f] - * - * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation - * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process + * + * and each decision tree function \f$f_j\f$ has the property that features \f$X\f$ are used to determine which leaf node an observation + * falls into, and then the parameters attached to that leaf node are used to compute \f$f_j(X)\f$. The exact mechanics of this process * are model-dependent, so now we introduce the "leaf node" models that `stochtree` supports. * * \section gaussian_constant_leaf_model Gaussian Constant Leaf Model - * + * * The most standard and common tree ensemble is a sum of "constant leaf" trees, in which a leaf node's parameter uniquely determines the prediction - * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then - * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, + * for all observations that fall into that leaf. For example, if leaf 2 for a tree is reached by the conditions that \f$X_1 < 0.4 \; \& \; X_2 > 0.6\f$, then + * every observation whose first feature is less than 0.4 and whose second feature is greater than 0.6 will receive the same prediction. Mathematically, * for an observation \f$i\f$ this looks like - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \mu_{\ell} * \f] - * + * * where \f$L\f$ denotes the indices of every leaf node, \f$\mu_{\ell}\f$ is the parameter attached to leaf node \f$\ell\f$, and \f$\mathbb{1}(X \in \ell)\f$ * checks whether \f$X_i\f$ falls into leaf node \f$\ell\f$. - * + * * The way that we make such a model "stochastic" is by attaching to the leaf node parameters \f$\mu_{\ell}\f$ a "prior" distribution. - * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) - * as well as its "XBART" extension (He and Hahn (2023)). + * This leaf model corresponds to the "classic" BART model of Chipman et al (2010) + * as well as its "XBART" extension (He and Hahn (2023)). * We assign each leaf node parameter a prior - * + * * \f[ * \mu \sim N\left(0, \tau\right) * \f] - * - * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), - * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by - * + * + * Assuming a homoskedastic Gaussian outcome likelihood (i.e. \f$y_i \sim N\left(f(X_i),\sigma^2\right)\f$), + * the log marginal likelihood in this model, for the outcome data in node \f$\ell\f$ of tree \f$j\f$ is given by + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - n_{\ell}\log(\sigma) + \frac{1}{2} \log\left(\frac{\sigma^2}{n_{\ell} \tau + \sigma^2}\right) - \frac{s_{yy,\ell}}{2\sigma^2} + \frac{\tau s_{y,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{y,\ell} = \sum_{i : X_i \in \ell} r_i * \f] - * + * * \f[ * s_{yy,\ell} = \sum_{i : X_i \in \ell} r_i^2 * \f] - * + * * \f[ * r_i = y_i - \sum_{k \neq j} f_k(X_i) * \f] * - * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, - * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for + * In words, this model depends on the data for a given leaf node only through three sufficient statistics, \f$n_{\ell}\f$, \f$s_{y,\ell}\f$, and \f$s_{yy,\ell}\f$, + * and it only depends on the other trees in the ensemble through the "partial residual" \f$r_i\f$. The posterior distribution for * node \f$\ell\f$'s leaf parameter is similarly defined as: - * + * * \f[ * \mu_{\ell} \mid - \sim N\left(\frac{\tau s_{y,\ell}}{n_{\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{n_{\ell} \tau + \sigma^2}\right) * \f] - * - * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or + * + * Now, consider the possibility that each observation carries a unique weight \f$w_i\f$. These could be "case weights" in a survey context or * individual-level variances ("heteroskedasticity"). These case weights transform the outcome distribution (and associated likelihood) to - * + * * \f[ - * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) + * y_i \mid - \sim N\left(\mu(X_i), \frac{\sigma^2}{w_i}\right) * \f] - * - * This gives a modified log marginal likelihood of - * + * + * This gives a modified log marginal likelihood of + * * \f[ * L(y) = -\frac{n_{\ell}}{2}\log(2\pi) - \frac{1}{2} \sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right) + \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) - \frac{s_{wyy,\ell}}{2\sigma^2} + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(s_{w,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{w,\ell} = \sum_{i : X_i \in \ell} w_i * \f] - * + * * \f[ * s_{wy,\ell} = \sum_{i : X_i \in \ell} w_i r_i * \f] - * + * * \f[ * s_{wyy,\ell} = \sum_{i : X_i \in \ell} w_i r_i^2 * \f] - * - * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, - * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. - * - * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ - * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored + * + * Finally, note that when we consider splitting leaf \f$\ell\f$ into new left and right leaves, or pruning two nodes into a single leaf node, + * we compute the log marginal likelihood of the combined data and the log marginal likelihoods of the left and right leaves and compare these three values. + * + * The terms \f$\frac{n_{\ell}}{2}\log(2\pi)\f$, \f$\sum_{i : X_i \in \ell} \log\left(\frac{\sigma^2}{w_i}\right)\f$, and \f$\frac{s_{wyy,\ell}}{2\sigma^2}\f$ + * are such that their left and right node values will always sum to the respective value in the combined log marginal likelihood, so they can be ignored * when evaluating splits or prunes and thus the reduced log marginal likelihood is - * + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{w,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wy,\ell}^2}{2\sigma^2(n_{\ell} \tau + \sigma^2)} * \f] - * + * * So the \ref StochTree::GaussianConstantSuffStat "GaussianConstantSuffStat" class tracks a generalized version of these three statistics * (which allows for each observation to have a weight \f$w_i \neq 1\f$): - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{w,\ell}\f$: `double sum_w` * - \f$s_{wy,\ell}\f$: `double sum_yw` - * - * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the - * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and - * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. + * + * And these values are used by the \ref StochTree::GaussianConstantLeafModel "GaussianConstantLeafModel" class in the + * \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterMean "PosteriorParameterMean", and + * \ref StochTree::GaussianConstantLeafModel::PosteriorParameterVariance "PosteriorParameterVariance" methods. * To give one example, below is the implementation of \ref StochTree::GaussianConstantLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood": - * + * * \code{.cpp} * double left_log_ml = ( * -0.5*std::log(1 + tau_*(left_stat.sum_w/global_variance)) + ((tau_*left_stat.sum_yw*left_stat.sum_yw)/(2.0*global_variance*(tau_*left_stat.sum_w + global_variance))) * ); - * + * * double right_log_ml = ( * -0.5*std::log(1 + tau_*(right_stat.sum_w/global_variance)) + ((tau_*right_stat.sum_yw*right_stat.sum_yw)/(2.0*global_variance*(tau_*right_stat.sum_w + global_variance))) * ); - * + * * return left_log_ml + right_log_ml; - * \endcode - * + * \endcode + * * \section gaussian_multivariate_regression_leaf_model Gaussian Multivariate Regression Leaf Model - * - * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights + * + * In this model, the tree defines a "partitioned linear model" in which leaf node parameters define regression weights * that are multiplied by a "basis" \f$\Omega\f$ to determine the prediction for an observation. - * + * * \f[ * f_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \Omega_i \vec{\beta_{\ell}} * \f] - * + * * and we assign \f$\beta_{\ell}\f$ a prior of - * + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\vec{\beta_0}, \Sigma_0\right) * \f] - * + * * where \f$\vec{\beta_0}\f$ is typically a vector of zeros. The outcome likelihood is still - * + * * \f[ * y_i \sim N\left(f(X_i), \sigma^2\right) * \f] - * + * * This gives a reduced log integrated likelihood of - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \frac{\Sigma_0\Omega'\Omega}{\sigma^2}\right)\right) + \frac{1}{2}\frac{y'\Omega}{\sigma^2}\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\frac{\Omega'y}{\sigma^2} * \f] - * - * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is - * + * + * where \f$\Omega\f$ is a matrix of bases for every observation in leaf \f$\ell\f$ and \f$p\f$ is the dimension of \f$\Omega\f$. The posterior for \f$\vec{\beta_{\ell}}\f$ is + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\left(\frac{\Omega'y}{\sigma^2}\right),\left(\Sigma_0^{-1} + \frac{\Omega'\Omega}{\sigma^2}\right)^{-1}\right) * \f] - * + * * This is an extension of the single-tree model of Chipman et al (2002), with: - * + * * - Support for using a separate basis for leaf model than the partitioning (i.e. tree) model (i.e. \f$X \neq \Omega\f$) * - Support for multiple trees and sampling via grow-from-root (GFR) or MCMC - * + * * We can also enable heteroskedasticity by defining a (diagonal) covariance matrix for the outcome likelihood - * + * * \f[ * \Sigma_y = \text{diag}\left(\sigma^2 / w_1,\sigma^2 / w_2,\dots,\sigma^2 / w_n\right) * \f] - * + * * This updates the reduced log integrated likelihood to - * + * * \f[ * L(y) \propto - \frac{1}{2} \log\left(\textrm{det}\left(I_p + \Sigma_{0}\Omega'\Sigma_y^{-1}\Omega\right)\right) + \frac{1}{2}y'\Sigma_{y}^{-1}\Omega\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\Omega'\Sigma_{y}^{-1}y * \f] - * - * and a posterior for \f$\vec{\beta_{\ell}}\f$ of - * + * + * and a posterior for \f$\vec{\beta_{\ell}}\f$ of + * * \f[ * \vec{\beta_{\ell}} \sim N\left(\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\left(\Omega'\Sigma_{y}^{-1}y\right),\left(\Sigma_{0}^{-1} + \Omega'\Sigma_{y}^{-1}\Omega\right)^{-1}\right) * \f] - * + * * \section gaussian_univariate_regression_leaf_model Gaussian Univariate Regression Leaf Model - * + * * This specializes the Gaussian Multivariate Regression Leaf Model for a univariate leaf basis, which allows for several computational speedups (replacing generalized matrix operations with simple summation or sum-product operations). - * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf - * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: - * + * We simplify \f$\Omega\f$ to \f$\omega\f$, a univariate basis for every observation, so that \f$\Omega'\Omega = \sum_{i:i \in \ell}\omega_i^2\f$ and \f$\Omega'y = \sum_{i:i \in \ell}\omega_ir_i\f$. Similarly, the prior for the leaf + * parameter becomes univariate normal as in \ref gaussian_constant_leaf_model: + * * \f[ * \beta \sim N\left(0, \tau\right) * \f] - * - * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of - * + * + * Allowing for case / variance weights \f$w_i\f$ as above, we derive a reduced log marginal likelihood of + * * \f[ * L(y) \propto \frac{1}{2} \log\left(\frac{\sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) + \frac{\tau s_{wyx,\ell}^2}{2\sigma^2(s_{wxx,\ell} \tau + \sigma^2)} * \f] - * + * * where - * + * * \f[ * s_{wxx,\ell} = \sum_{i : X_i \in \ell} w_i \omega_i \omega_i * \f] - * + * * \f[ * s_{wyx,\ell} = \sum_{i : X_i \in \ell} w_i r_i \omega_i * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \beta_{\ell} \mid - \sim N\left(\frac{\tau s_{wyx,\ell}}{s_{wxx,\ell} \tau + \sigma^2}, \frac{\tau \sigma^2}{s_{wxx,\ell} \tau + \sigma^2}\right) * \f] - * + * * \section inverse_gamma_leaf_model Inverse Gamma Leaf Model - * - * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. + * + * Each of the above models is a variation on a theme: a conjugate, partitioned Gaussian leaf model. * The inverse gamma leaf model allows for forest-based heteroskedasticity modeling using an inverse gamma prior on the exponentiated leaf parameter, as discussed in Murray (2021) * Define a variance function based on an ensemble of \f$b\f$ trees as - * + * * \f[ * \sigma^2(X) = \exp\left(s_1(X) + \dots + s_b(X)\right) * \f] - * - * where each tree function \f$s_j(X)\f$ is defined as - * + * + * where each tree function \f$s_j(X)\f$ is defined as + * * \f[ * s_j(X_i) = \sum_{\ell \in L} \mathbb{1}(X_i \in \ell) \lambda_{\ell} * \f] - * + * * We reparameterize \f$\lambda_{\ell} = \log(\mu_{\ell})\f$ and we place an inverse gamma prior on \f$\mu_{\ell}\f$ - * + * * \f[ * \mu_{\ell} \sim \text{IG}\left(a, b\right) * \f] - * - * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification - * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual - * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. + * + * As noted in Murray (2021), this model no longer enables the "Bayesian backfitting" simplification + * of conjugated Gaussian leaf models, in which sampling updates for a given tree only depend on other trees in the ensemble via their imprint on the partial residual + * \f$r_i = y_i - \sum_{k \neq j} \mu_k(X_i)\f$. * However, this model is part of a broader class of models with convenient "blocked MCMC" sampling updates (another important example being multinomial classification). - * + * * Under an outcome model - * + * * \f[ * y \sim N\left(f(X), \sigma_0^2 \sigma^2(X)\right) * \f] - * + * * updates to \f$\mu_{\ell}\f$ for a given tree \f$j\f$ are based on a reduced log marginal likelihood of - * + * * \f[ * L(y) \propto a \log (b) - \log \Gamma (a) + \log \Gamma \left(a + \frac{n_{\ell}}{2}\right) - \left(a + \frac{n_{\ell}}{2}\right) \left(b + \frac{s_{\sigma,\ell}}{2\sigma_0^2}\right) * \f] - * + * * where - * + * * \f[ * n_{\ell} = \sum_{i : X_i \in \ell} 1 * \f] - * + * * \f[ * s_{\sigma,\ell} = \sum_{i: i \in \ell} \frac{(y_i - f(X_i))^2}{\prod_{k \neq j} s_k(X_i)} * \f] - * - * and a posterior of - * + * + * and a posterior of + * * \f[ * \mu_{\ell} \mid - \sim \text{IG}\left( a + \frac{n_{\ell}}{2} , b + \frac{s_{\sigma,\ell}}{2\sigma_0^2} \right) * \f] - * + * * Thus, as above, we implement a sufficient statistic class (\ref StochTree::LogLinearVarianceSuffStat "LogLinearVarianceSuffStat"), which tracks - * + * * - \f$n_{\ell}\f$: `data_size_t n` * - \f$s_{\sigma,\ell}\f$: `double weighted_sum_ei` - * - * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the - * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and - * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. + * + * And these values are used by the \ref StochTree::LogLinearVarianceLeafModel "LogLinearVarianceLeafModel" class in the + * \ref StochTree::LogLinearVarianceLeafModel::SplitLogMarginalLikelihood "SplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood", + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterShape "PosteriorParameterShape", and + * \ref StochTree::LogLinearVarianceLeafModel::PosteriorParameterScale "PosteriorParameterScale" methods. * To give one example, below is the implementation of \ref StochTree::LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood "NoSplitLogMarginalLikelihood": - * + * * \code{.cpp} * double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); * double a_term = a_ + 0.5 * suff_stat.n; @@ -337,8 +337,8 @@ namespace StochTree { * double resid_term = a_term * log_b_term; * double log_ml = prior_terms + lgamma_a_term - resid_term; * return log_ml; - * \endcode - * + * \endcode + * * \{ */ @@ -347,12 +347,14 @@ namespace StochTree { * 2. `kUnivariateRegressionLeafGaussian`: Every leaf node has a zero-centered univariate normal prior and every leaf is a linear model, multiplying the leaf parameter by a (fixed) basis. * 3. `kMultivariateRegressionLeafGaussian`: Every leaf node has a multivariate normal prior, centered around the zero vector, and every leaf is a linear model, matrix-multiplying the leaf parameters by a (fixed) basis vector. * 4. `kLogLinearVariance`: Every leaf node has a inverse gamma prior and every leaf is constant. + * 5. `kCloglogOrdinal`: Every leaf node has a log-gamma prior and every leaf is constant. */ enum ModelType { - kConstantLeafGaussian, - kUnivariateRegressionLeafGaussian, - kMultivariateRegressionLeafGaussian, - kLogLinearVariance + kConstantLeafGaussian, + kUnivariateRegressionLeafGaussian, + kMultivariateRegressionLeafGaussian, + kLogLinearVariance, + kCloglogOrdinal }; /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model */ @@ -371,7 +373,7 @@ class GaussianConstantSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -398,7 +400,7 @@ class GaussianConstantSuffStat { } /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianConstantSuffStat& suff_stat) { @@ -408,7 +410,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -419,7 +421,7 @@ class GaussianConstantSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -430,7 +432,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -438,7 +440,7 @@ class GaussianConstantSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -457,14 +459,14 @@ class GaussianConstantLeafModel { public: /*! * \brief Construct a new GaussianConstantLeafModel object - * + * * \param tau Leaf node prior scale parameter */ GaussianConstantLeafModel(double tau) {tau_ = tau; normal_sampler_ = UnivariateNormalSampler();} ~GaussianConstantLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -472,28 +474,28 @@ class GaussianConstantLeafModel { double SplitLogMarginalLikelihood(GaussianConstantSuffStat& left_stat, GaussianConstantSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianConstantSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -506,7 +508,7 @@ class GaussianConstantLeafModel { void SetEnsembleRootPredictedValue(ForestDataset& dataset, TreeEnsemble* ensemble, double root_pred_value); /*! * \brief Set a new value for the leaf node scale parameter - * + * * \param tau Leaf node prior scale parameter */ void SetScale(double tau) {tau_ = tau;} @@ -535,7 +537,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -560,9 +562,9 @@ class GaussianUnivariateRegressionSuffStat { sum_xxw = 0.0; sum_yxw = 0.0; } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianUnivariateRegressionSuffStat& suff_stat) { @@ -572,7 +574,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -583,7 +585,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -594,7 +596,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -602,7 +604,7 @@ class GaussianUnivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -623,7 +625,7 @@ class GaussianUnivariateRegressionLeafModel { ~GaussianUnivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -631,28 +633,28 @@ class GaussianUnivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& left_stat, GaussianUnivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterMean(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterVariance(GaussianUnivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -679,7 +681,7 @@ class GaussianMultivariateRegressionSuffStat { Eigen::MatrixXd ytWX; /*! * \brief Construct a new GaussianMultivariateRegressionSuffStat object - * + * * \param basis_dim Size of the basis vector that defines the leaf regression */ GaussianMultivariateRegressionSuffStat(int basis_dim) { @@ -690,7 +692,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -715,9 +717,9 @@ class GaussianMultivariateRegressionSuffStat { XtWX = Eigen::MatrixXd::Zero(p, p); ytWX = Eigen::MatrixXd::Zero(1, p); } - /*! + /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(GaussianMultivariateRegressionSuffStat& suff_stat) { @@ -727,7 +729,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -738,7 +740,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -749,7 +751,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -757,7 +759,7 @@ class GaussianMultivariateRegressionSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -776,14 +778,14 @@ class GaussianMultivariateRegressionLeafModel { public: /*! * \brief Construct a new GaussianMultivariateRegressionLeafModel object - * + * * \param Sigma_0 Prior covariance, must have the same number of rows and columns as dimensions of the basis vector for the multivariate regression problem */ GaussianMultivariateRegressionLeafModel(Eigen::MatrixXd& Sigma_0) {Sigma_0_ = Sigma_0; multivariate_normal_sampler_ = MultivariateNormalSampler();} ~GaussianMultivariateRegressionLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -791,28 +793,28 @@ class GaussianMultivariateRegressionLeafModel { double SplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& left_stat, GaussianMultivariateRegressionSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double NoSplitLogMarginalLikelihood(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior mean. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::VectorXd PosteriorParameterMean(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior variance. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ Eigen::MatrixXd PosteriorParameterVariance(GaussianMultivariateRegressionSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` @@ -841,7 +843,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Accumulate data from observation `row_idx` into the sufficient statistics - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param outcome Data object containing the "partial" residual net of all the model's other mean terms, aside from `tree` * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state @@ -861,7 +863,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` - * + * * \param suff_stat Sufficient statistic to be added to the current sufficient statistics */ void AddSuffStatInplace(LogLinearVarianceSuffStat& suff_stat) { @@ -870,7 +872,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -880,7 +882,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` - * + * * \param lhs First sufficient statistic ("left hand side") * \param rhs Second sufficient statistic ("right hand side") */ @@ -890,7 +892,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than some threshold - * + * * \param threshold Value used to compute `n > threshold` */ bool SampleGreaterThan(data_size_t threshold) { @@ -898,7 +900,7 @@ class LogLinearVarianceSuffStat { } /*! * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold - * + * * \param threshold Value used to compute `n >= threshold` */ bool SampleGreaterThanEqual(data_size_t threshold) { @@ -919,7 +921,7 @@ class LogLinearVarianceLeafModel { ~LogLinearVarianceLeafModel() {} /*! * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. - * + * * \param left_stat Sufficient statistics of the left node formed by the proposed split * \param right_stat Sufficient statistics of the right node formed by the proposed split * \param global_variance Global error variance parameter @@ -927,7 +929,7 @@ class LogLinearVarianceLeafModel { double SplitLogMarginalLikelihood(LogLinearVarianceSuffStat& left_stat, LogLinearVarianceSuffStat& right_stat, double global_variance); /*! * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ @@ -935,21 +937,21 @@ class LogLinearVarianceLeafModel { double SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior shape parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterShape(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Leaf node posterior scale parameter. - * + * * \param suff_stat Sufficient statistics of the node being evaluated * \param global_variance Global error variance parameter */ double PosteriorParameterScale(LogLinearVarianceSuffStat& suff_stat, double global_variance); /*! * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters - * + * * \param dataset Data object containining training data, including covariates, leaf regression bases, and case weights * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state * \param residual Data object containing the "full" residual net of all the model's mean terms @@ -969,31 +971,207 @@ class LogLinearVarianceLeafModel { GammaSampler gamma_sampler_; }; -/*! - * \brief Unifying layer for disparate sufficient statistic class types - * - * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, - * GaussianMultivariateRegressionSuffStat, and LogLinearVarianceSuffStat - * as a combined "variant" type. See the std::variant documentation + +/*! \brief Sufficient statistic and associated operations for complementary log-log ordinal BART model */ +class CloglogOrdinalSuffStat { + public: + data_size_t n; + double sum_Y_less_K; + double other_sum; + + /*! + * \brief Construct a new CloglogOrdinalSuffStat object, setting all sufficient statistics to zero + */ + CloglogOrdinalSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Accumulate data from observation `row_idx` into the sufficient statistics + * + * \param dataset Data object containing training data, including covariates + * \param outcome Data object containing the original ordinal outcome values, which are used to compute sufficient statistics + * \param tracker Tracking data structures that speed up sampler operations, synchronized with `active_forest` tracking a forest's state + * \param row_idx Index of the training data observation from which the sufficient statistics should be updated + * \param tree_idx Index of the tree being updated in the course of this sufficient statistic update + */ + void IncrementSuffStat(ForestDataset& dataset, Eigen::VectorXd& outcome, ForestTracker& tracker, data_size_t row_idx, int tree_idx) { + n += 1; + + // Get ordinal outcome value for this observation + unsigned int y = static_cast(outcome(row_idx)); + + // Get auxiliary data from tracker (assuming types: 0=latents Z, 1=forest predictions, 2=cutpoints gamma, 3=cumsum exp of gamma) + double Z = dataset.GetAuxiliaryDataValue(0, row_idx); // latent variables Z + double lambda_minus = dataset.GetAuxiliaryDataValue(1, row_idx); // forest predictions excluding current tree + + // Get cutpoints gamma and cumulative sum of exp(gamma) + const std::vector& gamma = dataset.GetAuxiliaryDataVectorConst(2); // cutpoints gamma + const std::vector& seg = dataset.GetAuxiliaryDataVectorConst(3); // cumsum exp of gamma + + int K = gamma.size() + 1; // Number of ordinal categories + + if (y == K - 1) { + other_sum += std::exp(lambda_minus) * seg[y]; // checked and it's correct + } else { + sum_Y_less_K += 1.0; + other_sum += std::exp(lambda_minus) * (Z * std::exp(gamma[y]) + seg[y]); // checked and it's correct + } + } + + /*! + * \brief Reset all of the sufficient statistics to zero + */ + void ResetSuffStat() { + n = 0; + sum_Y_less_K = 0.0; + other_sum = 0.0; + } + + /*! + * \brief Increment the value of each sufficient statistic by the values provided by `suff_stat` + * + * \param suff_stat Sufficient statistic to be added to the current sufficient statistics + */ + void AddSuffStatInplace(CloglogOrdinalSuffStat& suff_stat) { + n += suff_stat.n; + sum_Y_less_K += suff_stat.sum_Y_less_K; + other_sum += suff_stat.other_sum; + } + + /*! + * \brief Set the value of each sufficient statistic to the sum of the values provided by `lhs` and `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void AddSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n + rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K + rhs.sum_Y_less_K; + other_sum = lhs.other_sum + rhs.other_sum; + } + + /*! + * \brief Set the value of each sufficient statistic to the difference between the values provided by `lhs` and those provided by `rhs` + * + * \param lhs First sufficient statistic ("left hand side") + * \param rhs Second sufficient statistic ("right hand side") + */ + void SubtractSuffStat(CloglogOrdinalSuffStat& lhs, CloglogOrdinalSuffStat& rhs) { + n = lhs.n - rhs.n; + sum_Y_less_K = lhs.sum_Y_less_K - rhs.sum_Y_less_K; + other_sum = lhs.other_sum - rhs.other_sum; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than some threshold + * + * \param threshold Value used to compute `n > threshold` + */ + bool SampleGreaterThan(data_size_t threshold) { + return n > threshold; + } + + /*! + * \brief Check whether accumulated sample size, `n`, is greater than or equal to some threshold + * + * \param threshold Value used to compute `n >= threshold` + */ + bool SampleGreaterThanEqual(data_size_t threshold) { + return n >= threshold; + } + + /*! + * \brief Return the sample size accumulated by a sufficient stat object + */ + data_size_t SampleSize() { + return n; + } +}; + +/*! \brief Marginal likelihood and posterior computation for complementary log-log ordinal BART model */ +class CloglogOrdinalLeafModel { + public: + /*! + * \brief Construct a new CloglogOrdinalLeafModel object + * + * \param a shape parameter for log-gamma prior on leaf parameters + * \param b rate parameter for log-gamma prior on leaf parameters + * Log-gamma density: f(x) = b^a / Gamma(a) * exp(a*x - b*exp(x)) + */ + CloglogOrdinalLeafModel(double a, double b) { + a_ = a; + b_ = b; + gamma_sampler_ = GammaSampler(); + } + ~CloglogOrdinalLeafModel() {} + + /*! + * \brief Log marginal likelihood for a proposed split, evaluated only for observations that fall into the node being split. + */ + double SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance); + + /*! + * \brief Log marginal likelihood of a node, evaluated only for observations that fall into the node being split. + */ + double NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Helper function to compute log marginal likelihood from sufficient statistics + */ + double SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior shape parameter for leaf node log-gamma distribution + */ + double PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Posterior rate parameter for leaf node log-gamma distribution + */ + double PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance); + + /*! + * \brief Draw new parameters for every leaf node in `tree`, using a Gibbs update that conditions on the data, every other tree in the forest, and all model parameters. + * Samples from log-gamma: sample from gamma, then take log. + */ + void SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen); + inline bool RequiresBasis() {return false;} + + private: + double a_; + double b_; + GammaSampler gamma_sampler_; +}; + +/*! \brief Unifying layer for disparate sufficient statistic class types + * + * Joins together GaussianConstantSuffStat, GaussianUnivariateRegressionSuffStat, + * GaussianMultivariateRegressionSuffStat, LogLinearVarianceSuffStat, and CloglogOrdinalSuffStat + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using SuffStatVariant = std::variant; +using SuffStatVariant = std::variant; /*! * \brief Unifying layer for disparate leaf model class types - * - * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, - * GaussianMultivariateRegressionLeafModel, and LogLinearVarianceLeafModel - * as a combined "variant" type. See the std::variant documentation + * + * Joins together GaussianConstantLeafModel, GaussianUnivariateRegressionLeafModel, + * GaussianMultivariateRegressionLeafModel, LogLinearVarianceLeafModel, and CloglogOrdinalLeafModel + * as a combined "variant" type. See the std::variant documentation * for more detail. */ -using LeafModelVariant = std::variant; +using LeafModelVariant = std::variant; template static inline SuffStatVariant createSuffStat(SuffStatConstructorArgs... leaf_suff_stat_args) { @@ -1007,7 +1185,7 @@ static inline LeafModelVariant createLeafModel(LeafModelConstructorArgs... leaf_ /*! * \brief Factory function that creates a new `SuffStat` object for the specified model type - * + * * \param model_type Enumeration storing the model type * \param basis_dim [Optional] dimension of the basis vector, only used if `model_type = kMultivariateRegressionLeafGaussian` */ @@ -1018,19 +1196,21 @@ static inline SuffStatVariant suffStatFactory(ModelType model_type, int basis_di return createSuffStat(); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createSuffStat(basis_dim); - } else { + } else if (model_type == kLogLinearVariance) { return createSuffStat(); + } else { + return createSuffStat(); } } /*! * \brief Factory function that creates a new `LeafModel` object for the specified model type - * + * * \param model_type Enumeration storing the model type - * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian` or `model_type = kUnivariateRegressionLeafGaussian` + * \param tau Value of the leaf node prior scale parameter, only used if `model_type = kConstantLeafGaussian`, `model_type = kUnivariateRegressionLeafGaussian` * \param Sigma0 Value of the leaf node prior covariance matrix, only used if `model_type = kMultivariateRegressionLeafGaussian` - * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` - * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` + * \param a Value of the leaf node inverse gamma prior shape parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior shape parameter, only used if `model_type = kCloglogOrdinal`) + * \param b Value of the leaf node inverse gamma prior scale parameter, only used if `model_type = kLogLinearVariance` (or value of the leaf node log-gamma prior rate parameter, only used if `model_type = kCloglogOrdinal`) */ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau, Eigen::MatrixXd& Sigma0, double a, double b) { if (model_type == kConstantLeafGaussian) { @@ -1039,15 +1219,17 @@ static inline LeafModelVariant leafModelFactory(ModelType model_type, double tau return createLeafModel(tau); } else if (model_type == kMultivariateRegressionLeafGaussian) { return createLeafModel(Sigma0); - } else { + } else if (model_type == kLogLinearVariance) { return createLeafModel(a, b); + } else { + return createLeafModel(a, b); } } template static inline void AccumulateSuffStatProposed( - SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, - ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, + SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, + ColumnVector& residual, double global_variance, TreeSplit& split, int tree_num, int leaf_num, int split_feature, int num_threads, SuffStatConstructorArgs&... suff_stat_args ) { // Determine the position of the node's indices in the forest tracking data structure @@ -1072,13 +1254,13 @@ static inline void AccumulateSuffStatProposed( std::vector thread_suff_stats_left; std::vector thread_suff_stats_right; for (int i = 0; i < num_threads; i++) { - thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, + thread_ranges[i] = std::make_pair(node_begin_index + i * chunk_size, node_begin_index + (i + 1) * chunk_size); thread_suff_stats_node.emplace_back(suff_stat_args...); thread_suff_stats_left.emplace_back(suff_stat_args...); thread_suff_stats_right.emplace_back(suff_stat_args...); } - + // Accumulate sufficient statistics StochTree::ParallelFor(0, num_threads, num_threads, [&](int i) { int start_idx = thread_ranges[i].first; @@ -1116,7 +1298,7 @@ static inline void AccumulateSuffStatProposed( } template -static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, +static inline void AccumulateSuffStatExisting(SuffStatType& node_suff_stat, SuffStatType& left_suff_stat, SuffStatType& right_suff_stat, ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, double global_variance, int tree_num, int split_node_id, int left_node_id, int right_node_id) { // Acquire iterators auto left_node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, left_node_id); @@ -1152,7 +1334,7 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo node_begin_iter = tracker.UnsortedNodeBeginIterator(tree_num, node_id); node_end_iter = tracker.UnsortedNodeEndIterator(tree_num, node_id); } - + // Accumulate sufficient statistics for (auto i = node_begin_iter; i != node_end_iter; i++) { auto idx = *i; @@ -1161,13 +1343,13 @@ static inline void AccumulateSingleNodeSuffStat(SuffStatType& node_suff_stat, Fo } template -static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, - ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, +static inline void AccumulateCutpointBinSuffStat(SuffStatType& left_suff_stat, ForestTracker& tracker, CutpointGridContainer& cutpoint_grid_container, + ForestDataset& dataset, ColumnVector& residual, double global_variance, int tree_num, int node_id, int feature_num, int cutpoint_num) { // Acquire iterators auto node_begin_iter = tracker.SortedNodeBeginIterator(node_id, feature_num); auto node_end_iter = tracker.SortedNodeEndIterator(node_id, feature_num); - + // Determine node start point data_size_t node_begin = tracker.SortedNodeBegin(node_id, feature_num); diff --git a/include/stochtree/ordinal_sampler.h b/include/stochtree/ordinal_sampler.h new file mode 100644 index 00000000..d67563e2 --- /dev/null +++ b/include/stochtree/ordinal_sampler.h @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2025 stochtree authors. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef STOCHTREE_ORDINAL_SAMPLER_H_ +#define STOCHTREE_ORDINAL_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace StochTree { + +static double sample_truncated_exponential_low_high(double u, double rate, double low, double high) { + return -std::log((1-u)*std::exp(-rate*low) + u*std::exp(-rate*high))/rate; +} + +static double sample_truncated_exponential_low(double u, double rate, double low) { + return -std::log((1-u)*std::exp(-rate*low))/rate; +} + +static double sample_truncated_exponential_high(double u, double rate, double high) { + return -std::log1p(u*std::expm1(-high*rate))/rate; +} + +static double sample_exponential(double u, double rate) { + return -std::log1p(-u)/rate; +} + +/*! + * \brief Sampler for ordinal model hyperparameters + * + * This class handles MCMC sampling for ordinal-specific parameters: + * - Truncated exponential latent variables (Z) + * - Cutpoint parameters (gamma) + * - Cumulative sum of exp(gamma) (seg) [derived parameter] + */ +class OrdinalSampler { + public: + OrdinalSampler() { + gamma_sampler_ = GammaSampler(); + } + ~OrdinalSampler() {} + + /*! + * \brief Sample from truncated exponential distribution + * + * Samples from exponential distribution truncated to [low,high] + * + * \param gen Random number generator + * \param rate Rate parameter for exponential distribution + * \param low Lower truncation bound + * \param high Upper truncation bound + * \return Sampled value from truncated exponential + */ + static double SampleTruncatedExponential(std::mt19937& gen, double rate, double low = 0.0, double high = 1.0); + + /*! + * \brief Update truncated exponential latent variables (Z) + * + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling + * \param outcome Vector of outcome values + * \param gen Random number generator + */ + void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen); + + /*! + * \brief Update gamma cutpoint parameters + * + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling + * \param outcome Vector of outcome values + * \param alpha_gamma Shape parameter for log-gamma prior on cutpoints gamma + * \param beta_gamma Rate parameter for log-gamma prior on cutpoints gamma + * \param gamma_0 Fixed value for first cutpoint parameter (for identifiability) + * \param gen Random number generator + */ + void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, + double gamma_0, std::mt19937& gen); + + /*! + * \brief Update cumulative exponential sums (seg) + * + * \param dataset Forest dataset containing training data (covariates) and auxiliary data needed for sampling + */ + void UpdateCumulativeExpSums(ForestDataset& dataset); + + private: + GammaSampler gamma_sampler_; +}; + +} // namespace StochTree + +#endif // STOCHTREE_ORDINAL_SAMPLER_H_ diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 0790d87a..f25c875c 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -32,7 +32,6 @@ #include #include -#include namespace StochTree { @@ -92,7 +91,6 @@ class ForestTracker { int GetNumTrees() {return num_trees_;} int GetNumFeatures() {return num_features_;} bool Initialized() {return initialized_;} - private: /*! \brief Mapper from observations to predicted values summed over every tree in a forest */ std::vector sum_predictions_; diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index 68c9c15a..b8101fd2 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -394,6 +394,40 @@ static inline void UpdateVarModelTree(ForestTracker& tracker, ForestDataset& dat } } +static inline void UpdateCLogLogModelTree(ForestTracker& tracker, ForestDataset& dataset, ColumnVector& residual, Tree* tree, int tree_num, + bool requires_basis, bool tree_new) { + data_size_t n = dataset.GetCovariates().rows(); + + double pred_value; + int32_t leaf_pred; + double pred_delta; + for (data_size_t i = 0; i < n; i++) { + if (tree_new) { + // If the tree has been newly sampled or adjusted, we must rerun the prediction + // method and update the SamplePredMapper stored in tracker + leaf_pred = tracker.GetNodeId(i, tree_num); + if (requires_basis) { + pred_value = tree->PredictFromNode(leaf_pred, dataset.GetBasis(), i); + } else { + pred_value = tree->PredictFromNode(leaf_pred); + } + pred_delta = pred_value - tracker.GetTreeSamplePrediction(i, tree_num); + tracker.SetTreeSamplePrediction(i, tree_num, pred_value); + tracker.SetSamplePrediction(i, tracker.GetSamplePrediction(i) + pred_delta); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num) + dataset.SetAuxiliaryDataValue(1, i, tracker.GetSamplePrediction(i) - pred_value); + } else { + // If the tree has not yet been modified via a sampling step, + // we can query its prediction directly from the SamplePredMapper stored in tracker + pred_value = tracker.GetTreeSamplePrediction(i, tree_num); + // Set auxiliary data slot 1 to forest predictions excluding the current tree (tree_num): needed? since tree not changed? + double current_lambda_hat = tracker.GetSamplePrediction(i); + double lambda_minus = current_lambda_hat - pred_value; + dataset.SetAuxiliaryDataValue(1, i, lambda_minus); + } + } +} + template static inline std::tuple EvaluateProposedSplit( ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, LeafModel& leaf_model, @@ -447,8 +481,10 @@ static inline std::tuple EvaluateExist template static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), false); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::plus(), false); } else { // TODO: think about a generic way to store "state" corresponding to the other models? @@ -458,8 +494,10 @@ static inline void AdjustStateBeforeTreeSampling(ForestTracker& tracker, LeafMod template static inline void AdjustStateAfterTreeSampling(ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, - ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { - if (backfitting) { + ColumnVector& residual, TreePrior& tree_prior, bool backfitting, Tree* tree, int tree_num) { + if constexpr (std::is_same_v) { + UpdateCLogLogModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), true); + } else if (backfitting) { UpdateMeanModelTree(tracker, dataset, residual, tree, tree_num, leaf_model.RequiresBasis(), std::minus(), true); } else { // TODO: think about a generic way to store "state" corresponding to the other models? @@ -776,6 +814,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). * \param num_features_subsample How many features to subsample when running the GFR algorithm. + * \param num_threads Number of threads to use for split evaluations and other compute-intensive operations. * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template @@ -932,7 +971,8 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM template static inline void MCMCPruneTreeOneIter(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, - TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + TreePrior& tree_prior, std::mt19937& gen, int tree_num, double global_variance, int num_threads, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Choose a "leaf parent" node at random int num_leaves = tree->NumLeaves(); int num_leaf_parents = tree->NumLeafParents(); @@ -1085,6 +1125,7 @@ static inline void MCMCSampleTreeOneIter(Tree* tree, ForestTracker& tracker, For * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). + * \param num_threads Number of threads to use for split evaluations and other compute-intensive operations. * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template diff --git a/man/ForestDataset.Rd b/man/ForestDataset.Rd index dfd7760f..e684cd18 100644 --- a/man/ForestDataset.Rd +++ b/man/ForestDataset.Rd @@ -29,6 +29,12 @@ weights are optional. \item \href{#method-ForestDataset-get_variance_weights}{\code{ForestDataset$get_variance_weights()}} \item \href{#method-ForestDataset-has_basis}{\code{ForestDataset$has_basis()}} \item \href{#method-ForestDataset-has_variance_weights}{\code{ForestDataset$has_variance_weights()}} +\item \href{#method-ForestDataset-has_auxiliary_dimension}{\code{ForestDataset$has_auxiliary_dimension()}} +\item \href{#method-ForestDataset-add_auxiliary_dimension}{\code{ForestDataset$add_auxiliary_dimension()}} +\item \href{#method-ForestDataset-get_auxiliary_data_value}{\code{ForestDataset$get_auxiliary_data_value()}} +\item \href{#method-ForestDataset-set_auxiliary_data_value}{\code{ForestDataset$set_auxiliary_data_value()}} +\item \href{#method-ForestDataset-get_auxiliary_data_vector}{\code{ForestDataset$get_auxiliary_data_vector()}} +\item \href{#method-ForestDataset-store_auxiliary_data_vector_matrix}{\code{ForestDataset$store_auxiliary_data_vector_matrix()}} } } \if{html}{\out{
}} @@ -195,4 +201,138 @@ Whether or not a dataset has variance weights True if variance weights are loaded, false otherwise } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-has_auxiliary_dimension}{}}} +\subsection{Method \code{has_auxiliary_dimension()}}{ +Whether or not a dataset has auxiliary data stored at the dimension indicated +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$has_auxiliary_dimension(dim_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension of auxiliary data} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +True if auxiliary data has been allocated for \code{dim_idx} False otherwise +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-add_auxiliary_dimension}{}}} +\subsection{Method \code{add_auxiliary_dimension()}}{ +Initialize a new dimension / lane of auxiliary data and allocate data in its place +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$add_auxiliary_dimension(dim_size)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_size}}{Size of the new vector of data to allocate} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +None +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_auxiliary_data_value}{}}} +\subsection{Method \code{get_auxiliary_data_value()}}{ +Retrieve auxiliary data value +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_auxiliary_data_value(dim_idx, element_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension from which data value to be retrieved} + +\item{\code{element_idx}}{Element to retrieve from dimension \code{dim_idx}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Floating point value stored in the requested auxiliary data space +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-set_auxiliary_data_value}{}}} +\subsection{Method \code{set_auxiliary_data_value()}}{ +Set auxiliary data value +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$set_auxiliary_data_value(dim_idx, element_idx, value)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension in which data value to be set} + +\item{\code{element_idx}}{Element to set within dimension \code{dim_idx}} + +\item{\code{value}}{Data value to set at auxiliary data dimension \code{dim_idx} and element \code{element_idx}} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +None +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-get_auxiliary_data_vector}{}}} +\subsection{Method \code{get_auxiliary_data_vector()}}{ +Retrieve entire auxiliary data vector +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$get_auxiliary_data_vector(dim_idx)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{dim_idx}}{Dimension to retrieve} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Vector of all of the auxiliary data stored at dimension \code{dim_idx} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestDataset-store_auxiliary_data_vector_matrix}{}}} +\subsection{Method \code{store_auxiliary_data_vector_matrix()}}{ +Retrieve auxiliary data vector and place it into a column of the supplied matrix +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestDataset$store_auxiliary_data_vector_matrix( + output_matrix, + dim_idx, + matrix_col_idx +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{output_matrix}}{Matrix to be overwritten} + +\item{\code{dim_idx}}{Auxiliary data dimension to retrieve} + +\item{\code{matrix_col_idx}}{Matrix column in which to copy auxiliary data} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Vector of all of the auxiliary data stored at dimension \code{dim_idx} +} +} } diff --git a/man/bart.Rd b/man/bart.Rd index 66a9b9ad..c11c619b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -136,9 +136,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -153,6 +153,6 @@ X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/bcf.Rd b/man/bcf.Rd index 01e5fab8..f7d42e93 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -162,21 +162,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -199,8 +199,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/cloglog_ordinal_bart.Rd b/man/cloglog_ordinal_bart.Rd new file mode 100644 index 00000000..049aa532 --- /dev/null +++ b/man/cloglog_ordinal_bart.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cloglog_ordinal_bart.R +\name{cloglog_ordinal_bart} +\alias{cloglog_ordinal_bart} +\title{Run the BART algorithm for ordinal outcomes using a complementary log-log link} +\usage{ +cloglog_ordinal_bart( + X, + y, + X_test = NULL, + n_trees = 50, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 500, + n_thin = 1, + alpha_gamma = 2, + beta_gamma = 2, + variable_weights = NULL, + feature_types = NULL, + seed = NULL, + num_threads = 1 +) +} +\arguments{ +\item{X}{A numeric matrix of predictors (training data).} + +\item{y}{A numeric vector of ordinal outcomes (positive integers starting from 1).} + +\item{X_test}{An optional numeric matrix of predictors (test data).} + +\item{n_trees}{Number of trees in the BART ensemble. Default: \code{50}.} + +\item{num_gfr}{Number of GFR samples to draw at the beginning of the sampler. Default: \code{0}.} + +\item{num_burnin}{Number of burn-in MCMC samples to discard. Default: \code{1000}.} + +\item{num_mcmc}{Total number of MCMC samples to draw. Default: \code{500}.} + +\item{n_thin}{Thinning interval for MCMC samples. Default: \code{1}.} + +\item{alpha_gamma}{Shape parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{beta_gamma}{Rate parameter for the log-gamma prior on cutpoints. Default: \code{2.0}.} + +\item{variable_weights}{(Optional) vector of variable weights for splitting (default: equal weights).} + +\item{feature_types}{(Optional) vector indicating feature types (0 for continuous, 1 for categorical; default: all continuous).} + +\item{seed}{(Optional) random seed for reproducibility.} + +\item{num_threads}{(Optional) Number of threads to use in split evaluations and other compute-intensive operations. Default: 1.} +} +\description{ +Run the BART algorithm for ordinal outcomes using a complementary log-log link +} diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd index 35d185c3..83d61d0d 100644 --- a/man/createBARTModelFromCombinedJson.Rd +++ b/man/createBARTModelFromCombinedJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- list(saveBARTModelToJson(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd index a8470dee..7a17484a 100644 --- a/man/createBARTModelFromCombinedJsonString.Rd +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd index 57686122..68a02f0e 100644 --- a/man/createBARTModelFromJson.Rd +++ b/man/createBARTModelFromJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) bart_model_roundtrip <- createBARTModelFromJson(bart_json) diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd index f714a94a..7608d8d2 100644 --- a/man/createBARTModelFromJsonFile.Rd +++ b/man/createBARTModelFromJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 67068fd0..0748d97a 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index 6f29569e..24c82e4f 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_list <- list(saveBCFModelToJson(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index bd7e63f2..e0522f75 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index a579b140..35cff7ce 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2661d4de..a2496797 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 5f34724c..cc944f85 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json <- saveBCFModelToJsonString(bcf_model) bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index d9000925..d7a1adae 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,10 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 0da1eb98..149586a8 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -24,9 +24,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 @@ -51,11 +51,11 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index 6769de62..08a8eae4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -24,21 +24,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -74,15 +74,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 2afccbf6..8a0a47bf 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -40,9 +40,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -56,7 +56,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X_test)$y_hat } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index ff315808..907e5308 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -42,21 +42,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -79,8 +79,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test) } diff --git a/man/preprocessPredictionData.Rd b/man/preprocessPredictionData.Rd index f881fda8..a6382e69 100644 --- a/man/preprocessPredictionData.Rd +++ b/man/preprocessPredictionData.Rd @@ -22,7 +22,7 @@ types. Matrices will be passed through assuming all columns are numeric. } \examples{ cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata) } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fec6ca..b02158d4 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,23 +48,23 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_observations=n, - num_features=p, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, - variable_weights=variable_weights, - cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, +forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, + forest_dataset, outcome, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd index fec99b77..b032ccc2 100644 --- a/man/resetRandomEffectsModel.Rd +++ b/man/resetRandomEffectsModel.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd index 5249ca96..c57af16a 100644 --- a/man/resetRandomEffectsTracker.Rd +++ b/man/resetRandomEffectsTracker.Rd @@ -57,8 +57,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd index c58a09e9..4c3cc2f7 100644 --- a/man/rootResetRandomEffectsModel.Rd +++ b/man/rootResetRandomEffectsModel.Rd @@ -63,8 +63,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd index 8de2c514..6f2dc843 100644 --- a/man/rootResetRandomEffectsTracker.Rd +++ b/man/rootResetRandomEffectsTracker.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/saveBARTModelToJson.Rd b/man/saveBARTModelToJson.Rd index a617532e..054af24e 100644 --- a/man/saveBARTModelToJson.Rd +++ b/man/saveBARTModelToJson.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd index 46a3110e..62ef6ad7 100644 --- a/man/saveBARTModelToJsonFile.Rd +++ b/man/saveBARTModelToJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd index c83f9e5d..10927c20 100644 --- a/man/saveBARTModelToJsonString.Rd +++ b/man/saveBARTModelToJsonString.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string <- saveBARTModelToJsonString(bart_model) } diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index ae2c286d..2c04d76c 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) } diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index e6a9f0aa..584bbbba 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 4328e525..2182bbe3 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) saveBCFModelToJsonString(bcf_model) } diff --git a/src/Makevars.in b/src/Makevars.in index 4eb970cb..850e2555 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/Makevars.win.in b/src/Makevars.win.in index 95bff1dd..e9d54ab6 100644 --- a/src/Makevars.win.in +++ b/src/Makevars.win.in @@ -34,6 +34,7 @@ OBJECTS = \ data.o \ io.o \ leaf_model.o \ + ordinal_sampler.o \ partition_tracker.o \ random_effects.o \ tree.o diff --git a/src/R_data.cpp b/src/R_data.cpp index 39b77ab3..caf3e9bc 100644 --- a/src/R_data.cpp +++ b/src/R_data.cpp @@ -148,6 +148,47 @@ cpp11::writable::doubles forest_dataset_get_variance_weights_cpp(cpp11::external return output; } +[[cpp11::register]] +bool forest_dataset_has_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { + return dataset_ptr->HasAuxiliaryDimension(dim_idx); +} + +[[cpp11::register]] +void forest_dataset_add_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_size) { + dataset_ptr->AddAuxiliaryDimension(dim_size); +} + +[[cpp11::register]] +double forest_dataset_get_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx) { + return dataset_ptr->GetAuxiliaryDataValue(dim_idx, element_idx); +} + +[[cpp11::register]] +void forest_dataset_set_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx, double value) { + dataset_ptr->SetAuxiliaryDataValue(dim_idx, element_idx, value); +} + +[[cpp11::register]] +cpp11::writable::doubles forest_dataset_get_auxiliary_data_vector_cpp(cpp11::external_pointer dataset_ptr, int dim_idx) { + const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + cpp11::writable::doubles output(n); + for (int i = 0; i < n; i++) { + output[i] = output_raw[i]; + } + return output; +} + +[[cpp11::register]] +cpp11::writable::doubles_matrix<> forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::external_pointer dataset_ptr, cpp11::writable::doubles_matrix<> output_matrix, int dim_idx, int matrix_col_idx) { + const std::vector output_raw = dataset_ptr->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + for (int i = 0; i < n; i++) { + output_matrix(i, matrix_col_idx) = output_raw[i]; + } + return output_matrix; +} + [[cpp11::register]] cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome) { // Unpack pointers to data and dimensions diff --git a/src/cpp11.cpp b/src/cpp11.cpp index ef98aac0..0f20cdcb 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -109,6 +109,50 @@ extern "C" SEXP _stochtree_forest_dataset_get_variance_weights_cpp(SEXP dataset_ END_CPP11 } // R_data.cpp +bool forest_dataset_has_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_idx); +extern "C" SEXP _stochtree_forest_dataset_has_auxiliary_dimension_cpp(SEXP dataset_ptr, SEXP dim_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_has_auxiliary_dimension_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx))); + END_CPP11 +} +// R_data.cpp +void forest_dataset_add_auxiliary_dimension_cpp(cpp11::external_pointer dataset_ptr, int dim_size); +extern "C" SEXP _stochtree_forest_dataset_add_auxiliary_dimension_cpp(SEXP dataset_ptr, SEXP dim_size) { + BEGIN_CPP11 + forest_dataset_add_auxiliary_dimension_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_size)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +double forest_dataset_get_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx); +extern "C" SEXP _stochtree_forest_dataset_get_auxiliary_data_value_cpp(SEXP dataset_ptr, SEXP dim_idx, SEXP element_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_auxiliary_data_value_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(element_idx))); + END_CPP11 +} +// R_data.cpp +void forest_dataset_set_auxiliary_data_value_cpp(cpp11::external_pointer dataset_ptr, int dim_idx, int element_idx, double value); +extern "C" SEXP _stochtree_forest_dataset_set_auxiliary_data_value_cpp(SEXP dataset_ptr, SEXP dim_idx, SEXP element_idx, SEXP value) { + BEGIN_CPP11 + forest_dataset_set_auxiliary_data_value_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(element_idx), cpp11::as_cpp>(value)); + return R_NilValue; + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles forest_dataset_get_auxiliary_data_vector_cpp(cpp11::external_pointer dataset_ptr, int dim_idx); +extern "C" SEXP _stochtree_forest_dataset_get_auxiliary_data_vector_cpp(SEXP dataset_ptr, SEXP dim_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_get_auxiliary_data_vector_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>(dim_idx))); + END_CPP11 +} +// R_data.cpp +cpp11::writable::doubles_matrix<> forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::external_pointer dataset_ptr, cpp11::writable::doubles_matrix<> output_matrix, int dim_idx, int matrix_col_idx); +extern "C" SEXP _stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp(SEXP dataset_ptr, SEXP output_matrix, SEXP dim_idx, SEXP matrix_col_idx) { + BEGIN_CPP11 + return cpp11::as_sexp(forest_dataset_store_auxiliary_data_vector_as_column_cpp(cpp11::as_cpp>>(dataset_ptr), cpp11::as_cpp>>(output_matrix), cpp11::as_cpp>(dim_idx), cpp11::as_cpp>(matrix_col_idx))); + END_CPP11 +} +// R_data.cpp cpp11::external_pointer create_column_vector_cpp(cpp11::doubles outcome); extern "C" SEXP _stochtree_create_column_vector_cpp(SEXP outcome) { BEGIN_CPP11 @@ -1281,6 +1325,37 @@ extern "C" SEXP _stochtree_sample_without_replacement_integer_cpp(SEXP populatio return cpp11::as_sexp(sample_without_replacement_integer_cpp(cpp11::as_cpp>(population_vector), cpp11::as_cpp>(sampling_probs), cpp11::as_cpp>(sample_size))); END_CPP11 } +// sampler.cpp +cpp11::external_pointer ordinal_sampler_cpp(); +extern "C" SEXP _stochtree_ordinal_sampler_cpp() { + BEGIN_CPP11 + return cpp11::as_sexp(ordinal_sampler_cpp()); + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_latent_variables_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_latent_variables_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_latent_variables_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_gamma_params_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr, cpp11::external_pointer outcome_ptr, double alpha_gamma, double beta_gamma, double gamma_0, cpp11::external_pointer rng_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_gamma_params_cpp(SEXP sampler_ptr, SEXP data_ptr, SEXP outcome_ptr, SEXP alpha_gamma, SEXP beta_gamma, SEXP gamma_0, SEXP rng_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_gamma_params_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr), cpp11::as_cpp>>(outcome_ptr), cpp11::as_cpp>(alpha_gamma), cpp11::as_cpp>(beta_gamma), cpp11::as_cpp>(gamma_0), cpp11::as_cpp>>(rng_ptr)); + return R_NilValue; + END_CPP11 +} +// sampler.cpp +void ordinal_sampler_update_cumsum_exp_cpp(cpp11::external_pointer sampler_ptr, cpp11::external_pointer data_ptr); +extern "C" SEXP _stochtree_ordinal_sampler_update_cumsum_exp_cpp(SEXP sampler_ptr, SEXP data_ptr) { + BEGIN_CPP11 + ordinal_sampler_update_cumsum_exp_cpp(cpp11::as_cpp>>(sampler_ptr), cpp11::as_cpp>>(data_ptr)); + return R_NilValue; + END_CPP11 +} // serialization.cpp cpp11::external_pointer init_json_cpp(); extern "C" SEXP _stochtree_init_json_cpp() { @@ -1582,219 +1657,229 @@ extern "C" SEXP _stochtree_json_load_string_cpp(SEXP json_ptr, SEXP json_string) extern "C" { static const R_CallMethodDef CallEntries[] = { - {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, - {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, - {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, - {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, - {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, - {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, - {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, - {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, - {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, - {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, - {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, - {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, - {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, - {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, - {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, - {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, - {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, - {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, - {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, - {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, - {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, - {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, - {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, - {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, - {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, - {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, - {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, - {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, - {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, - {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, - {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, - {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, - {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, - {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, - {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, - {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, - {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, - {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, - {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, - {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, - {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, - {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, - {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, - {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, - {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, - {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, - {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, - {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, - {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, - {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, - {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, - {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, - {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, - {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, - {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, - {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, - {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, - {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, - {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, - {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, - {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, - {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, - {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, - {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, - {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, - {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, - {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, - {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, - {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, - {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, - {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, - {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, - {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, - {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, - {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, - {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, - {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, - {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, - {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, - {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, - {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, - {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, - {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, - {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, - {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, - {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, - {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, - {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, - {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, - {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, - {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, - {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, - {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, - {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, - {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, - {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, - {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, - {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, - {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, - {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, - {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, - {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, - {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, - {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, - {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, - {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, - {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, - {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, - {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, - {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, - {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, - {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, - {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, - {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, - {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, - {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, - {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, - {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, - {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, - {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, - {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, - {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, - {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, - {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, - {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, - {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, - {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, - {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, - {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, - {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, - {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, - {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, - {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, - {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, - {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, - {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, - {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, - {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, - {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, - {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, - {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, - {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, - {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, - {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, - {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, - {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, - {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, - {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, - {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, - {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, - {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, - {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, - {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, - {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, - {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, - {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, - {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, - {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, - {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, - {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, - {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, - {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, - {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, - {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, - {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, - {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, - {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, - {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, - {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, - {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, - {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, - {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, - {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, - {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, - {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, - {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, - {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, - {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, - {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, - {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 19}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 18}, - {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, - {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, - {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, - {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, - {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, - {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, - {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, - {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, - {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, - {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, - {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, - {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, - {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, - {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, - {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, - {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, - {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, + {"_stochtree_active_forest_cpp", (DL_FUNC) &_stochtree_active_forest_cpp, 4}, + {"_stochtree_add_numeric_split_tree_value_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_value_forest_container_cpp, 8}, + {"_stochtree_add_numeric_split_tree_vector_active_forest_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_active_forest_cpp, 7}, + {"_stochtree_add_numeric_split_tree_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_numeric_split_tree_vector_forest_container_cpp, 8}, + {"_stochtree_add_sample_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_forest_container_cpp, 1}, + {"_stochtree_add_sample_value_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_value_forest_container_cpp, 2}, + {"_stochtree_add_sample_vector_forest_container_cpp", (DL_FUNC) &_stochtree_add_sample_vector_forest_container_cpp, 2}, + {"_stochtree_add_to_column_vector_cpp", (DL_FUNC) &_stochtree_add_to_column_vector_cpp, 2}, + {"_stochtree_add_to_forest_forest_container_cpp", (DL_FUNC) &_stochtree_add_to_forest_forest_container_cpp, 3}, + {"_stochtree_adjust_residual_active_forest_cpp", (DL_FUNC) &_stochtree_adjust_residual_active_forest_cpp, 6}, + {"_stochtree_adjust_residual_forest_container_cpp", (DL_FUNC) &_stochtree_adjust_residual_forest_container_cpp, 7}, + {"_stochtree_all_roots_active_forest_cpp", (DL_FUNC) &_stochtree_all_roots_active_forest_cpp, 1}, + {"_stochtree_all_roots_forest_container_cpp", (DL_FUNC) &_stochtree_all_roots_forest_container_cpp, 2}, + {"_stochtree_average_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_average_max_depth_active_forest_cpp, 1}, + {"_stochtree_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_average_max_depth_forest_container_cpp, 1}, + {"_stochtree_combine_forests_forest_container_cpp", (DL_FUNC) &_stochtree_combine_forests_forest_container_cpp, 2}, + {"_stochtree_compute_leaf_indices_cpp", (DL_FUNC) &_stochtree_compute_leaf_indices_cpp, 3}, + {"_stochtree_create_column_vector_cpp", (DL_FUNC) &_stochtree_create_column_vector_cpp, 1}, + {"_stochtree_create_forest_dataset_cpp", (DL_FUNC) &_stochtree_create_forest_dataset_cpp, 0}, + {"_stochtree_create_rfx_dataset_cpp", (DL_FUNC) &_stochtree_create_rfx_dataset_cpp, 0}, + {"_stochtree_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_dataset_has_basis_cpp, 1}, + {"_stochtree_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_dataset_num_basis_cpp, 1}, + {"_stochtree_dataset_num_covariates_cpp", (DL_FUNC) &_stochtree_dataset_num_covariates_cpp, 1}, + {"_stochtree_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_dataset_num_rows_cpp, 1}, + {"_stochtree_ensemble_average_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_average_max_depth_forest_container_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_active_forest_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_active_forest_cpp, 2}, + {"_stochtree_ensemble_tree_max_depth_forest_container_cpp", (DL_FUNC) &_stochtree_ensemble_tree_max_depth_forest_container_cpp, 3}, + {"_stochtree_forest_add_constant_cpp", (DL_FUNC) &_stochtree_forest_add_constant_cpp, 2}, + {"_stochtree_forest_container_append_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_cpp, 3}, + {"_stochtree_forest_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_append_from_json_string_cpp, 3}, + {"_stochtree_forest_container_cpp", (DL_FUNC) &_stochtree_forest_container_cpp, 4}, + {"_stochtree_forest_container_from_json_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_cpp, 2}, + {"_stochtree_forest_container_from_json_string_cpp", (DL_FUNC) &_stochtree_forest_container_from_json_string_cpp, 2}, + {"_stochtree_forest_container_get_max_leaf_index_cpp", (DL_FUNC) &_stochtree_forest_container_get_max_leaf_index_cpp, 2}, + {"_stochtree_forest_dataset_add_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_basis_cpp, 2}, + {"_stochtree_forest_dataset_add_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_covariates_cpp, 2}, + {"_stochtree_forest_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_add_weights_cpp, 2}, + {"_stochtree_forest_dataset_get_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_value_cpp, 3}, + {"_stochtree_forest_dataset_get_auxiliary_data_vector_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_auxiliary_data_vector_cpp, 2}, + {"_stochtree_forest_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_basis_cpp, 1}, + {"_stochtree_forest_dataset_get_covariates_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_covariates_cpp, 1}, + {"_stochtree_forest_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_forest_dataset_has_auxiliary_dimension_cpp", (DL_FUNC) &_stochtree_forest_dataset_has_auxiliary_dimension_cpp, 2}, + {"_stochtree_forest_dataset_set_auxiliary_data_value_cpp", (DL_FUNC) &_stochtree_forest_dataset_set_auxiliary_data_value_cpp, 4}, + {"_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp", (DL_FUNC) &_stochtree_forest_dataset_store_auxiliary_data_vector_as_column_cpp, 4}, + {"_stochtree_forest_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_basis_cpp, 2}, + {"_stochtree_forest_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_forest_dataset_update_var_weights_cpp, 3}, + {"_stochtree_forest_merge_cpp", (DL_FUNC) &_stochtree_forest_merge_cpp, 2}, + {"_stochtree_forest_multiply_constant_cpp", (DL_FUNC) &_stochtree_forest_multiply_constant_cpp, 2}, + {"_stochtree_forest_tracker_cpp", (DL_FUNC) &_stochtree_forest_tracker_cpp, 4}, + {"_stochtree_get_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_get_alpha_tree_prior_cpp, 1}, + {"_stochtree_get_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_get_beta_tree_prior_cpp, 1}, + {"_stochtree_get_cached_forest_predictions_cpp", (DL_FUNC) &_stochtree_get_cached_forest_predictions_cpp, 1}, + {"_stochtree_get_forest_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_forest_split_counts_forest_container_cpp, 3}, + {"_stochtree_get_granular_split_count_array_active_forest_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_active_forest_cpp, 2}, + {"_stochtree_get_granular_split_count_array_forest_container_cpp", (DL_FUNC) &_stochtree_get_granular_split_count_array_forest_container_cpp, 2}, + {"_stochtree_get_json_string_cpp", (DL_FUNC) &_stochtree_get_json_string_cpp, 1}, + {"_stochtree_get_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_get_max_depth_tree_prior_cpp, 1}, + {"_stochtree_get_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_get_min_samples_leaf_tree_prior_cpp, 1}, + {"_stochtree_get_overall_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_active_forest_cpp, 2}, + {"_stochtree_get_overall_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_overall_split_counts_forest_container_cpp, 2}, + {"_stochtree_get_residual_cpp", (DL_FUNC) &_stochtree_get_residual_cpp, 1}, + {"_stochtree_get_tree_leaves_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_active_forest_cpp, 2}, + {"_stochtree_get_tree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_leaves_forest_container_cpp, 3}, + {"_stochtree_get_tree_split_counts_active_forest_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_active_forest_cpp, 3}, + {"_stochtree_get_tree_split_counts_forest_container_cpp", (DL_FUNC) &_stochtree_get_tree_split_counts_forest_container_cpp, 4}, + {"_stochtree_init_json_cpp", (DL_FUNC) &_stochtree_init_json_cpp, 0}, + {"_stochtree_initialize_forest_model_active_forest_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_active_forest_cpp, 6}, + {"_stochtree_initialize_forest_model_cpp", (DL_FUNC) &_stochtree_initialize_forest_model_cpp, 6}, + {"_stochtree_is_categorical_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_categorical_split_node_forest_container_cpp, 4}, + {"_stochtree_is_exponentiated_active_forest_cpp", (DL_FUNC) &_stochtree_is_exponentiated_active_forest_cpp, 1}, + {"_stochtree_is_exponentiated_forest_container_cpp", (DL_FUNC) &_stochtree_is_exponentiated_forest_container_cpp, 1}, + {"_stochtree_is_leaf_constant_active_forest_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_active_forest_cpp, 1}, + {"_stochtree_is_leaf_constant_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_constant_forest_container_cpp, 1}, + {"_stochtree_is_leaf_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_leaf_node_forest_container_cpp, 4}, + {"_stochtree_is_numeric_split_node_forest_container_cpp", (DL_FUNC) &_stochtree_is_numeric_split_node_forest_container_cpp, 4}, + {"_stochtree_json_add_bool_cpp", (DL_FUNC) &_stochtree_json_add_bool_cpp, 3}, + {"_stochtree_json_add_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_bool_subfolder_cpp, 4}, + {"_stochtree_json_add_double_cpp", (DL_FUNC) &_stochtree_json_add_double_cpp, 3}, + {"_stochtree_json_add_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_double_subfolder_cpp, 4}, + {"_stochtree_json_add_forest_cpp", (DL_FUNC) &_stochtree_json_add_forest_cpp, 2}, + {"_stochtree_json_add_integer_cpp", (DL_FUNC) &_stochtree_json_add_integer_cpp, 3}, + {"_stochtree_json_add_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_subfolder_cpp, 4}, + {"_stochtree_json_add_integer_vector_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_cpp, 3}, + {"_stochtree_json_add_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_integer_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_rfx_container_cpp", (DL_FUNC) &_stochtree_json_add_rfx_container_cpp, 2}, + {"_stochtree_json_add_rfx_groupids_cpp", (DL_FUNC) &_stochtree_json_add_rfx_groupids_cpp, 2}, + {"_stochtree_json_add_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_json_add_rfx_label_mapper_cpp, 2}, + {"_stochtree_json_add_string_cpp", (DL_FUNC) &_stochtree_json_add_string_cpp, 3}, + {"_stochtree_json_add_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_subfolder_cpp, 4}, + {"_stochtree_json_add_string_vector_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_cpp, 3}, + {"_stochtree_json_add_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_string_vector_subfolder_cpp, 4}, + {"_stochtree_json_add_vector_cpp", (DL_FUNC) &_stochtree_json_add_vector_cpp, 3}, + {"_stochtree_json_add_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_add_vector_subfolder_cpp, 4}, + {"_stochtree_json_contains_field_cpp", (DL_FUNC) &_stochtree_json_contains_field_cpp, 2}, + {"_stochtree_json_contains_field_subfolder_cpp", (DL_FUNC) &_stochtree_json_contains_field_subfolder_cpp, 3}, + {"_stochtree_json_extract_bool_cpp", (DL_FUNC) &_stochtree_json_extract_bool_cpp, 2}, + {"_stochtree_json_extract_bool_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_bool_subfolder_cpp, 3}, + {"_stochtree_json_extract_double_cpp", (DL_FUNC) &_stochtree_json_extract_double_cpp, 2}, + {"_stochtree_json_extract_double_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_double_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_cpp", (DL_FUNC) &_stochtree_json_extract_integer_cpp, 2}, + {"_stochtree_json_extract_integer_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_subfolder_cpp, 3}, + {"_stochtree_json_extract_integer_vector_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_cpp, 2}, + {"_stochtree_json_extract_integer_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_integer_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_cpp", (DL_FUNC) &_stochtree_json_extract_string_cpp, 2}, + {"_stochtree_json_extract_string_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_subfolder_cpp, 3}, + {"_stochtree_json_extract_string_vector_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_cpp, 2}, + {"_stochtree_json_extract_string_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_string_vector_subfolder_cpp, 3}, + {"_stochtree_json_extract_vector_cpp", (DL_FUNC) &_stochtree_json_extract_vector_cpp, 2}, + {"_stochtree_json_extract_vector_subfolder_cpp", (DL_FUNC) &_stochtree_json_extract_vector_subfolder_cpp, 3}, + {"_stochtree_json_increment_rfx_count_cpp", (DL_FUNC) &_stochtree_json_increment_rfx_count_cpp, 1}, + {"_stochtree_json_load_file_cpp", (DL_FUNC) &_stochtree_json_load_file_cpp, 2}, + {"_stochtree_json_load_forest_container_cpp", (DL_FUNC) &_stochtree_json_load_forest_container_cpp, 2}, + {"_stochtree_json_load_string_cpp", (DL_FUNC) &_stochtree_json_load_string_cpp, 2}, + {"_stochtree_json_save_file_cpp", (DL_FUNC) &_stochtree_json_save_file_cpp, 2}, + {"_stochtree_json_save_forest_container_cpp", (DL_FUNC) &_stochtree_json_save_forest_container_cpp, 2}, + {"_stochtree_leaf_dimension_active_forest_cpp", (DL_FUNC) &_stochtree_leaf_dimension_active_forest_cpp, 1}, + {"_stochtree_leaf_dimension_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_dimension_forest_container_cpp, 1}, + {"_stochtree_leaf_values_forest_container_cpp", (DL_FUNC) &_stochtree_leaf_values_forest_container_cpp, 4}, + {"_stochtree_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_leaves_forest_container_cpp, 3}, + {"_stochtree_left_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_left_child_node_forest_container_cpp, 4}, + {"_stochtree_multiply_forest_forest_container_cpp", (DL_FUNC) &_stochtree_multiply_forest_forest_container_cpp, 3}, + {"_stochtree_node_depth_forest_container_cpp", (DL_FUNC) &_stochtree_node_depth_forest_container_cpp, 4}, + {"_stochtree_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_nodes_forest_container_cpp, 3}, + {"_stochtree_num_leaf_parents_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaf_parents_forest_container_cpp, 3}, + {"_stochtree_num_leaves_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_ensemble_forest_container_cpp, 2}, + {"_stochtree_num_leaves_forest_container_cpp", (DL_FUNC) &_stochtree_num_leaves_forest_container_cpp, 3}, + {"_stochtree_num_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_nodes_forest_container_cpp, 3}, + {"_stochtree_num_samples_forest_container_cpp", (DL_FUNC) &_stochtree_num_samples_forest_container_cpp, 1}, + {"_stochtree_num_split_nodes_forest_container_cpp", (DL_FUNC) &_stochtree_num_split_nodes_forest_container_cpp, 3}, + {"_stochtree_num_trees_active_forest_cpp", (DL_FUNC) &_stochtree_num_trees_active_forest_cpp, 1}, + {"_stochtree_num_trees_forest_container_cpp", (DL_FUNC) &_stochtree_num_trees_forest_container_cpp, 1}, + {"_stochtree_ordinal_sampler_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_cpp, 0}, + {"_stochtree_ordinal_sampler_update_cumsum_exp_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_cumsum_exp_cpp, 2}, + {"_stochtree_ordinal_sampler_update_gamma_params_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_gamma_params_cpp, 7}, + {"_stochtree_ordinal_sampler_update_latent_variables_cpp", (DL_FUNC) &_stochtree_ordinal_sampler_update_latent_variables_cpp, 4}, + {"_stochtree_overwrite_column_vector_cpp", (DL_FUNC) &_stochtree_overwrite_column_vector_cpp, 2}, + {"_stochtree_parent_node_forest_container_cpp", (DL_FUNC) &_stochtree_parent_node_forest_container_cpp, 4}, + {"_stochtree_predict_active_forest_cpp", (DL_FUNC) &_stochtree_predict_active_forest_cpp, 2}, + {"_stochtree_predict_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_cpp, 2}, + {"_stochtree_predict_forest_raw_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_cpp, 2}, + {"_stochtree_predict_forest_raw_single_forest_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_forest_cpp, 3}, + {"_stochtree_predict_forest_raw_single_tree_cpp", (DL_FUNC) &_stochtree_predict_forest_raw_single_tree_cpp, 4}, + {"_stochtree_predict_raw_active_forest_cpp", (DL_FUNC) &_stochtree_predict_raw_active_forest_cpp, 2}, + {"_stochtree_propagate_basis_update_active_forest_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_active_forest_cpp, 4}, + {"_stochtree_propagate_basis_update_forest_container_cpp", (DL_FUNC) &_stochtree_propagate_basis_update_forest_container_cpp, 5}, + {"_stochtree_propagate_trees_column_vector_cpp", (DL_FUNC) &_stochtree_propagate_trees_column_vector_cpp, 2}, + {"_stochtree_remove_sample_forest_container_cpp", (DL_FUNC) &_stochtree_remove_sample_forest_container_cpp, 2}, + {"_stochtree_reset_active_forest_cpp", (DL_FUNC) &_stochtree_reset_active_forest_cpp, 3}, + {"_stochtree_reset_forest_model_cpp", (DL_FUNC) &_stochtree_reset_forest_model_cpp, 5}, + {"_stochtree_reset_rfx_model_cpp", (DL_FUNC) &_stochtree_reset_rfx_model_cpp, 3}, + {"_stochtree_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_reset_rfx_tracker_cpp, 4}, + {"_stochtree_rfx_container_append_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_cpp, 3}, + {"_stochtree_rfx_container_append_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_append_from_json_string_cpp, 3}, + {"_stochtree_rfx_container_cpp", (DL_FUNC) &_stochtree_rfx_container_cpp, 2}, + {"_stochtree_rfx_container_delete_sample_cpp", (DL_FUNC) &_stochtree_rfx_container_delete_sample_cpp, 2}, + {"_stochtree_rfx_container_from_json_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_cpp, 2}, + {"_stochtree_rfx_container_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_container_from_json_string_cpp, 2}, + {"_stochtree_rfx_container_get_alpha_cpp", (DL_FUNC) &_stochtree_rfx_container_get_alpha_cpp, 1}, + {"_stochtree_rfx_container_get_beta_cpp", (DL_FUNC) &_stochtree_rfx_container_get_beta_cpp, 1}, + {"_stochtree_rfx_container_get_sigma_cpp", (DL_FUNC) &_stochtree_rfx_container_get_sigma_cpp, 1}, + {"_stochtree_rfx_container_get_xi_cpp", (DL_FUNC) &_stochtree_rfx_container_get_xi_cpp, 1}, + {"_stochtree_rfx_container_num_components_cpp", (DL_FUNC) &_stochtree_rfx_container_num_components_cpp, 1}, + {"_stochtree_rfx_container_num_groups_cpp", (DL_FUNC) &_stochtree_rfx_container_num_groups_cpp, 1}, + {"_stochtree_rfx_container_num_samples_cpp", (DL_FUNC) &_stochtree_rfx_container_num_samples_cpp, 1}, + {"_stochtree_rfx_container_predict_cpp", (DL_FUNC) &_stochtree_rfx_container_predict_cpp, 3}, + {"_stochtree_rfx_dataset_add_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_basis_cpp, 2}, + {"_stochtree_rfx_dataset_add_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_add_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_add_weights_cpp, 2}, + {"_stochtree_rfx_dataset_get_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_basis_cpp, 1}, + {"_stochtree_rfx_dataset_get_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_get_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_get_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_has_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_basis_cpp, 1}, + {"_stochtree_rfx_dataset_has_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_group_labels_cpp, 1}, + {"_stochtree_rfx_dataset_has_variance_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_has_variance_weights_cpp, 1}, + {"_stochtree_rfx_dataset_num_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_basis_cpp, 1}, + {"_stochtree_rfx_dataset_num_rows_cpp", (DL_FUNC) &_stochtree_rfx_dataset_num_rows_cpp, 1}, + {"_stochtree_rfx_dataset_update_basis_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_basis_cpp, 2}, + {"_stochtree_rfx_dataset_update_group_labels_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_group_labels_cpp, 2}, + {"_stochtree_rfx_dataset_update_var_weights_cpp", (DL_FUNC) &_stochtree_rfx_dataset_update_var_weights_cpp, 3}, + {"_stochtree_rfx_group_ids_from_json_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_cpp, 2}, + {"_stochtree_rfx_group_ids_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_group_ids_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_cpp, 1}, + {"_stochtree_rfx_label_mapper_from_json_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_cpp, 2}, + {"_stochtree_rfx_label_mapper_from_json_string_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_from_json_string_cpp, 2}, + {"_stochtree_rfx_label_mapper_to_list_cpp", (DL_FUNC) &_stochtree_rfx_label_mapper_to_list_cpp, 1}, + {"_stochtree_rfx_model_cpp", (DL_FUNC) &_stochtree_rfx_model_cpp, 2}, + {"_stochtree_rfx_model_predict_cpp", (DL_FUNC) &_stochtree_rfx_model_predict_cpp, 3}, + {"_stochtree_rfx_model_sample_random_effects_cpp", (DL_FUNC) &_stochtree_rfx_model_sample_random_effects_cpp, 8}, + {"_stochtree_rfx_model_set_group_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_group_parameters_cpp", (DL_FUNC) &_stochtree_rfx_model_set_group_parameters_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_scale_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_scale_cpp, 2}, + {"_stochtree_rfx_model_set_variance_prior_shape_cpp", (DL_FUNC) &_stochtree_rfx_model_set_variance_prior_shape_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_covariance_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_covariance_cpp, 2}, + {"_stochtree_rfx_model_set_working_parameter_cpp", (DL_FUNC) &_stochtree_rfx_model_set_working_parameter_cpp, 2}, + {"_stochtree_rfx_tracker_cpp", (DL_FUNC) &_stochtree_rfx_tracker_cpp, 1}, + {"_stochtree_rfx_tracker_get_unique_group_ids_cpp", (DL_FUNC) &_stochtree_rfx_tracker_get_unique_group_ids_cpp, 1}, + {"_stochtree_right_child_node_forest_container_cpp", (DL_FUNC) &_stochtree_right_child_node_forest_container_cpp, 4}, + {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, + {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, + {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 19}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 18}, + {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, + {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, + {"_stochtree_sample_without_replacement_integer_cpp", (DL_FUNC) &_stochtree_sample_without_replacement_integer_cpp, 3}, + {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, + {"_stochtree_set_leaf_value_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_value_forest_container_cpp, 2}, + {"_stochtree_set_leaf_vector_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_active_forest_cpp, 2}, + {"_stochtree_set_leaf_vector_forest_container_cpp", (DL_FUNC) &_stochtree_set_leaf_vector_forest_container_cpp, 2}, + {"_stochtree_split_categories_forest_container_cpp", (DL_FUNC) &_stochtree_split_categories_forest_container_cpp, 4}, + {"_stochtree_split_index_forest_container_cpp", (DL_FUNC) &_stochtree_split_index_forest_container_cpp, 4}, + {"_stochtree_split_theshold_forest_container_cpp", (DL_FUNC) &_stochtree_split_theshold_forest_container_cpp, 4}, + {"_stochtree_subtract_from_column_vector_cpp", (DL_FUNC) &_stochtree_subtract_from_column_vector_cpp, 2}, + {"_stochtree_sum_leaves_squared_ensemble_forest_container_cpp", (DL_FUNC) &_stochtree_sum_leaves_squared_ensemble_forest_container_cpp, 2}, + {"_stochtree_tree_prior_cpp", (DL_FUNC) &_stochtree_tree_prior_cpp, 4}, + {"_stochtree_update_alpha_tree_prior_cpp", (DL_FUNC) &_stochtree_update_alpha_tree_prior_cpp, 2}, + {"_stochtree_update_beta_tree_prior_cpp", (DL_FUNC) &_stochtree_update_beta_tree_prior_cpp, 2}, + {"_stochtree_update_max_depth_tree_prior_cpp", (DL_FUNC) &_stochtree_update_max_depth_tree_prior_cpp, 2}, + {"_stochtree_update_min_samples_leaf_tree_prior_cpp", (DL_FUNC) &_stochtree_update_min_samples_leaf_tree_prior_cpp, 2}, {NULL, NULL, 0} }; } diff --git a/src/kernel.cpp b/src/kernel.cpp index 88f12c53..38fdd35c 100644 --- a/src/kernel.cpp +++ b/src/kernel.cpp @@ -2,7 +2,6 @@ #include "stochtree_types.h" #include #include -#include typedef Eigen::Map> DoubleMatrixType; typedef Eigen::Map> IntMatrixType; diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 3b59ab96..3c1f91c9 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -274,4 +274,64 @@ void LogLinearVarianceLeafModel::SetEnsembleRootPredictedValue(ForestDataset& da } } +double CloglogOrdinalLeafModel::SplitLogMarginalLikelihood(CloglogOrdinalSuffStat& left_stat, CloglogOrdinalSuffStat& right_stat, double global_variance) { + double left_log_ml = SuffStatLogMarginalLikelihood(left_stat, global_variance); + double right_log_ml = SuffStatLogMarginalLikelihood(right_stat, global_variance); + return left_log_ml + right_log_ml; +} + +double CloglogOrdinalLeafModel::NoSplitLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return SuffStatLogMarginalLikelihood(suff_stat, global_variance); +} + +double CloglogOrdinalLeafModel::SuffStatLogMarginalLikelihood(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_); + double a_term = a_ + suff_stat.sum_Y_less_K; + double b_term = b_ + suff_stat.other_sum; + double log_b_term = std::log(b_term); + double lgamma_a_term = boost::math::lgamma(a_term); + double resid_term = a_term * log_b_term; + double log_ml = prior_terms + lgamma_a_term - resid_term; + return log_ml; +} + +double CloglogOrdinalLeafModel::PosteriorParameterShape(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return a_ + suff_stat.sum_Y_less_K; +} + +double CloglogOrdinalLeafModel::PosteriorParameterRate(CloglogOrdinalSuffStat& suff_stat, double global_variance) { + return b_ + suff_stat.other_sum; +} + +void CloglogOrdinalLeafModel::SampleLeafParameters(ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, Tree* tree, int tree_num, double global_variance, std::mt19937& gen) { + // Vector of leaf indices for tree + std::vector tree_leaves = tree->GetLeaves(); + + // Initialize sufficient statistics + CloglogOrdinalSuffStat node_suff_stat = CloglogOrdinalSuffStat(); + + // Sample each leaf node parameter + double node_shape; + double node_rate; + double node_mu; + int32_t leaf_id; + for (int i = 0; i < tree_leaves.size(); i++) { + // Compute leaf node sufficient statistics + leaf_id = tree_leaves[i]; + node_suff_stat.ResetSuffStat(); + AccumulateSingleNodeSuffStat(node_suff_stat, dataset, tracker, residual, tree_num, leaf_id); + + // Compute posterior shape and rate + node_shape = PosteriorParameterShape(node_suff_stat, global_variance); + node_rate = PosteriorParameterRate(node_suff_stat, global_variance); + + // Draw from log-gamma dist(node_shape, node_rate) and set the leaf parameter with each draw + // std::gamma_distribution gamma_dist_(node_shape, 1.); + // node_mu = -std::log(gamma_sample / node_rate); + double gamma_sample = gamma_sampler_.Sample(node_shape, node_rate, gen); + node_mu = std::log(gamma_sample); + tree->SetLeaf(leaf_id, node_mu); + } +} + } // namespace StochTree diff --git a/src/ordinal_sampler.cpp b/src/ordinal_sampler.cpp new file mode 100644 index 00000000..6a09a000 --- /dev/null +++ b/src/ordinal_sampler.cpp @@ -0,0 +1,102 @@ +#include +#include + +namespace StochTree { + +double OrdinalSampler::SampleTruncatedExponential(std::mt19937& gen, double rate, double low, double high) { + std::uniform_real_distribution unif(0.0, 1.0); + double u = unif(gen); + if ((low <= 0.0) && (high <= 0.0)) { + return sample_exponential(u, rate); + } else if ((low <= 0.0) && (high > 0.0)) { + return sample_truncated_exponential_high(u, rate, high); + } else if ((low > 0.0) && (high <= 0.0)) { + return sample_truncated_exponential_low(u, rate, low); + } else { + return sample_truncated_exponential_low_high(u, rate, low, high); + } +} + +void OrdinalSampler::UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen) { + // Get auxiliary data vectors + const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // gamma cutpoints + const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables: z_i ~ TExp(e^{gamma[y_i] + lambda_hat_i}; 0, 1) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Update truncated exponentials (stored in latent auxiliary data slot 0) + // z_i ~ TExp(rate = e^{gamma[y_i] + lambda_hat_i}; 0, 1) + // where y_i is the ordinal outcome for observation i: make sure y_i converted to {0, 1, ..., K-1} + // and lambda_hat_i is the total forest prediction for observation i + // If y_i = K-1 (last category), then we set z_i = 1.0 deterministically just for bookkeeping, we don't need it + // We only need to sample latent z_i for y_i < K-1 (as z_i is only used in the likelihood for y_i < K-1) + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y == K - 1) { + Z[i] = 1.0; + } else { + double rate = std::exp(gamma[y] + lambda_hat[i]); + Z[i] = SampleTruncatedExponential(gen, rate, 0.0, 1.0); + } + } +} + +void OrdinalSampler::UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome, + double alpha_gamma, double beta_gamma, + double gamma_0, std::mt19937& gen) { + // Get auxiliary data vectors + std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's + const std::vector& Z = dataset.GetAuxiliaryDataVector(0); // latent variables z_i's + const std::vector& lambda_hat = dataset.GetAuxiliaryDataVector(1); // forest predictions: lambda_hat_i = sum_t lambda_t(x_i) + + int K = gamma.size() + 1; // Number of ordinal categories + int N = dataset.NumObservations(); + + // Compute sufficient statistics A[k] and B[k] for gamma[k] update + std::vector A(K - 1, 0.0); + std::vector B(K - 1, 0.0); + + for (int i = 0; i < N; i++) { + int y = static_cast(outcome(i)); + if (y < K - 1) { + A[y] += 1.0; + B[y] += Z[i] * std::exp(lambda_hat[i]); + } + for (int k = 0; k < y; k++) { + B[k] += std::exp(lambda_hat[i]); + } + } + + // Update gamma parameters using log-gamma sampling + // First sample all gamma parameters + for (int k = 0; k < static_cast(gamma.size()); k++) { + double shape = A[k] + alpha_gamma; + double rate = B[k] + beta_gamma; + double gamma_sample = gamma_sampler_.Sample(shape, rate, gen); + gamma[k] = std::log(gamma_sample); + } + + // Set the first gamma parameter to gamma_0 (e.g., 0) for identifiability + // if (K > 2) { + gamma[0] = gamma_0; + // } +} + +void OrdinalSampler::UpdateCumulativeExpSums(ForestDataset& dataset) { + // Get auxiliary data vectors + const std::vector& gamma = dataset.GetAuxiliaryDataVector(2); // cutpoints gamma_k's + std::vector& seg = dataset.GetAuxiliaryDataVector(3); // seg_k = sum_{j=0}^{k-1} exp(gamma_j) + + // Update seg (sum of exponentials of gamma cutpoints) + for (int j = 0; j < static_cast(seg.size()); j++) { + if (j == 0) { + seg[j] = 0.0; // checked and it is correct + } else { + seg[j] = seg[j - 1] + std::exp(gamma[j - 1]); // checked and it is correct + } + } +} + +} // namespace StochTree diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 9d643380..73b37fe8 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -28,15 +28,15 @@ void ForestTracker::ReconstituteFromForest(TreeEnsemble& forest, ForestDataset& // (1) Updates the residual by adding currently cached tree predictions and subtracting predictions from new tree // (2) Updates sample_node_mapper_, sample_pred_mapper_, and sum_predictions_ based on the new forest UpdateSampleTrackersResidual(forest, dataset, residual, is_mean_model); - + // Since GFR always starts over from root, this data structure can always simply be reset Eigen::MatrixXd& covariates = dataset.GetCovariates(); sorted_node_sample_tracker_.reset(new SortedNodeSampleTracker(presort_container_.get(), covariates, feature_types_)); - + // Reconstitute each of the remaining data structures in the tracker based on splits in the ensemble // UnsortedNodeSampleTracker unsorted_node_sample_tracker_->ReconstituteFromForest(forest, dataset); - + } void ForestTracker::ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { @@ -156,7 +156,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int j = 0; j < num_trees_; j++) { // Query the previously cached prediction for tree j, observation i prev_tree_pred = sample_pred_mapper_->GetPred(i, j); - + // Compute the new prediction for tree j, observation i new_tree_pred = 0.0; Tree* tree = forest.GetTree(j); @@ -164,7 +164,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalBasis(TreeEnsemble& fore for (int32_t k = 0; k < output_dim; k++) { new_tree_pred += tree->LeafValue(nidx, k) * basis(i, k); } - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -202,7 +202,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo Tree* tree = forest.GetTree(j); std::int32_t nidx = EvaluateTree(*tree, covariates, i); new_tree_pred = tree->LeafValue(nidx, 0); - + if (is_mean_model) { // Adjust the residual by adding the previous prediction and subtracting the new prediction new_resid = residual.GetElement(i) - new_tree_pred + prev_tree_pred; @@ -211,7 +211,7 @@ void ForestTracker::UpdateSampleTrackersResidualInternalNoBasis(TreeEnsemble& fo new_weight = std::log(dataset.VarWeightValue(i)) + new_tree_pred - prev_tree_pred; dataset.SetVarWeightValue(i, new_weight, true); } - + // Update the sample node mapper and sample prediction mapper sample_node_mapper_->SetNodeId(i, j, nidx); sample_pred_mapper_->SetPred(i, j, new_tree_pred); @@ -346,21 +346,21 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d CHECK_EQ(num_deleted_nodes_, 0); data_size_t n = dataset.NumObservations(); CHECK_EQ(indices_.size(), n); - + // Extract covariates Eigen::MatrixXd& covariates = dataset.GetCovariates(); // Set node counters num_nodes_ = tree.NumNodes(); num_deleted_nodes_ = tree.NumDeletedNodes(); - + // Resize tracking vectors node_begin_.resize(num_nodes_); node_length_.resize(num_nodes_); parent_nodes_.resize(num_nodes_); left_nodes_.resize(num_nodes_); right_nodes_.resize(num_nodes_); - + // Unpack tree's splits into this data structure bool is_deleted; TreeNodeType node_type; @@ -399,11 +399,11 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d } else { continue; } - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[i]); auto node_end = (indices_.begin() + node_begin_[i] + node_length_[i]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split_rule.SplitTrue(covariates(row, split_index)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[i]); num_true = std::distance(node_begin, right_node_begin); @@ -415,7 +415,7 @@ void FeatureUnsortedPartition::ReconstituteFromTree(Tree& tree, ForestDataset& d parent_nodes_[left_nodes_[i]] = i; left_nodes_[left_nodes_[i]] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_nodes_[i]] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information node_begin_[right_nodes_[i]] = node_start_idx + num_true; node_length_[right_nodes_[i]] = num_false; @@ -455,11 +455,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_split)); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -474,11 +474,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, split_value); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -493,11 +493,11 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; - // Partition the node indices + // Partition the node indices auto node_begin = (indices_.begin() + node_begin_[node_id]); auto node_end = (indices_.begin() + node_begin_[node_id] + node_length_[node_id]); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_split, category_list); }); - + // Determine the number of true and false elements node_begin = (indices_.begin() + node_begin_[node_id]); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -536,7 +536,7 @@ void FeatureUnsortedPartition::ExpandNodeTrackingVectors(int node_id, int left_n parent_nodes_[left_node_id] = node_id; left_nodes_[left_node_id] = StochTree::Tree::kInvalidNodeId; left_nodes_[right_node_id] = StochTree::Tree::kInvalidNodeId; - + // Add right node tracking information right_nodes_[node_id] = right_node_id; node_begin_[right_node_id] = node_start_idx + num_left; @@ -578,7 +578,7 @@ bool FeatureUnsortedPartition::RightNodeIsLeaf(int node_id) { } void FeatureUnsortedPartition::PruneNodeToLeaf(int node_id) { - // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices + // No need to "un-sift" the indices in the newly pruned node, we don't depend on the indices // having any type of sort order, so the indices will simply be "re-sifted" if the node is later partitioned if (IsLeaf(node_id)) return; if (!LeftNodeIsLeaf(node_id)) { @@ -614,7 +614,7 @@ std::vector FeatureUnsortedPartition::NodeIndices(int node_id) { void FeaturePresortPartition::AddLeftRightNodes(data_size_t left_node_begin, data_size_t left_node_size, data_size_t right_node_begin, data_size_t right_node_size) { // Assumes that we aren't pruning / deleting nodes, since this is for use with recursive algorithms - + // Add the left ("true") node to the offset size vector node_offset_sizes_.emplace_back(left_node_begin, left_node_size); // Add the right ("false") node to the offset size vector @@ -627,11 +627,11 @@ void FeaturePresortPartition::SplitFeature(Eigen::MatrixXd& covariates, int32_t data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return split.SplitTrue(covariates(row, feature_index)); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -645,11 +645,11 @@ void FeaturePresortPartition::SplitFeatureNumeric(Eigen::MatrixXd& covariates, i data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, split_value); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); @@ -663,11 +663,11 @@ void FeaturePresortPartition::SplitFeatureCategorical(Eigen::MatrixXd& covariate data_size_t node_end_idx = NodeEnd(node_id); data_size_t num_node_elements = NodeSize(node_id); - // Partition the node indices + // Partition the node indices auto node_begin = (feature_sort_indices_.begin() + node_start_idx); auto node_end = (feature_sort_indices_.begin() + node_end_idx); auto right_node_begin = std::stable_partition(node_begin, node_end, [&](int row) { return RowSplitLeft(covariates, row, feature_index, category_list); }); - + // Add the left and right nodes to the offset size vector node_begin = (feature_sort_indices_.begin() + node_start_idx); data_size_t num_true = std::distance(node_begin, right_node_begin); diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 66621d52..284fd211 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -6,12 +6,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) @@ -141,6 +143,42 @@ class ForestDatasetCpp { return dataset_.get(); } + bool HasAuxiliaryDimension(int dim_idx) { + return dataset_->HasAuxiliaryDimension(dim_idx); + } + + void AddAuxiliaryDimension(int dim_size) { + dataset_->AddAuxiliaryDimension(dim_size); + } + + double GetAuxiliaryDataValue(int dim_idx, data_size_t element_idx) { + return dataset_->GetAuxiliaryDataValue(dim_idx, element_idx); + } + + void SetAuxiliaryDataValue(int dim_idx, data_size_t element_idx, double value) { + dataset_->SetAuxiliaryDataValue(dim_idx, element_idx, value); + } + + py::array_t GetAuxiliaryDataArray(int dim_idx) { + std::vector output_vec = dataset_->GetAuxiliaryDataVector(dim_idx); + int n = output_vec.size(); + auto result = py::array_t(py::detail::any_container({n})); + auto accessor = result.mutable_unchecked<1>(); + for (size_t i = 0; i < n; i++) { + accessor(i) = output_vec[i]; + } + return result; + } + + void StoreAuxiliaryDataArrayMatrix(py::array_t output_matrix, int dim_idx, int matrix_col_idx) { + const std::vector output_raw = dataset_->GetAuxiliaryDataVector(dim_idx); + int n = output_raw.size(); + auto accessor = output_matrix.mutable_unchecked<2>(); + for (int i = 0; i < n; i++) { + accessor(i, matrix_col_idx) = output_raw[i]; + } + } + private: std::unique_ptr dataset_; }; @@ -1104,6 +1142,8 @@ class ForestSamplerCpp { else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; + else StochTree::Log::Fatal("Invalid model type"); // Unpack leaf model parameters double leaf_scale; @@ -1147,6 +1187,8 @@ class ForestSamplerCpp { StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::GFRSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); } } else { if (model_type == StochTree::ModelType::kConstantLeafGaussian) { @@ -1157,6 +1199,8 @@ class ForestSamplerCpp { StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest_ptr, *(tracker_.get()), *forest_sample_ptr, std::get(leaf_model), *forest_data_ptr, *residual_data_ptr, *(split_prior_.get()), *rng_ptr, var_weights_vector, sweep_update_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } } } @@ -1169,6 +1213,7 @@ class ForestSamplerCpp { else if (leaf_model_int == 1) model_type = StochTree::ModelType::kUnivariateRegressionLeafGaussian; else if (leaf_model_int == 2) model_type = StochTree::ModelType::kMultivariateRegressionLeafGaussian; else if (leaf_model_int == 3) model_type = StochTree::ModelType::kLogLinearVariance; + else if (leaf_model_int == 4) model_type = StochTree::ModelType::kCloglogOrdinal; else StochTree::Log::Fatal("Invalid model type"); // Unpack initial value @@ -1213,6 +1258,10 @@ class ForestSamplerCpp { int n = forest_data_ptr->NumObservations(); std::vector initial_preds(n, init_val); forest_data_ptr->AddVarianceWeights(initial_preds.data(), n); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + leaf_init_val = std::log(init_val) / static_cast(num_trees); + forest_ptr->SetLeafValue(leaf_init_val); + tracker_->UpdatePredictions(forest_ptr, *forest_data_ptr); } } @@ -1275,6 +1324,46 @@ class ForestSamplerCpp { std::unique_ptr split_prior_; }; +class OrdinalSamplerCpp { + public: + OrdinalSamplerCpp() { + // Initialize pointer to C++ OrdinalSampler classes + ordinal_sampler_ = std::make_unique(); + } + ~OrdinalSamplerCpp() {} + + double SampleTruncatedExponential(RngCpp& rng, double rate, double lower_bound = 0.0, double upper_bound = 1.0) { + std::mt19937* rng_ptr = rng.GetRng(); + return ordinal_sampler_->SampleTruncatedExponential(*rng_ptr, rate, lower_bound, upper_bound); + } + + void UpdateLatentVariables(ForestDatasetCpp& dataset, ResidualCpp& residual, RngCpp& rng) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + StochTree::ColumnVector* residual_ptr = residual.GetData(); + Eigen::VectorXd& residual_data_eigen = residual_ptr->GetData(); + std::mt19937* rng_ptr = rng.GetRng(); + ordinal_sampler_->UpdateLatentVariables(*dataset_ptr, residual_data_eigen, *rng_ptr); + } + + void UpdateGammaParams(ForestDatasetCpp& dataset, ResidualCpp& residual, + double alpha_gamma, double beta_gamma, + double gamma_0, RngCpp& rng) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + StochTree::ColumnVector* residual_ptr = residual.GetData(); + Eigen::VectorXd& residual_data_eigen = residual_ptr->GetData(); + std::mt19937* rng_ptr = rng.GetRng(); + ordinal_sampler_->UpdateGammaParams(*dataset_ptr, residual_data_eigen, alpha_gamma, beta_gamma, gamma_0, *rng_ptr); + } + + void UpdateCumulativeExpSums(ForestDatasetCpp& dataset) { + StochTree::ForestDataset* dataset_ptr = dataset.GetDataset(); + ordinal_sampler_->UpdateCumulativeExpSums(*dataset_ptr); + } + + private: + std::unique_ptr ordinal_sampler_; +}; + class GlobalVarianceModelCpp { public: GlobalVarianceModelCpp() { @@ -2148,7 +2237,13 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("GetBasis", &ForestDatasetCpp::GetBasis) .def("GetVarianceWeights", &ForestDatasetCpp::GetVarianceWeights) .def("HasBasis", &ForestDatasetCpp::HasBasis) - .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights); + .def("HasVarianceWeights", &ForestDatasetCpp::HasVarianceWeights) + .def("HasAuxiliaryDimension", &ForestDatasetCpp::HasAuxiliaryDimension) + .def("AddAuxiliaryDimension", &ForestDatasetCpp::AddAuxiliaryDimension) + .def("GetAuxiliaryDataValue", &ForestDatasetCpp::GetAuxiliaryDataValue) + .def("SetAuxiliaryDataValue", &ForestDatasetCpp::SetAuxiliaryDataValue) + .def("GetAuxiliaryDataArray", &ForestDatasetCpp::GetAuxiliaryDataArray) + .def("StoreAuxiliaryDataArrayMatrix", &ForestDatasetCpp::StoreAuxiliaryDataArrayMatrix); py::class_(m, "ResidualCpp") .def(py::init,data_size_t>()) @@ -2339,6 +2434,13 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def(py::init<>()) .def("SampleOneIteration", &LeafVarianceModelCpp::SampleOneIteration); + py::class_(m, "OrdinalSamplerCpp") + .def(py::init<>()) + .def("SampleTruncatedExponential", &OrdinalSamplerCpp::SampleTruncatedExponential) + .def("UpdateLatentVariables", &OrdinalSamplerCpp::UpdateLatentVariables) + .def("UpdateGammaParams", &OrdinalSamplerCpp::UpdateGammaParams) + .def("UpdateCumulativeExpSums", &OrdinalSamplerCpp::UpdateCumulativeExpSums); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/sampler.cpp b/src/sampler.cpp index 212ccb42..f356d968 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -4,38 +4,39 @@ #include #include #include +#include #include #include #include #include [[cpp11::register]] -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, + double global_variance, int leaf_model_int, bool keep_forest, int num_features_subsample, int num_threads ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_(sweep_indices.size()); // if (sweep_indices.size() > 0) { @@ -44,19 +45,20 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads); @@ -89,35 +91,37 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, true, num_features_subsample, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::GFRSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, feature_types_, cutpoint_grid_size, keep_forest, pre_initialized, false, num_features_subsample, num_threads); } } [[cpp11::register]] -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, - cpp11::external_pointer residual, - cpp11::external_pointer forest_samples, - cpp11::external_pointer active_forest, - cpp11::external_pointer tracker, - cpp11::external_pointer split_prior, - cpp11::external_pointer rng, - cpp11::integers sweep_indices, - cpp11::integers feature_types, int cutpoint_grid_size, - cpp11::doubles_matrix<> leaf_model_scale_input, - cpp11::doubles variable_weights, +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, + cpp11::external_pointer residual, + cpp11::external_pointer forest_samples, + cpp11::external_pointer active_forest, + cpp11::external_pointer tracker, + cpp11::external_pointer split_prior, + cpp11::external_pointer rng, + cpp11::integers sweep_indices, + cpp11::integers feature_types, int cutpoint_grid_size, + cpp11::doubles_matrix<> leaf_model_scale_input, + cpp11::doubles variable_weights, double a_forest, double b_forest, - double global_variance, int leaf_model_int, + double global_variance, int leaf_model_int, bool keep_forest, int num_threads ) { // Refactoring completely out of the R interface. // Intention to refactor out of the C++ interface in the future. bool pre_initialized = true; - + // Unpack feature types std::vector feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Unpack sweep indices std::vector sweep_indices_; if (sweep_indices.size() > 0) { @@ -126,19 +130,20 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer var_weights_vector(variable_weights.size()); for (int i = 0; i < variable_weights.size(); i++) { var_weights_vector[i] = variable_weights[i]; } - + // Prepare the samplers StochTree::LeafModelVariant leaf_model = StochTree::leafModelFactory(model_type, leaf_scale, leaf_scale_matrix, a_forest, b_forest); int num_basis = data->NumBasis(); - + // Run one iteration of the sampler if (model_type == StochTree::ModelType::kConstantLeafGaussian) { StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads); @@ -171,13 +176,15 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, true, num_threads, num_basis); } else if (model_type == StochTree::ModelType::kLogLinearVariance) { StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); + } else if (model_type == StochTree::ModelType::kCloglogOrdinal) { + StochTree::MCMCSampleOneIter(*active_forest, *tracker, *forest_samples, std::get(leaf_model), *data, *residual, *split_prior, *rng, var_weights_vector, sweep_indices_, global_variance, keep_forest, pre_initialized, false, num_threads); } } [[cpp11::register]] -double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, - cpp11::external_pointer dataset, - cpp11::external_pointer rng, +double sample_sigma2_one_iteration_cpp(cpp11::external_pointer residual, + cpp11::external_pointer dataset, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -190,8 +197,8 @@ double sample_sigma2_one_iteration_cpp(cpp11::external_pointer active_forest, - cpp11::external_pointer rng, +double sample_tau_one_iteration_cpp(cpp11::external_pointer active_forest, + cpp11::external_pointer rng, double a, double b ) { // Run one iteration of the sampler @@ -208,7 +215,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { } else { rng_ = std::make_unique(random_seed); } - + // Release management of the pointer to R session return cpp11::external_pointer(rng_.release()); } @@ -217,7 +224,7 @@ cpp11::external_pointer rng_cpp(int random_seed = -1) { cpp11::external_pointer tree_prior_cpp(double alpha, double beta, int min_samples_leaf, int max_depth = -1) { // Create smart pointer to newly allocated object std::unique_ptr prior_ptr_ = std::make_unique(alpha, beta, min_samples_leaf, max_depth); - + // Release management of the pointer to R session return cpp11::external_pointer(prior_ptr_.release()); } @@ -274,10 +281,10 @@ cpp11::external_pointer forest_tracker_cpp(cpp11::exte for (int i = 0; i < feature_types.size(); i++) { feature_types_[i] = static_cast(feature_types[i]); } - + // Create smart pointer to newly allocated object std::unique_ptr tracker_ptr_ = std::make_unique(data->GetCovariates(), feature_types_, num_trees, n); - + // Release management of the pointer to R session return cpp11::external_pointer(tracker_ptr_.release()); } @@ -294,8 +301,8 @@ cpp11::writable::doubles get_cached_forest_predictions_cpp(cpp11::external_point [[cpp11::register]] cpp11::writable::integers sample_without_replacement_integer_cpp( - cpp11::integers population_vector, - cpp11::doubles sampling_probs, + cpp11::integers population_vector, + cpp11::doubles sampling_probs, int sample_size ) { // Unpack pointer to population vector @@ -307,14 +314,14 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Create output vector cpp11::writable::integers output(sample_size); - + // Unpack pointer to output vector int* output_ptr = INTEGER(PROTECT(output)); // Create C++ RNG std::random_device rd; std::mt19937 gen(rd()); - + // Run the sampler StochTree::sample_without_replacement( output_ptr, sampling_probs_ptr, population_vector_ptr, population_size, sample_size, gen @@ -326,3 +333,40 @@ cpp11::writable::integers sample_without_replacement_integer_cpp( // Return result return(output); } + +[[cpp11::register]] +cpp11::external_pointer ordinal_sampler_cpp() { + std::unique_ptr sampler_ptr = std::make_unique(); + return cpp11::external_pointer(sampler_ptr.release()); +} + +[[cpp11::register]] +void ordinal_sampler_update_latent_variables_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateLatentVariables(*data_ptr, outcome_ptr->GetData(), *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_gamma_params_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr, + cpp11::external_pointer outcome_ptr, + double alpha_gamma, + double beta_gamma, + double gamma_0, + cpp11::external_pointer rng_ptr +) { + sampler_ptr->UpdateGammaParams(*data_ptr, outcome_ptr->GetData(), alpha_gamma, beta_gamma, gamma_0, *rng_ptr); +} + +[[cpp11::register]] +void ordinal_sampler_update_cumsum_exp_cpp( + cpp11::external_pointer sampler_ptr, + cpp11::external_pointer data_ptr +) { + sampler_ptr->UpdateCumulativeExpSums(*data_ptr); +} diff --git a/src/stochtree_types.h b/src/stochtree_types.h index d3d6327c..9f4e77df 100644 --- a/src/stochtree_types.h +++ b/src/stochtree_types.h @@ -1,8 +1,10 @@ #include #include #include +#include #include #include +#include #include #include #include diff --git a/stochtree/__init__.py b/stochtree/__init__.py index 318f6219..0b71557c 100644 --- a/stochtree/__init__.py +++ b/stochtree/__init__.py @@ -1,6 +1,7 @@ from .bart import BARTModel from .bcf import BCFModel from .calibration import calibrate_global_error_variance +from .cloglog_ordinal_bart import CloglogOrdinalBARTModel from .config import ForestModelConfig, GlobalModelConfig from .data import Dataset, Residual from .forest import Forest, ForestContainer @@ -39,6 +40,7 @@ __all__ = [ "BARTModel", "BCFModel", + "CloglogOrdinalBARTModel", "Dataset", "Residual", "ForestContainer", diff --git a/stochtree/cloglog_ordinal_bart.py b/stochtree/cloglog_ordinal_bart.py new file mode 100644 index 00000000..5137bf24 --- /dev/null +++ b/stochtree/cloglog_ordinal_bart.py @@ -0,0 +1,698 @@ +import warnings +from math import log +from numbers import Integral +from typing import Any, Dict, Optional, Union + +import numpy as np +import pandas as pd +from scipy.stats import norm + +from stochtree_cpp import OrdinalSamplerCpp +from .config import ForestModelConfig, GlobalModelConfig +from .data import Dataset, Residual +from .forest import Forest, ForestContainer +from .preprocessing import CovariatePreprocessor, _preprocess_params +from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel +from .serialization import JSONSerializer +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, +) + + +class CloglogOrdinalBARTModel: + r""" + Class that handles sampling, storage, and serialization of BART models with a cloglog link for ordinal outcomes. + This is an implementation of the model of Alam and Linero (2025), in which y is an ordinal outcome with K categories, ordered from 0 to K-1. + """ + + def __init__(self) -> None: + # Internal flag for whether the sample() method has been run + self.sampled = False + + def sample( + self, + X_train: Union[np.array, pd.DataFrame], + y_train: np.array, + X_test: Union[np.array, pd.DataFrame] = None, + n_trees: int = 50, + num_gfr: int = 0, + num_burnin: int = 1000, + num_mcmc: int = 500, + n_thin: int = 1, + alpha_gamma: float = 2.0, + beta_gamma: float = 2.0, + variable_weights: np.array = None, + feature_types: np.array = None, + seed: int = None, + num_threads=1, + ) -> None: + """Runs a Cloglog BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. + + Parameters + ---------- + X_train : np.array + Training set covariates on which trees may be partitioned. + y_train : np.array + Training set outcome (must be integer-valued from 0 to K-1, where K is the number of outcome categories). + X_test : np.array, optional + Optional test set covariates. + n_trees : int, optional + Number of trees in the BART ensemble. Defaults to `50`. + num_gfr : int, optional + Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Defaults to `0`. + num_burnin : int, optional + Number of "burn-in" iterations of the MCMC sampler. Defaults to `1000`. + num_mcmc : int, optional + Number of "retained" iterations of the MCMC sampler. Defaults to `500`. + n_thin : int, optional + Thinning interval for MCMC samples. Defaults to `1` (no thinning). + alpha_gamma : float, optional + Shape parameter for the log-gamma prior on cutpoints. Defaults to `2.0`. + beta_gamma : float, optional + Rate parameter for the log-gamma prior on cutpoints. Defaults to `2.0`. + variable_weights : np.array, optional + Variable weights for covariate selection probabilities. If `None`, uniform weights are used. + seed : int, optional + Random seed for reproducibility. If `None`, a random seed is used. + num_threads : int, optional + Number of threads to use for parallel processing. Defaults to `1`. + + Returns + ------- + self : BARTModel + Sampled BART Model. + """ + # Check data inputs + if not isinstance(X_train, pd.DataFrame) and not isinstance( + X_train, np.ndarray + ): + raise ValueError("X_train must be a pandas dataframe or numpy array") + if X_test is not None: + if not isinstance(X_test, pd.DataFrame) and not isinstance( + X_test, np.ndarray + ): + raise ValueError("X_test must be a pandas dataframe or numpy array") + if not isinstance(y_train, np.ndarray): + raise ValueError("y_train must be a numpy array") + if y_train.dtype not in [np.int32, np.int64]: + raise ValueError("y_train must be an integer-valued numpy array") + if np.any(y_train < 0): + raise ValueError("y_train must be non-negative integer-valued") + + # Convert everything to standard shape (2-dimensional) + if isinstance(X_train, np.ndarray): + if X_train.ndim == 1: + X_train = np.expand_dims(X_train, 1) + if y_train.ndim == 1: + y_train = np.expand_dims(y_train, 1) + if X_test is not None: + if isinstance(X_test, np.ndarray): + if X_test.ndim == 1: + X_test = np.expand_dims(X_test, 1) + + # Data checks + if X_test is not None: + if X_test.shape[1] != X_train.shape[1]: + raise ValueError( + "X_train and X_test must have the same number of columns" + ) + if y_train.shape[0] != X_train.shape[0]: + raise ValueError("X_train and y_train must have the same number of rows") + + # Variable weight preprocessing (and initialization if necessary) + p = X_train.shape[1] + if variable_weights is None: + if X_train.ndim > 1: + variable_weights = np.repeat(1.0 / p, p) + else: + variable_weights = np.repeat(1.0, 1) + if np.any(variable_weights < 0): + raise ValueError("variable_weights cannot have any negative weights") + + # Covariate preprocessing + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.fit(X_train) + X_train_processed = self._covariate_preprocessor.transform(X_train) + if X_test is not None: + X_test_processed = self._covariate_preprocessor.transform(X_test) + feature_types = np.asarray( + self._covariate_preprocessor._processed_feature_types + ) + original_var_indices = ( + self._covariate_preprocessor.fetch_original_feature_indices() + ) + + # Update variable weights if the covariates have been resized (by e.g. one-hot encoding) + if X_train_processed.shape[1] != X_train.shape[1]: + variable_counts = [ + original_var_indices.count(i) for i in original_var_indices + ] + variable_weights_adj = np.array([1 / i for i in variable_counts]) + variable_weights = ( + variable_weights[original_var_indices] * variable_weights_adj + ) + + # Determine whether a test set is provided + self.has_test = X_test is not None + + # Unpack data dimensions + self.n_train = y_train.shape[0] + self.n_test = X_test_processed.shape[0] if self.has_test else 0 + self.num_covariates = X_train_processed.shape[1] + + # Determine number of outcome categories + self.n_levels = np.max(np.unique(np.squeeze(y_train))) + 1 + + # Check that there are at least 2 outcome categories + if self.n_levels < 2: + raise ValueError("y_train must have at least 2 outcome categories") + + # BART parameters + alpha_bart = 0.95 + beta_bart = 2 + min_samples_in_leaf = 5 + max_depth = 10 + scale_leaf = 2 / np.sqrt(n_trees) + cutpoint_grid_size = 100 + + # Fixed for identifiability (can be pass as argument later if desired) + gamma_0 = 0.0 # First gamma cutpoint fixed at gamma_0 = 0 + + # Indices of MCMC samples to keep after GFR, burn-in, and thinning + keep_idx = np.arange( + num_gfr + num_burnin, num_gfr + num_burnin + num_mcmc, n_thin + ) + n_keep = len(keep_idx) + + # Container of parameter samples / model draws + self.num_gfr = num_gfr + self.num_burnin = num_burnin + self.num_mcmc = num_mcmc + self.forest_pred_train = np.empty((self.n_train, n_keep), dtype=np.float64) + if self.has_test: + self.forest_pred_test = np.empty((self.n_test, n_keep), dtype=np.float64) + self.gamma_samples = np.empty((self.n_levels - 1, n_keep), dtype=np.float64) + self.latent_samples = np.empty((self.n_train, n_keep), dtype=np.float64) + + # Initialize samplers + ordinal_sampler_cpp = OrdinalSamplerCpp() + if seed is None: + cpp_rng = RNG(-1) + self.rng = np.random.default_rng() + else: + cpp_rng = RNG(seed) + self.rng = np.random.default_rng(seed) + + # Data structures + forest_dataset_train = Dataset() + forest_dataset_train.add_covariates(X_train_processed) + if self.has_test: + forest_dataset_test = Dataset() + forest_dataset_test.add_covariates(X_test_processed) + outcome_train = Residual(y_train) + active_forest = Forest(n_trees, 1, True, False) + active_forest.set_root_leaves(0.0) + self.forest_samples = ForestContainer(n_trees, 1, True, False) + global_model_config = GlobalModelConfig(global_error_variance=1.0) + forest_model_config = ForestModelConfig( + num_trees=n_trees, + num_features=self.num_covariates, + num_observations=self.n_train, + feature_types=feature_types, + variable_weights=variable_weights, + leaf_dimension=1, + alpha=alpha_bart, + beta=beta_bart, + min_samples_leaf=min_samples_in_leaf, + max_depth=max_depth, + leaf_model_type=4, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_scale=scale_leaf, + ) + forest_sampler = ForestSampler( + forest_dataset_train, global_model_config, forest_model_config + ) + + # Latent variable (Z in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_train) + # Forest predictions (eta in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_train) + # Log-scale non-cumulative cutpoint (gamma in Alam et al (2025) notation) + forest_dataset_train.add_auxiliary_dimension(self.n_levels - 1) + # Exponentiated cumulative cutpoints (exp(c_k) in Alam et al (2025) notation) + # This auxiliary series is designed so that the element stored at position `i` + # corresponds to the sum of all exponentiated gamma_j values for j < i. + # It has n_levels elements instead of n_levels - 1 because even the largest + # categorical index has a valid value of sum_{j < i} exp(gamma_j) + forest_dataset_train.add_auxiliary_dimension(self.n_levels) + + # Initialize gamma parameters to zero (3rd auxiliary data series, mapped to `dim_idx = 2` with 0-indexing) + initial_gamma = np.zeros((self.n_levels - 1,), dtype=np.float64) + for i in range(self.n_levels - 1): + forest_dataset_train.set_auxiliary_data_value(2, i - 1, initial_gamma[i]) + + # Convert the log-scale parameters into cumulative exponentiated parameters. + # This is done under the hood in a C++ function for efficiency. + ordinal_sampler_cpp.UpdateCumulativeExpSums(forest_dataset_train.dataset_cpp) + + # Initialize forest predictions to zero (slot 1) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value(1, i, 0.0) + + # Initialize latent variables to zero (slot 0) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value(0, i, 0.0) + + # Run the algorithm + sample_counter = -1 + for i in range(num_gfr + num_burnin + num_mcmc): + keep_sample = i in keep_idx + if keep_sample: + sample_counter += 1 + + # 1. Sample forest using MCMC + if i > self.num_gfr - 1: + forest_sampler.sample_one_iteration( + self.forest_samples, + active_forest, + forest_dataset_train, + outcome_train, + cpp_rng, + global_model_config, + forest_model_config, + keep_sample, + True, + num_threads, + ) + else: + forest_sampler.sample_one_iteration( + self.forest_samples, + active_forest, + forest_dataset_train, + outcome_train, + cpp_rng, + global_model_config, + forest_model_config, + keep_sample, + False, + num_threads, + ) + + # Set auxiliary data slot 1 to current forest predictions = lambda_hat = sum of all the tree predictions + # This is needed for updating gamma parameters, latent z_i's + forest_pred_current = active_forest.predict(forest_dataset_train) + for i in range(self.n_train): + forest_dataset_train.set_auxiliary_data_value( + 1, i, forest_pred_current[i] + ) + + # 2. Sample latent z_i's using truncated exponential + ordinal_sampler_cpp.UpdateLatentVariables( + forest_dataset_train.dataset_cpp, + outcome_train.residual_cpp, + cpp_rng.rng_cpp, + ) + + # 3. Sample gamma cutpoints + ordinal_sampler_cpp.UpdateGammaParams( + forest_dataset_train.dataset_cpp, + outcome_train.residual_cpp, + alpha_gamma, + beta_gamma, + gamma_0, + cpp_rng.rng_cpp, + ) + + # 4. Update cumulative sum of exp(gamma) values + ordinal_sampler_cpp.UpdateCumulativeExpSums( + forest_dataset_train.dataset_cpp + ) + + if keep_sample: + self.forest_pred_train[:, sample_counter] = active_forest.predict( + forest_dataset_train + ) + if self.has_test: + self.forest_pred_test[:, sample_counter] = active_forest.predict( + forest_dataset_test + ) + gamma_current = forest_dataset_train.get_auxiliary_data_array(2) + self.gamma_samples[:, sample_counter] = gamma_current + latent_current = forest_dataset_train.get_auxiliary_data_array(0) + self.latent_samples[:, sample_counter] = latent_current + + # Mark the model as sampled + self.sampled = True + + def predict( + self, + X: Union[np.array, pd.DataFrame], + ) -> np.array: + """Return predictions from the cloglog forest. + + Parameters + ---------- + covariates : np.array + Test set covariates. + + Returns + ------- + lambda_x : np.array, optional + Cloglog forest predictions + """ + if not self.is_sampled(): + msg = ( + "This CloglogOrdinalBARTModel instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this model." + ) + raise NotSampledError(msg) + + # Data checks + if not isinstance(X, pd.DataFrame) and not isinstance(X, np.ndarray): + raise ValueError("X must be a pandas dataframe or numpy array") + + # Convert everything to standard shape (2-dimensional) + if isinstance(X, np.ndarray): + if X.ndim == 1: + X = np.expand_dims(X, 1) + + # Covariate preprocessing + if not self._covariate_preprocessor._check_is_fitted(): + if not isinstance(X, np.ndarray): + raise ValueError( + "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + ) + else: + warnings.warn( + "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", + RuntimeWarning, + ) + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): + raise ValueError( + "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." + ) + X_processed = X + else: + X_processed = self._covariate_preprocessor.transform(X) + + # Dataset construction + pred_dataset = Dataset() + pred_dataset.add_covariates(X_processed) + + # Forest predictions + forest_pred = self.forest_samples.forest_container_cpp.Predict( + pred_dataset.dataset_cpp + ) + + return forest_pred + + # def to_json(self) -> str: + # """ + # Converts a sampled BART model to JSON string representation (which can then be saved to a file or + # processed using the `json` library) + + # Returns + # ------- + # str + # JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + # """ + # if not self.is_sampled: + # msg = ( + # "This BARTModel instance has not yet been sampled. " + # "Call 'fit' with appropriate arguments before using this model." + # ) + # raise NotSampledError(msg) + + # # Initialize JSONSerializer object + # bart_json = JSONSerializer() + + # # Add the forests + # if self.include_mean_forest: + # bart_json.add_forest(self.forest_container_mean) + # if self.include_variance_forest: + # bart_json.add_forest(self.forest_container_variance) + + # # Add the rfx + # if self.has_rfx: + # bart_json.add_random_effects(self.rfx_container) + + # # Add global parameters + # bart_json.add_scalar("outcome_scale", self.y_std) + # bart_json.add_scalar("outcome_mean", self.y_bar) + # bart_json.add_boolean("standardize", self.standardize) + # bart_json.add_scalar("sigma2_init", self.sigma2_init) + # bart_json.add_boolean("sample_sigma2_global", self.sample_sigma2_global) + # bart_json.add_boolean("sample_sigma2_leaf", self.sample_sigma2_leaf) + # bart_json.add_boolean("include_mean_forest", self.include_mean_forest) + # bart_json.add_boolean("include_variance_forest", self.include_variance_forest) + # bart_json.add_boolean("has_rfx", self.has_rfx) + # bart_json.add_integer("num_gfr", self.num_gfr) + # bart_json.add_integer("num_burnin", self.num_burnin) + # bart_json.add_integer("num_mcmc", self.num_mcmc) + # bart_json.add_integer("num_samples", self.num_samples) + # bart_json.add_integer("num_basis", self.num_basis) + # bart_json.add_boolean("requires_basis", self.has_basis) + # bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + + # # Add parameter samples + # if self.sample_sigma2_global: + # bart_json.add_numeric_vector( + # "sigma2_global_samples", self.global_var_samples, "parameters" + # ) + # if self.sample_sigma2_leaf: + # bart_json.add_numeric_vector( + # "sigma2_leaf_samples", self.leaf_scale_samples, "parameters" + # ) + + # # Add covariate preprocessor + # covariate_preprocessor_string = self._covariate_preprocessor.to_json() + # bart_json.add_string("covariate_preprocessor", covariate_preprocessor_string) + + # return bart_json.return_json_string() + + # def from_json(self, json_string: str) -> None: + # """ + # Converts a JSON string to an in-memory BART model. + + # Parameters + # ---------- + # json_string : str + # JSON string representing model metadata (hyperparameters), sampled parameters, and sampled forests + # """ + # # Parse string to a JSON object in C++ + # bart_json = JSONSerializer() + # bart_json.load_from_json_string(json_string) + + # # Unpack forests + # self.include_mean_forest = bart_json.get_boolean("include_mean_forest") + # self.include_variance_forest = bart_json.get_boolean("include_variance_forest") + # self.has_rfx = bart_json.get_boolean("has_rfx") + # if self.include_mean_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_mean = ForestContainer(0, 0, False, False) + # self.forest_container_mean.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_0" + # ) + # if self.include_variance_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_1" + # ) + # else: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # bart_json.json_cpp, "forest_0" + # ) + + # # Unpack random effects + # if self.has_rfx: + # self.rfx_container = RandomEffectsContainer() + # self.rfx_container.load_from_json(bart_json, 0) + + # # Unpack global parameters + # self.y_std = bart_json.get_scalar("outcome_scale") + # self.y_bar = bart_json.get_scalar("outcome_mean") + # self.standardize = bart_json.get_boolean("standardize") + # self.sigma2_init = bart_json.get_scalar("sigma2_init") + # self.sample_sigma2_global = bart_json.get_boolean("sample_sigma2_global") + # self.sample_sigma2_leaf = bart_json.get_boolean("sample_sigma2_leaf") + # self.num_gfr = bart_json.get_integer("num_gfr") + # self.num_burnin = bart_json.get_integer("num_burnin") + # self.num_mcmc = bart_json.get_integer("num_mcmc") + # self.num_samples = bart_json.get_integer("num_samples") + # self.num_basis = bart_json.get_integer("num_basis") + # self.has_basis = bart_json.get_boolean("requires_basis") + # self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") + + # # Unpack parameter samples + # if self.sample_sigma2_global: + # self.global_var_samples = bart_json.get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # if self.sample_sigma2_leaf: + # self.leaf_scale_samples = bart_json.get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + + # # Unpack covariate preprocessor + # covariate_preprocessor_string = bart_json.get_string("covariate_preprocessor") + # self._covariate_preprocessor = CovariatePreprocessor() + # self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # # Mark the deserialized model as "sampled" + # self.sampled = True + + # def from_json_string_list(self, json_string_list: list[str]) -> None: + # """ + # Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object + # which can be used for prediction, etc... + + # Parameters + # ------- + # json_string_list : list of str + # List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BART model + # """ + # # Convert strings to JSONSerializer + # json_object_list = [] + # for i in range(len(json_string_list)): + # json_string = json_string_list[i] + # json_object_list.append(JSONSerializer()) + # json_object_list[i].load_from_json_string(json_string) + + # # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + # json_object_default = json_object_list[0] + + # # Unpack forests + # self.include_mean_forest = json_object_default.get_boolean( + # "include_mean_forest" + # ) + # self.include_variance_forest = json_object_default.get_boolean( + # "include_variance_forest" + # ) + # if self.include_mean_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_mean = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_mean.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # else: + # self.forest_container_mean.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # if self.include_variance_forest: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_1" + # ) + # else: + # self.forest_container_variance.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_1" + # ) + # else: + # # TODO: don't just make this a placeholder that we overwrite + # self.forest_container_variance = ForestContainer(0, 0, False, False) + # for i in range(len(json_object_list)): + # if i == 0: + # self.forest_container_variance.forest_container_cpp.LoadFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + # else: + # self.forest_container_variance.forest_container_cpp.AppendFromJson( + # json_object_list[i].json_cpp, "forest_0" + # ) + + # # Unpack random effects + # self.has_rfx = json_object_default.get_boolean("has_rfx") + # if self.has_rfx: + # self.rfx_container = RandomEffectsContainer() + # for i in range(len(json_object_list)): + # if i == 0: + # self.rfx_container.load_from_json(json_object_list[i], 0) + # else: + # self.rfx_container.append_from_json(json_object_list[i], 0) + + # # Unpack global parameters + # self.y_std = json_object_default.get_scalar("outcome_scale") + # self.y_bar = json_object_default.get_scalar("outcome_mean") + # self.standardize = json_object_default.get_boolean("standardize") + # self.sigma2_init = json_object_default.get_scalar("sigma2_init") + # self.sample_sigma2_global = json_object_default.get_boolean( + # "sample_sigma2_global" + # ) + # self.sample_sigma2_leaf = json_object_default.get_boolean("sample_sigma2_leaf") + # self.num_gfr = json_object_default.get_integer("num_gfr") + # self.num_burnin = json_object_default.get_integer("num_burnin") + # self.num_mcmc = json_object_default.get_integer("num_mcmc") + # self.num_basis = json_object_default.get_integer("num_basis") + # self.has_basis = json_object_default.get_boolean("requires_basis") + # self.probit_outcome_model = json_object_default.get_boolean( + # "probit_outcome_model" + # ) + + # # Unpack number of samples + # for i in range(len(json_object_list)): + # if i == 0: + # self.num_samples = json_object_list[i].get_integer("num_samples") + # else: + # self.num_samples += json_object_list[i].get_integer("num_samples") + + # # Unpack parameter samples + # if self.sample_sigma2_global: + # for i in range(len(json_object_list)): + # if i == 0: + # self.global_var_samples = json_object_list[i].get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # else: + # global_var_samples = json_object_list[i].get_numeric_vector( + # "sigma2_global_samples", "parameters" + # ) + # self.global_var_samples = np.concatenate( + # (self.global_var_samples, global_var_samples) + # ) + + # if self.sample_sigma2_leaf: + # for i in range(len(json_object_list)): + # if i == 0: + # self.leaf_scale_samples = json_object_list[i].get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + # else: + # leaf_scale_samples = json_object_list[i].get_numeric_vector( + # "sigma2_leaf_samples", "parameters" + # ) + # self.leaf_scale_samples = np.concatenate( + # (self.leaf_scale_samples, leaf_scale_samples) + # ) + + # # Unpack covariate preprocessor + # covariate_preprocessor_string = json_object_default.get_string( + # "covariate_preprocessor" + # ) + # self._covariate_preprocessor = CovariatePreprocessor() + # self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # # Mark the deserialized model as "sampled" + # self.sampled = True + + def is_sampled(self) -> bool: + """Whether or not a BART model has been sampled. + + Returns + ------- + bool + `True` if a BART model has been sampled, `False` otherwise + """ + return self.sampled diff --git a/stochtree/config.py b/stochtree/config.py index 72cae512..3c9cd36b 100644 --- a/stochtree/config.py +++ b/stochtree/config.py @@ -44,7 +44,7 @@ class ForestModelConfig: max_depth : int, optional Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. leaf_model_type : int, optional - Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression, 3 = log linear variance, 4 = cloglog ordinal regression). Default: `0`. leaf_model_scale : float or np.ndarray, optional Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. variance_forest_shape : int, optional @@ -110,9 +110,9 @@ def __init__( if leaf_model_type is None: leaf_model_type = 0 if not _check_is_int(leaf_model_type): - raise ValueError("`leaf_model_type` must be an integer between 0 and 3") - elif leaf_model_type < 0 or leaf_model_type > 3: - raise ValueError("`leaf_model_type` must be an integer between 0 and 3") + raise ValueError("`leaf_model_type` must be an integer between 0 and 4") + elif leaf_model_type < 0 or leaf_model_type > 4: + raise ValueError("`leaf_model_type` must be an integer between 0 and 4") if not _check_is_int(leaf_dimension): raise ValueError("`leaf_dimension` must be an integer greater than 0") elif leaf_dimension <= 0: diff --git a/stochtree/data.py b/stochtree/data.py index 4e40a282..7668ac73 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -205,6 +205,97 @@ def has_variance_weights(self) -> bool: `True` if the dataset has variance weights, `False` otherwise """ return self.dataset_cpp.HasVarianceWeights() + + def has_auxiliary_dimension(self, dim_idx: int) -> bool: + """ + Whether or not a dataset has an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension to check + + Returns + ------- + bool + `True` if the dataset has the specified auxiliary dimension, `False` otherwise + """ + return self.dataset_cpp.HasAuxiliaryDimension(dim_idx) + + def add_auxiliary_dimension(self, dim_size: int) -> None: + """ + Add an auxiliary dimension to a dataset + + Parameters + ---------- + dim_size : int + Size of the auxiliary dimension to add + """ + self.dataset_cpp.AddAuxiliaryDimension(dim_size) + + def get_auxiliary_data_value(self, dim_idx: int, element_idx: int) -> float: + """ + Get a value from an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + element_idx : int + Index of the element within the auxiliary dimension + + Returns + ------- + float + Value at the specified index in the auxiliary dimension + """ + return self.dataset_cpp.GetAuxiliaryDataValue(dim_idx, element_idx) + + def set_auxiliary_data_value(self, dim_idx: int, element_idx: int, value: float) -> None: + """ + Set a value in an auxiliary dimension + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + element_idx : int + Index of the element within the auxiliary dimension + value : float + Value to set at the specified index in the auxiliary dimension + """ + self.dataset_cpp.SetAuxiliaryDataValue(dim_idx, element_idx, value) + + def get_auxiliary_data_array(self, dim_idx: int) -> np.array: + """ + Get an auxiliary dimension as a numpy array + + Parameters + ---------- + dim_idx : int + Index of the auxiliary dimension + + Returns + ------- + np.array + Numpy array of the specified auxiliary dimension + """ + return self.dataset_cpp.GetAuxiliaryDataArray(dim_idx) + + def store_auxiliary_data_array_matrix(self, output_matrix: np.array, dim_idx: int, matrix_col_idx: int) -> None: + """ + Store an auxiliary dimension into a specified column of a numpy matrix + + Parameters + ---------- + output_matrix : np.array + Numpy array to store the auxiliary dimension into + dim_idx : int + Index of the auxiliary dimension + matrix_col_idx : int + Column index in the output matrix to store the auxiliary dimension + """ + self.dataset_cpp.StoreAuxiliaryDataArrayMatrix(output_matrix, dim_idx, matrix_col_idx) class Residual: diff --git a/tools/debug/cloglog_ordinal_bart_binary.R b/tools/debug/cloglog_ordinal_bart_binary.R new file mode 100644 index 00000000..5f167e04 --- /dev/null +++ b/tools/debug/cloglog_ordinal_bart_binary.R @@ -0,0 +1,137 @@ +# Load library +library(stochtree) + +# Set seed +set.seed(2025) + +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(runif(n * p), ncol = p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + +# Set cutpoints for ordinal categories (2 categories: 1, 2) +n_categories <- 2 +gamma_true <- c(-1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Sample the cloglog ordinal BART model +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 1000, + n_thin = 1 +) + +# Traceplot of cutoff parameters +par(mfrow = c(2, 1)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) + +# Histogram of cutoff parameters +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) + +# Traceplots of cutoff parameters combined with average forest predictions +par(mfrow = c(2,1)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) +plot(out$gamma_samples[1,]) + +# Compare forest predictions with the truth (for training and test sets) +par(mfrow = c(2,1)) + +# Train set +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, gamma_true[1] + true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx] + mean(out$gamma_samples[1,]), gamma_true[1] + lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +# Test set +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, gamma_true[1] + true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} +# Compute average difference +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Plot estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} +# Compute average difference +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Plot estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/tools/debug/cloglog_ordinal_bart_four_category.R b/tools/debug/cloglog_ordinal_bart_four_category.R new file mode 100644 index 00000000..36ee8710 --- /dev/null +++ b/tools/debug/cloglog_ordinal_bart_four_category.R @@ -0,0 +1,156 @@ +# Load library +library(stochtree) + +# Set seed +set.seed(2025) + +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + +# Set cutpoints for ordinal categories (4 categories: 1, 2, 3, 4) +n_categories <- 4 +gamma_true <- c(-2, 0, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- apply(sapply(1:(j-1), function(k) exp(-exp(gamma_true[k] + true_lambda_function))), 1, prod) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Sample the cloglog ordinal BART model +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 1000, + n_thin = 1 +) + +# Traceplots of cutoff parameters +par(mfrow = c(2, 2)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[2], col = 'red', lty = 2) +plot(out$gamma_samples[3, ], type = 'l', main = expression(gamma[3]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[3], col = 'red', lty = 2) + +# Histograms of cutoff parameters +par(mfrow = c(2, 2)) +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) +gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) +summary(gamma2) +hist(gamma2) +gamma3 <- out$gamma_samples[3,] + colMeans(out$forest_predictions_train) +summary(gamma3) +hist(gamma3) + +# Traceplots of cutoff parameters combined with average forest predictions +par(mfrow = c(2,3)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) +plot(moo[,2]) +abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) +plot(moo[,3]) +abline(h = gamma_true[3] + mean(true_lambda_function[train_idx])) +plot(out$gamma_samples[1,]) +plot(out$gamma_samples[2,]) +plot(out$gamma_samples[3,]) + +# Compare forest predictions with the truth (for training and test sets) +par(mfrow = c(2,1)) + +# Train set +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +# Test set +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} +# Compute average difference +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Plot estimated vs true class probabilities for training set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} +# Average difference +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Plot estimated vs true class probabilities for test set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/tools/debug/cloglog_ordinal_bart_three_category.R b/tools/debug/cloglog_ordinal_bart_three_category.R new file mode 100644 index 00000000..96eba51a --- /dev/null +++ b/tools/debug/cloglog_ordinal_bart_three_category.R @@ -0,0 +1,149 @@ +# Load library +library(stochtree) + +# Set seed +set.seed(2025) + +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} +apply(true_probs, 2, mean) +summary(true_lambda_function) + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") + +# Train test split +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Sample the cloglog ordinal BART model +out <- cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 1000, + num_mcmc = 1000, + n_thin = 1 +) + +# Traceplots of cutoff parameters +par(mfrow = c(2, 1)) +plot(out$gamma_samples[1, ], type = 'l', main = expression(gamma[1]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[1], col = 'red', lty = 2) +plot(out$gamma_samples[2, ], type = 'l', main = expression(gamma[2]), ylab = "Value", xlab = "MCMC Sample") +abline(h = gamma_true[2], col = 'red', lty = 2) + +# Histograms of cutoff parameters +gamma1 <- out$gamma_samples[1,] + colMeans(out$forest_predictions_train) +summary(gamma1) +hist(gamma1) +gamma2 <- out$gamma_samples[2,] + colMeans(out$forest_predictions_train) +summary(gamma2) +hist(gamma2) + +# Traceplots of cutoff parameters combined with average forest predictions +par(mfrow = c(3,2)) +rowMeans(out$gamma_samples) +moo <- t(out$gamma_samples) + colMeans(out$forest_predictions_train) +plot(moo[,1]) +abline(h = gamma_true[1] + mean(true_lambda_function[train_idx])) +plot(moo[,2]) +abline(h = gamma_true[2] + mean(true_lambda_function[train_idx])) +plot(out$gamma_samples[1,]) +plot(out$gamma_samples[2,]) + +# Compare forest predictions with the truth (for training and test sets) +par(mfrow = c(2,1)) + +# Train set +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +# Test set +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} +# Compute average difference +mean(log(-log(1 - est_probs_train[, 1])) - rowMeans(out$forest_predictions_train)) + +# Plot estimated vs true class probabilities for training set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} +# Compute average difference +mean(log(-log(1 - est_probs_test[, 1])) - rowMeans(out$forest_predictions_test)) + +# Compare estimated vs true class probabilities for test set +par(mfrow = c(2,2)) +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} diff --git a/vignettes/CLogLogOrdinalBart.Rmd b/vignettes/CLogLogOrdinalBart.Rmd new file mode 100644 index 00000000..a87b1ebb --- /dev/null +++ b/vignettes/CLogLogOrdinalBart.Rmd @@ -0,0 +1,173 @@ +--- +title: "Complementary Log-Log Ordinal BART in StochTree" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{CLogLog-Ordinal-BART} + %\VignetteEncoding{UTF-8} + %\VignetteEngine{knitr::rmarkdown} +bibliography: vignettes.bib +editor_options: + markdown: + wrap: 72 +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +This vignette demonstrates how to use the `cloglog_ordinal_bart()` function for modeling ordinal outcomes using a complementary log-log link function in the BART (Bayesian Additive Regression Trees) framework. + +To begin, we load the `stochtree` package. + +```{r setup} +library(stochtree) +``` + +# Introduction to Ordinal BART with CLogLog Link + +Ordinal data represents outcomes that have a natural ordering but undefined distances between categories. Examples include survey responses (strongly disagree, disagree, neutral, agree, strongly agree), severity ratings (mild, moderate, severe), or educational levels (elementary, high school, college, graduate). + +The complementary log-log (CLogLog) model uses the link function: +$$\text{cloglog}(p) = \log(-\log(1-p))$$ + +This link function is asymmetric and particularly appropriate when the probability of being in higher categories changes rapidly at certain thresholds, making it different from the symmetric probit or logit links commonly used in ordinal regression. + +In the BART framework with CLogLog ordinal regression, we model: +$$P(Y = k \mid Y \geq k, X = x) = 1 - \exp\left(-e^{\gamma_k + \lambda(x)}\right)$$ + +where $\lambda(x)$ is learned by the BART ensemble and $c_k = \log \sum_{j \leq k}e^{\gamma_j}$ are the cutpoints for the ordinal categories. + +## Data Simulation + +```{r demo1-simulation} +set.seed(2025) +# Sample size and number of predictors +n <- 2000 +p <- 5 + +# Design matrix and true lambda function +X <- matrix(rnorm(n * p), n, p) +beta <- rep(1 / sqrt(p), p) +true_lambda_function <- X %*% beta + + +# Set cutpoints for ordinal categories (3 categories: 1, 2, 3) +n_categories <- 3 +gamma_true <- c(-2, 1) +ordinal_cutpoints <- log(cumsum(exp(gamma_true))) +ordinal_cutpoints + +# True ordinal class probabilities +true_probs <- matrix(0, nrow = n, ncol = n_categories) +for (j in 1:n_categories) { + if (j == 1) { + true_probs[, j] <- 1 - exp(-exp(gamma_true[j] + true_lambda_function)) + } else if (j == n_categories) { + true_probs[, j] <- 1 - rowSums(true_probs[, 1:(j - 1), drop = FALSE]) + } else { + true_probs[, j] <- exp(-exp(gamma_true[j - 1] + true_lambda_function)) * + (1 - exp(-exp(gamma_true[j] + true_lambda_function))) + } +} + +# Generate ordinal outcomes +y <- sapply(1:nrow(X), function(i) sample(1:n_categories, 1, prob = true_probs[i, ])) +cat("Outcome distribution:", table(y), "\n") +``` + +## Model Fitting + +Now let's fit the CLogLog Ordinal BART model: + +```{r demo1-model-fitting} +# Split data into train and test sets +train_idx <- sample(1:n, size = floor(0.8 * n)) +test_idx <- setdiff(1:n, train_idx) + +X_train <- X[train_idx, ] +y_train <- y[train_idx] +X_test <- X[test_idx, ] +y_test <- y[test_idx] + +# Fit CLogLog Ordinal BART model +out <- stochtree::cloglog_ordinal_bart( + X = X_train, + y = y_train, + X_test = X_test, + n_samples_mcmc = 1000, + n_burnin = 500, + n_thin = 1 +) +``` + +## Model Results and Interpretation + +Let's examine the posterior samples and model performance: + +```{r demo1-results} +# Compare forest predictions with the truth function (for training and test sets) +lambda_pred_train <- rowMeans(out$forest_predictions_train) - mean(out$forest_predictions_train) +plot(lambda_pred_train, true_lambda_function[train_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_train <- cor(true_lambda_function[train_idx], lambda_pred_train) +text(min(true_lambda_function[train_idx]), max(true_lambda_function[train_idx]), paste('Correlation:', round(cor_train, 3)), adj = 0, col = 'red') + +lambda_pred_test <- rowMeans(out$forest_predictions_test) - mean(out$forest_predictions_test) +plot(lambda_pred_test, true_lambda_function[test_idx]) +abline(a=0,b=1,col='blue', lwd=2) +cor_test <- cor(true_lambda_function[test_idx], lambda_pred_test) +text(min(true_lambda_function[test_idx]), max(true_lambda_function[test_idx]), paste('Correlation:', round(cor_test, 3)), adj = 0, col = 'red') + +# Estimated ordinal class probabilities for the training set +est_probs_train <- matrix(0, nrow=length(train_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_train[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_train[, j] <- 1 - rowSums(est_probs_train[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_train[, j] <- rowMeans(exp(-exp(out$forest_predictions_train + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_train + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for training set +for (j in 1:n_categories) { + plot(true_probs[train_idx, j], est_probs_train[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_train_prob <- cor(true_probs[train_idx, j], est_probs_train[, j]) + text(min(true_probs[train_idx, j]), max(est_probs_train[, j]), paste('Correlation:', round(cor_train_prob, 3)), adj = 0, col = 'red') +} + +# Estimated ordinal class probabilities for the test set +est_probs_test <- matrix(0, nrow=length(test_idx), ncol=n_categories) +for (j in 1:n_categories) { + if (j == 1) { + est_probs_test[, j] <- rowMeans(1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j, ]))) + } else if (j == n_categories) { + est_probs_test[, j] <- 1 - rowSums(est_probs_test[, 1:(j - 1), drop = FALSE]) + } else { + est_probs_test[, j] <- rowMeans(exp(-exp(out$forest_predictions_test + out$gamma_samples[j-1,])) * + (1 - exp(-exp(out$forest_predictions_test + out$gamma_samples[j,])))) + } +} + +# Compare estimated vs true class probabilities for test set +for (j in 1:n_categories) { + plot(true_probs[test_idx, j], est_probs_test[, j], xlab = paste("True Prob Category", j), ylab = paste("Estimated Prob Category", j)) + abline(a = 0, b = 1, col = 'blue', lwd = 2) + cor_test_prob <- cor(true_probs[test_idx, j], est_probs_test[, j]) + text(min(true_probs[test_idx, j]), max(est_probs_test[, j]), paste('Correlation:', round(cor_test_prob, 3)), adj = 0, col = 'red') +} +``` + +# Conclusion + +The CLogLog Ordinal BART model in `stochtree` provides a flexible and powerful approach for modeling ordinal outcomes, especially better suited for asymmetric outcomes: Rare events (e.g., credit default, fraud detection, system failures, adverse drug reactions), Toxic thresholds (e.g., credit risk escalation, dose-response toxicity, engagement drop-offs), Discrete survival outcomes (e.g., time-to-default, customer churn, progression-free survival). + +The CLogLog Ordinal BART implementation in `stochtree` builds on the paper by @alam2025unified. + +# References diff --git a/vignettes/vignettes.bib b/vignettes/vignettes.bib index a1b0a768..65a6f152 100644 --- a/vignettes/vignettes.bib +++ b/vignettes/vignettes.bib @@ -117,4 +117,11 @@ @book{scholkopf2002learning author={Sch{\"o}lkopf, Bernhard and Smola, Alexander J}, year={2002}, publisher={MIT press} -} \ No newline at end of file +} + +@article{alam2025unified, + title={A Unified Bayesian Nonparametric Framework for Ordinal, Survival, and Density Regression Using the Complementary Log-Log Link}, + author={Alam, Entejar and Linero, Antonio R}, + journal={arXiv preprint arXiv:2502.00606}, + year={2025} +}