From 0dd7a35b9b13f17c8fa997df3d0ea13fb382318b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 20 Apr 2025 18:17:05 -0500 Subject: [PATCH 01/14] WIP implementation of probit link for BART --- R/bart.R | 156 ++++++++++++++++++----- vignettes/BayesianSupervisedLearning.Rmd | 93 ++++++++++++++ 2 files changed, 220 insertions(+), 29 deletions(-) diff --git a/R/bart.R b/R/bart.R index 96815850..2e9db337 100644 --- a/R/bart.R +++ b/R/bart.R @@ -58,6 +58,7 @@ #' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. #' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. +#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. #' #' @param variance_forest_params (Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' @@ -125,7 +126,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train min_samples_leaf = 5, max_depth = 10, sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL, sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL + keep_vars = NULL, drop_vars = NULL, + probit_outcome_model = FALSE ) mean_forest_params_updated <- preprocessParams( mean_forest_params_default, mean_forest_params @@ -173,6 +175,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train b_leaf <- mean_forest_params_updated$sigma2_leaf_scale keep_vars_mean <- mean_forest_params_updated$keep_vars drop_vars_mean <- mean_forest_params_updated$drop_vars + probit_outcome_model <- mean_forest_params_updated$probit_outcome_model # 3. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees @@ -462,50 +465,116 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Determine whether a test set is provided has_test = !is.null(X_test) + + # Check whether outcome is 0-1 binary + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0,1)))) { + stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + } + if (include_variance_forest) { + stop("We do not support heteroskedasticity with a probit link") + } + if (sample_sigma_global) { + warning("Global error variance will not be sampled with a probit link as it is fixed at 1") + sample_sigma_global <- F + } + } - # Standardize outcome separately for test and train - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- pnorm(mean(y_train)) y_std_train <- 1 - } - resid_train <- (y_train-y_bar_train)/y_std_train - - # Compute initial value of root nodes in mean forest - init_val_mean <- mean(resid_train) - # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) - if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + # Set a pseudo outcome by subtracting mean(y_train) from y_train + current_z_train <- y_train - mean(y_train) + resid_train <- y_train - mean(y_train) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean <- 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf <- 1/(num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + } else { + current_leaf_scale <- sigma_leaf_init + } } else { - current_leaf_scale <- sigma_leaf_init + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2/(num_trees_mean)) if (!is.matrix(sigma_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) } else { current_leaf_scale <- sigma_leaf_init } } + current_sigma2 <- sigma2_init } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) - if (!is.matrix(sigma_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) } else { - current_leaf_scale <- sigma_leaf_init + y_bar_train <- 0 + y_std_train <- 1 } + + # Compute residual value + resid_train <- (y_train-y_bar_train)/y_std_train + + # Compute initial value of root nodes in mean forest + init_val_mean <- mean(resid_train) + + # Calibrate priors for sigma^2 and tau + if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) + if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) + if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + } else { + current_leaf_scale <- sigma_leaf_init + } + } else { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } + } + } else { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } + } + current_sigma2 <- sigma2_init } - current_sigma2 <- sigma2_init - + # Determine leaf model type if (!has_basis) leaf_model_mean_forest <- 0 else if (ncol(leaf_basis_train) == 1) leaf_model_mean_forest <- 1 @@ -652,6 +721,21 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + + # Sample mean forest forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, @@ -791,6 +875,20 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + forest_model_mean$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index 2b9337c3..66655f62 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -327,4 +327,97 @@ plot(rowMeans(bart_model_root$y_hat_test), y_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` +# Demo 4: Partitioned Linear Model with Probit Outcome Model + +## Simulation + +Here, we generate data from a simple partitioned linear model. + +```{r} +# Generate the data +n <- 500 +p_x <- 10 +snr <- 3 +X <- matrix(runif(n*p_x), ncol = p_x) +f_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) +) +z <- f_X + rnorm(n, 0, 1) +y <- (z>0) * 1.0 + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +z_test <- z[test_inds] +z_train <- z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +``` + +## Sampling and Analysis + +### Warmstart + +We first sample from an ensemble model of $y \mid X$ using "warm-start" +initialization samples (@he2023stochastic). This is the default in +`stochtree`. + +```{r} +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + probit_outcome_model = T) +bart_model_warmstart <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +``` + +Inspect the MCMC samples + +```{r} +plot(rowMeans(bart_model_warmstart$y_hat_test), z_test, + pch=16, cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + +### BART MCMC without Warmstart + +Next, we sample from this ensemble model without any warm-start initialization. + +```{r} +num_gfr <- 0 +num_burnin <- 100 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +general_params <- list(sample_sigma2_global = F) +mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 100, + probit_outcome_model = T) +bart_model_root <- stochtree::bart( + X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, mean_forest_params = mean_forest_params +) +``` + +Inspect the BART samples after burnin. + +```{r} +plot(rowMeans(bart_model_root$y_hat_test), z_test, + pch=16, cex=0.75, xlab = "pred", ylab = "actual") +abline(0,1,col="red",lty=2,lwd=2.5) +``` + # References From 09b62be627475feb2f9bcf48953c96805a92d5ef Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 20 Apr 2025 18:24:56 -0500 Subject: [PATCH 02/14] Updated probit BART vignettes --- vignettes/BayesianSupervisedLearning.Rmd | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index 66655f62..ae2589fc 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -393,6 +393,13 @@ plot(rowMeans(bart_model_warmstart$y_hat_test), z_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` +Check the prediction accuracy + +```{r} +preds_test <- rowMeans(bart_model_warmstart$y_hat_test) > 0 +mean(preds_test == y_test) +``` + ### BART MCMC without Warmstart Next, we sample from this ensemble model without any warm-start initialization. @@ -420,4 +427,11 @@ plot(rowMeans(bart_model_root$y_hat_test), z_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` +Check the prediction accuracy + +```{r} +preds_test <- rowMeans(bart_model_root$y_hat_test) > 0 +mean(preds_test == y_test) +``` + # References From 30e4ccb9032af974a1235aa0283af38767849473 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 22 Apr 2025 21:12:53 -0500 Subject: [PATCH 03/14] Updated probit vignette --- vignettes/BayesianSupervisedLearning.Rmd | 93 ++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 8 deletions(-) diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index ae2589fc..6bfbf85e 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -335,9 +335,8 @@ Here, we generate data from a simple partitioned linear model. ```{r} # Generate the data -n <- 500 -p_x <- 10 -snr <- 3 +n <- 1000 +p_x <- 100 X <- matrix(runif(n*p_x), ncol = p_x) f_X <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + @@ -349,7 +348,7 @@ z <- f_X + rnorm(n, 0, 1) y <- (z>0) * 1.0 # Split data into test and train sets -test_set_pct <- 0.2 +test_set_pct <- 0.5 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) @@ -385,7 +384,7 @@ bart_model_warmstart <- stochtree::bart( ) ``` -Inspect the MCMC samples +Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set. ```{r} plot(rowMeans(bart_model_warmstart$y_hat_test), z_test, @@ -393,13 +392,52 @@ plot(rowMeans(bart_model_warmstart$y_hat_test), z_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` -Check the prediction accuracy +On non-simulated datasets, the first thing we would evaluate is the prediction accuracy. ```{r} preds_test <- rowMeans(bart_model_warmstart$y_hat_test) > 0 mean(preds_test == y_test) ``` +We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean. + +```{r} +num_thresholds <- 1000 +thresholds <- seq(0.001,0.999,length.out=num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +yhat_samples <- bart_model_warmstart$y_hat_test +yhat_mean <- rowMeans(yhat_samples) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- yhat_samples > qnorm(thresholds[i]) + is_above_threshold_mean <- yhat_mean > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test==0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + tpr_mean[i] <- sum(y_above_threshold_mean)/n_positive + fpr_mean[i] <- sum(y_above_threshold_mean==0)/n_negative + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[,j]] + tpr_samples[i,j] <- sum(y_above_threshold)/n_positive + fpr_samples[i,j] <- sum(y_above_threshold==0)/n_negative + } +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot(fpr_samples[,i], tpr_samples[,i], type = "line", col = "blue", lwd = 1, lty = 1, + xlab = "False positive rate", ylab = "True positive rate") + } else { + lines(fpr_samples[,i], tpr_samples[,i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) +``` + +Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). + ### BART MCMC without Warmstart Next, we sample from this ensemble model without any warm-start initialization. @@ -419,7 +457,7 @@ bart_model_root <- stochtree::bart( ) ``` -Inspect the BART samples after burnin. +Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set. ```{r} plot(rowMeans(bart_model_root$y_hat_test), z_test, @@ -427,11 +465,50 @@ plot(rowMeans(bart_model_root$y_hat_test), z_test, abline(0,1,col="red",lty=2,lwd=2.5) ``` -Check the prediction accuracy +On non-simulated datasets, the first thing we would evaluate is the prediction accuracy. ```{r} preds_test <- rowMeans(bart_model_root$y_hat_test) > 0 mean(preds_test == y_test) ``` +We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean. + +```{r} +num_thresholds <- 1000 +thresholds <- seq(0.001,0.999,length.out=num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +yhat_samples <- bart_model_root$y_hat_test +yhat_mean <- rowMeans(yhat_samples) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- yhat_samples > qnorm(thresholds[i]) + is_above_threshold_mean <- yhat_mean > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test==0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + tpr_mean[i] <- sum(y_above_threshold_mean)/n_positive + fpr_mean[i] <- sum(y_above_threshold_mean==0)/n_negative + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[,j]] + tpr_samples[i,j] <- sum(y_above_threshold)/n_positive + fpr_samples[i,j] <- sum(y_above_threshold==0)/n_negative + } +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot(fpr_samples[,i], tpr_samples[,i], type = "line", col = "blue", lwd = 1, lty = 1, + xlab = "False positive rate", ylab = "True positive rate") + } else { + lines(fpr_samples[,i], tpr_samples[,i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) +``` + +Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). + # References From 876f79774a6544bccfbba51998a32d59ad90eee4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 23 Apr 2025 00:56:46 -0500 Subject: [PATCH 04/14] WIP Python probit implementation --- R/bart.R | 8 +- stochtree/bart.py | 191 ++++++++++++++++++++++++++++++++++++++-------- stochtree/bcf.py | 5 +- 3 files changed, 166 insertions(+), 38 deletions(-) diff --git a/R/bart.R b/R/bart.R index 2e9db337..8a33f8fd 100644 --- a/R/bart.R +++ b/R/bart.R @@ -466,7 +466,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Determine whether a test set is provided has_test = !is.null(X_test) - # Check whether outcome is 0-1 binary + # Preliminary runtime checks for probit link + if (!include_mean_forest) { + probit_outcome_model <- FALSE + } if (probit_outcome_model) { if (!(length(unique(y_train)) == 2)) { stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") @@ -488,11 +491,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # differently for binary and continuous outcomes if (probit_outcome_model) { # Compute a probit-scale offset and fix scale to 1 - y_bar_train <- pnorm(mean(y_train)) + y_bar_train <- qnorm(mean(y_train)) y_std_train <- 1 # Set a pseudo outcome by subtracting mean(y_train) from y_train - current_z_train <- y_train - mean(y_train) resid_train <- y_train - mean(y_train) # Set initial values of root nodes to 0.0 (in probit scale) diff --git a/stochtree/bart.py b/stochtree/bart.py index dff3362f..679cebe8 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from scipy.stats import norm from .config import ForestModelConfig, GlobalModelConfig from .data import Dataset, Residual @@ -145,6 +146,7 @@ def sample( * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the mean forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the mean forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. + * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. variance_forest_params : dict, optional Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. @@ -203,6 +205,7 @@ def sample( "sigma2_leaf_scale": None, "keep_vars": None, "drop_vars": None, + "probit_outcome_model": False, } mean_forest_params_updated = _preprocess_params( mean_forest_params_default, mean_forest_params @@ -253,6 +256,7 @@ def sample( b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] drop_vars_mean = mean_forest_params_updated["drop_vars"] + self.probit_outcome_model = mean_forest_params_updated["probit_outcome_model"] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -710,25 +714,40 @@ def sample( [variable_subset_variance.count(i) == 0 for i in original_var_indices] ] = 0 - # Scale outcome - if self.standardize: - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) - else: - self.y_bar = 0 - self.y_std = 1 - resid_train = (y_train - self.y_bar) / self.y_std - - # Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART) - if not sigma2_init: - sigma2_init = 1.0 * np.var(resid_train) - if not variance_forest_leaf_init: - variance_forest_leaf_init = 0.6 * np.var(resid_train) - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - if self.include_mean_forest: + # Preliminary runtime checks for probit link + if not self.include_mean_forest: + self.probit_outcome_model = False + if self.probit_outcome_model: + if np.unique(y_train).size != 2: + raise ValueError("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + unique_outcomes = np.squeeze(np.unique(y_train)) + if not np.array_equal(unique_outcomes, [0,1]): + raise ValueError("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + if self.include_variance_forest: + raise ValueError("We do not support heteroskedasticity with a probit link") + if self.sample_sigma_global: + warnings.warn("Global error variance will not be sampled with a probit link as it is fixed at 1") + self.sample_sigma_global = False + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if self.probit_outcome_model: + # Compute a probit-scale offset and fix scale to 1 + self.y_bar = norm.ppf(np.squeeze(np.mean(y_train))) + self.y_std = 1.0 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train = y_train - np.squeeze(np.mean(y_train)) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean = 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init = 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link b_leaf = ( - np.squeeze(np.var(resid_train)) / num_trees_mean + 1.0 / num_trees_mean if b_leaf is None else b_leaf ) @@ -737,7 +756,7 @@ def sample( current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) np.fill_diagonal( current_leaf_scale, - np.squeeze(np.var(resid_train)) / num_trees_mean, + 2.0 / num_trees_mean, ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) @@ -763,7 +782,7 @@ def sample( else: if sigma_leaf is None: current_leaf_scale = np.array( - [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + [[2.0 / num_trees_mean]] ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.array([[sigma_leaf]]) @@ -786,17 +805,98 @@ def sample( "sigma_leaf must be either a scalar or a 2d numpy array" ) else: - current_leaf_scale = np.array([[1.0]]) - if self.include_variance_forest: - if not a_forest: - a_forest = num_trees_variance / a_0**2 + 0.5 - if not b_forest: - b_forest = num_trees_variance / a_0**2 - else: - if not a_forest: - a_forest = 1.0 - if not b_forest: - b_forest = 1.0 + # Standardize if requested + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 + + # Compute residual value + resid_train = (y_train - self.y_bar) / self.y_std + + # Compute initial value of root nodes in mean forest + init_val_mean = np.squeeze(np.mean(resid_train)) + + # Calibrate priors for global sigma^2 and sigma_leaf + if not sigma2_init: + sigma2_init = 1.0 * np.var(resid_train) + if not variance_forest_leaf_init: + variance_forest_leaf_init = 0.6 * np.var(resid_train) + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + if self.include_mean_forest: + b_leaf = ( + np.squeeze(np.var(resid_train)) / num_trees_mean + if b_leaf is None + else b_leaf + ) + if self.has_basis: + if sigma_leaf is None: + current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + np.fill_diagonal( + current_leaf_scale, + np.squeeze(np.var(resid_train)) / num_trees_mean, + ) + elif isinstance(sigma_leaf, float): + current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + np.fill_diagonal(current_leaf_scale, sigma_leaf) + elif isinstance(sigma_leaf, np.ndarray): + if sigma_leaf.ndim != 2: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != self.num_basis: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) + current_leaf_scale = sigma_leaf + else: + raise ValueError( + "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + ) + else: + if sigma_leaf is None: + current_leaf_scale = np.array( + [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + ) + elif isinstance(sigma_leaf, float): + current_leaf_scale = np.array([[sigma_leaf]]) + elif isinstance(sigma_leaf, np.ndarray): + if sigma_leaf.ndim != 2: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != sigma_leaf.shape[1]: + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf.shape[0] != 1: + raise ValueError( + "sigma_leaf must be a 1x1 numpy array for this leaf model" + ) + current_leaf_scale = sigma_leaf + else: + raise ValueError( + "sigma_leaf must be either a scalar or a 2d numpy array" + ) + else: + current_leaf_scale = np.array([[1.0]]) + if self.include_variance_forest: + if not a_forest: + a_forest = num_trees_variance / a_0**2 + 0.5 + if not b_forest: + b_forest = num_trees_variance / a_0**2 + else: + if not a_forest: + a_forest = 1.0 + if not b_forest: + b_forest = 1.0 # Runtime checks on RFX group ids self.has_rfx = False @@ -894,11 +994,13 @@ def sample( # Residual residual_train = Residual(resid_train) - # C++ random number generator + # C++ and Numpy random number generator if random_seed is None: cpp_rng = RNG(-1) + self.rng = np.random.default_rng() else: cpp_rng = RNG(random_seed) + self.rng = np.random.default_rng(random_seed) # Set variance leaf model type (currently only one option) leaf_model_variance_forest = 3 @@ -1018,8 +1120,31 @@ def sample( keep_sample = True if keep_sample: sample_counter += 1 - # Sample the mean forest if self.include_mean_forest: + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred = active_forest_mean.predict(forest_dataset_train) + mu0 = forest_pred[y_train == 0] + mu1 = forest_pred[y_train == 1] + n0 = np.sum(y_train == 0) + n1 = np.sum(y_train == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train == 0] = mu0 + norm.ppf(u0) + resid_train[y_train == 1] = mu1 + norm.ppf(u1) + + # Update outcome + residual_train.update_data(resid_train - forest_pred) + + # Sample mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, active_forest_mean, diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 49699f70..35c84596 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -72,7 +72,6 @@ class BCFModel: def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False - self.rng = np.random.default_rng() def sample( self, @@ -1384,11 +1383,13 @@ def sample( # Residual residual_train = Residual(resid_train) - # C++ random number generator + # C++ and numpy random number generator if random_seed is None: cpp_rng = RNG(-1) + self.rng = np.random.default_rng() else: cpp_rng = RNG(random_seed) + self.rng = np.random.default_rng(random_seed) # Sampling data structures global_model_config = GlobalModelConfig(global_error_variance=current_sigma2) From ef8d96276db3b2a975cb0bf5471923bf40b7a605 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 24 Apr 2025 01:51:57 -0500 Subject: [PATCH 05/14] Updated probit interface and added vignette for python classification --- R/bart.R | 9 +- demo/debug/classification.py | 59 +++++ .../supervised_learning_classification.ipynb | 230 ++++++++++++++++++ stochtree/bart.py | 55 ++++- stochtree/data.py | 2 +- 5 files changed, 340 insertions(+), 15 deletions(-) create mode 100644 demo/debug/classification.py create mode 100644 demo/notebooks/supervised_learning_classification.ipynb diff --git a/R/bart.R b/R/bart.R index 8a33f8fd..340e53ae 100644 --- a/R/bart.R +++ b/R/bart.R @@ -705,7 +705,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) - } # Run GFR (warm start) if specified @@ -1015,7 +1014,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train "sample_sigma_global" = sample_sigma_global, "sample_sigma_leaf" = sample_sigma_leaf, "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest + "include_variance_forest" = include_variance_forest, + "probit_outcome_model" = probit_outcome_model ) result <- list( "model_params" = model_params, @@ -1357,6 +1357,7 @@ saveBARTModelToJson <- function(object){ jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) + jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) if (object$model_params$sample_sigma_global) { jsonobj$add_vector("sigma2_global_samples", object$sigma2_global_samples, "parameters") } @@ -1548,6 +1549,8 @@ createBARTModelFromJson <- function(json_object){ model_params[["num_chains"]] <- json_object$get_scalar("num_chains") model_params[["keep_every"]] <- json_object$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model") + output[["model_params"]] <- model_params # Unpack sampled parameters @@ -1750,6 +1753,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") @@ -1905,6 +1909,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["requires_basis"]] <- json_object_default$get_boolean("requires_basis") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") # Combine values that are sample-specific for (i in 1:length(json_object_list)) { diff --git a/demo/debug/classification.py b/demo/debug/classification.py new file mode 100644 index 00000000..4d303289 --- /dev/null +++ b/demo/debug/classification.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_curve, auc + +from stochtree import BARTModel + +# RNG +rng = np.random.default_rng() + +# Generate covariates +n = 1000 +p_X = 10 +X = rng.uniform(0, 1, (n, p_X)) + + +# Define the outcome mean function +def outcome_mean(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * X[:, 1], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * X[:, 1], + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]), + ), + ) + + +# Generate outcome +epsilon = rng.normal(0, 1, n) +z = outcome_mean(X) + epsilon +y = np.where(z >= 0, 1, 0) + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +z_train = z[train_inds] +z_test = z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] + +# Fit Probit BART +bart_model = BARTModel() +general_params = {"num_chains": 1} +mean_forest_params = {"probit_outcome_model": True} +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_mcmc=100, + general_params=general_params, + mean_forest_params=mean_forest_params +) diff --git a/demo/notebooks/supervised_learning_classification.ipynb b/demo/notebooks/supervised_learning_classification.ipynb new file mode 100644 index 00000000..fef2473b --- /dev/null +++ b/demo/notebooks/supervised_learning_classification.ipynb @@ -0,0 +1,230 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supervised Learning (Classification)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "from stochtree import BARTModel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate sample data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# RNG\n", + "rng = np.random.default_rng()\n", + "\n", + "# Generate covariates\n", + "n = 1000\n", + "p_X = 10\n", + "X = rng.uniform(0, 1, (n, p_X))\n", + "\n", + "\n", + "# Define the outcome mean function\n", + "def outcome_mean(X):\n", + " return np.where(\n", + " (X[:, 0] >= 0.0) & (X[:, 0] < 0.25),\n", + " -7.5 * X[:, 1],\n", + " np.where(\n", + " (X[:, 0] >= 0.25) & (X[:, 0] < 0.5),\n", + " -2.5 * X[:, 1],\n", + " np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * X[:, 1], 7.5 * X[:, 1]),\n", + " ),\n", + " )\n", + "\n", + "\n", + "# Generate outcome\n", + "epsilon = rng.normal(0, 1, n)\n", + "z = outcome_mean(X) + epsilon\n", + "y = np.where(z >= 0, 1, 0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test-train split" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sample_inds = np.arange(n)\n", + "train_inds, test_inds = train_test_split(sample_inds, test_size=0.5)\n", + "X_train = X[train_inds, :]\n", + "X_test = X[test_inds, :]\n", + "z_train = z[train_inds]\n", + "z_test = z[test_inds]\n", + "y_train = y[train_inds]\n", + "y_test = y[test_inds]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run BART" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/drew/Github/stochtree/venv/lib/python3.12/site-packages/stochtree/bart.py:729: UserWarning: Global error variance will not be sampled with a probit link as it is fixed at 1\n", + " warnings.warn(\"Global error variance will not be sampled with a probit link as it is fixed at 1\")\n" + ] + } + ], + "source": [ + "num_gfr = 10\n", + "num_mcmc = 100\n", + "bart_model = BARTModel()\n", + "general_params = {\"num_chains\": 1}\n", + "mean_forest_params = {\"probit_outcome_model\": True}\n", + "bart_model.sample(\n", + " X_train=X_train,\n", + " y_train=y_train,\n", + " X_test=X_test,\n", + " num_gfr=num_gfr,\n", + " num_mcmc=num_mcmc,\n", + " general_params=general_params,\n", + " mean_forest_params=mean_forest_params\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we've simulated this data, we can compare the true latent continuous outcome variable to the probit-scale predictions for a test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(x=np.mean(bart_model.y_hat_test,axis=1), y=z_test)\n", + "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3, 3)))\n", + "plt.xlabel(\"Predicted\")\n", + "plt.ylabel(\"Actual\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On non-simulated datasets, the first thing we would evaluate is the prediction accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preds_test = np.mean(bart_model.y_hat_test,axis=1) > 0\n", + "print(f\"Test set accuracy: {np.mean(y_test == preds_test):.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also compute the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) for every posterior sample, as well as the ROC of the posterior mean." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_gfr = 10\n", + "num_mcmc = 100\n", + "fpr_list = list()\n", + "tpr_list = list()\n", + "threshold_list = list()\n", + "for i in range(num_mcmc):\n", + " fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1)\n", + " fpr_list.append(fpr)\n", + " tpr_list.append(tpr)\n", + " threshold_list.append(thresholds)\n", + "probit_preds_test_mean = np.mean(bart_model.y_hat_test,axis=1)\n", + "fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, probit_preds_test_mean, pos_label=1)\n", + "for i in range(num_mcmc):\n", + " plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 0.9)\n", + "plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 1.75)\n", + "plt.axline((0, 0), slope=1, color=\"red\", linestyle='dashed', linewidth=1.5)\n", + "plt.xlabel(\"False Positive Rate\")\n", + "plt.ylabel(\"True Positive Rate\")\n", + "plt.xlim(0, 1)\n", + "plt.ylim(0, 1)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/stochtree/bart.py b/stochtree/bart.py index 679cebe8..fa012319 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -182,7 +182,7 @@ def sample( "sigma2_global_shape": 0, "sigma2_global_scale": 0, "variable_weights": None, - "random_seed": -1, + "random_seed": None, "keep_burnin": False, "keep_gfr": False, "keep_every": 1, @@ -725,9 +725,9 @@ def sample( raise ValueError("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") if self.include_variance_forest: raise ValueError("We do not support heteroskedasticity with a probit link") - if self.sample_sigma_global: + if sample_sigma_global: warnings.warn("Global error variance will not be sampled with a probit link as it is fixed at 1") - self.sample_sigma_global = False + sample_sigma_global = False # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes @@ -745,6 +745,8 @@ def sample( # Calibrate priors for sigma^2 and tau # Set sigma2_init to 1, ignoring default provided sigma2_init = 1.0 + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init # Skip variance_forest_init, since variance forests are not supported with probit link b_leaf = ( 1.0 / num_trees_mean @@ -1124,10 +1126,10 @@ def sample( if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred = active_forest_mean.predict(forest_dataset_train) - mu0 = forest_pred[y_train == 0] - mu1 = forest_pred[y_train == 1] - n0 = np.sum(y_train == 0) - n1 = np.sum(y_train == 1) + mu0 = forest_pred[y_train[:,0] == 0] + mu1 = forest_pred[y_train[:,0] == 1] + n0 = np.sum(y_train[:,0] == 0) + n1 = np.sum(y_train[:,0] == 1) u0 = self.rng.uniform( low=0.0, high=norm.cdf(0 - mu0), @@ -1138,13 +1140,14 @@ def sample( high=1.0, size=n1, ) - resid_train[y_train == 0] = mu0 + norm.ppf(u0) - resid_train[y_train == 1] = mu1 + norm.ppf(u1) + resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) + resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) # Update outcome - residual_train.update_data(resid_train - forest_pred) + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) - # Sample mean forest + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, active_forest_mean, @@ -1308,8 +1311,33 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 - # Sample the mean forest + if self.include_mean_forest: + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred = active_forest_mean.predict(forest_dataset_train) + mu0 = forest_pred[y_train[:,0] == 0] + mu1 = forest_pred[y_train[:,0] == 1] + n0 = np.sum(y_train[:,0] == 0) + n1 = np.sum(y_train[:,0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) + resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) + + # Update outcome + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) + + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, active_forest_mean, @@ -1811,6 +1839,7 @@ def to_json(self) -> str: bart_json.add_integer("num_samples", self.num_samples) bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) + bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) # Add parameter samples if self.sample_sigma_global: @@ -1882,6 +1911,7 @@ def from_json(self, json_string: str) -> None: self.num_samples = bart_json.get_integer("num_samples") self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") + self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") # Unpack parameter samples if self.sample_sigma_global: @@ -1990,6 +2020,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_samples = json_object_default.get_integer("num_samples") self.num_basis = json_object_default.get_integer("num_basis") self.has_basis = json_object_default.get_boolean("requires_basis") + self.probit_outcome_model = json_object_default.get_boolean("probit_outcome_model") # Unpack parameter samples if self.sample_sigma_global: diff --git a/stochtree/data.py b/stochtree/data.py index a29e80f5..8cbe76e0 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -175,4 +175,4 @@ def update_data(self, new_vector: np.array) -> None: Univariate numpy array of new residual values. """ n = new_vector.size - self.residual_cpp.UpdateData(new_vector, n) + self.residual_cpp.ReplaceData(new_vector, n) From e71bef9e7c287989aeb1d59ec7634f1e8c8840ab Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 24 Apr 2025 10:57:55 -0500 Subject: [PATCH 06/14] Clear outputs from classification notebook --- .../supervised_learning_classification.ipynb | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/demo/notebooks/supervised_learning_classification.ipynb b/demo/notebooks/supervised_learning_classification.ipynb index fef2473b..e88b1b7b 100644 --- a/demo/notebooks/supervised_learning_classification.ipynb +++ b/demo/notebooks/supervised_learning_classification.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -105,16 +105,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/drew/Github/stochtree/venv/lib/python3.12/site-packages/stochtree/bart.py:729: UserWarning: Global error variance will not be sampled with a probit link as it is fixed at 1\n", - " warnings.warn(\"Global error variance will not be sampled with a probit link as it is fixed at 1\")\n" - ] - } - ], + "outputs": [], "source": [ "num_gfr = 10\n", "num_mcmc = 100\n", From 58e4c57e50a6d578b40be5b5a60659269309a657 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 24 Apr 2025 13:25:21 -0500 Subject: [PATCH 07/14] Updated default RNG seed in BCF --- stochtree/bcf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 35c84596..4ad3a841 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -215,7 +215,7 @@ def sample( "adaptive_coding": True, "control_coding_init": -0.5, "treated_coding_init": 0.5, - "random_seed": -1, + "random_seed": None, "keep_burnin": False, "keep_gfr": False, "keep_every": 1, From 7b4bab17dcbfc2cb5ef7a5df4989e5f0af1f1a9a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 28 Apr 2025 15:08:57 -0500 Subject: [PATCH 08/14] Updated support for probit link in Python BART --- demo/notebooks/prototype_interface.ipynb | 2 +- src/py_stochtree.cpp | 2 +- stochtree/bart.py | 91 ++++++++++++++---------- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index ca385e1b..db2d2f22 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -207,7 +207,7 @@ " num_observations=n,\n", " feature_types=feature_types,\n", " variable_weights=var_weights,\n", - " leaf_dimension=leaf_dimension,\n", + " leaf_dimension=1,\n", " alpha=alpha,\n", " beta=beta,\n", " min_samples_leaf=min_samples_leaf,\n", diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 8c73fd59..df7afdd9 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1889,7 +1889,7 @@ void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& } void ForestCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, bool add) { - // Determine whether or not we are adding forest_num to the residuals + // Determine whether or not we are adding forest predictions to the residuals std::function op; if (add) op = std::plus(); else op = std::minus(); diff --git a/stochtree/bart.py b/stochtree/bart.py index fa012319..fd1159de 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -719,14 +719,22 @@ def sample( self.probit_outcome_model = False if self.probit_outcome_model: if np.unique(y_train).size != 2: - raise ValueError("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + raise ValueError( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) unique_outcomes = np.squeeze(np.unique(y_train)) - if not np.array_equal(unique_outcomes, [0,1]): - raise ValueError("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + if not np.array_equal(unique_outcomes, [0, 1]): + raise ValueError( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) if self.include_variance_forest: - raise ValueError("We do not support heteroskedasticity with a probit link") + raise ValueError( + "We do not support heteroskedasticity with a probit link" + ) if sample_sigma_global: - warnings.warn("Global error variance will not be sampled with a probit link as it is fixed at 1") + warnings.warn( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) sample_sigma_global = False # Handle standardization, prior calibration, and initialization of forest @@ -748,20 +756,20 @@ def sample( current_sigma2 = sigma2_init self.sigma2_init = sigma2_init # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf = ( - 1.0 / num_trees_mean - if b_leaf is None - else b_leaf - ) + b_leaf = 1.0 / num_trees_mean if b_leaf is None else b_leaf if self.has_basis: if sigma_leaf is None: - current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) np.fill_diagonal( current_leaf_scale, 2.0 / num_trees_mean, ) elif isinstance(sigma_leaf, float): - current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) np.fill_diagonal(current_leaf_scale, sigma_leaf) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: @@ -783,9 +791,7 @@ def sample( ) else: if sigma_leaf is None: - current_leaf_scale = np.array( - [[2.0 / num_trees_mean]] - ) + current_leaf_scale = np.array([[2.0 / num_trees_mean]]) elif isinstance(sigma_leaf, float): current_leaf_scale = np.array([[sigma_leaf]]) elif isinstance(sigma_leaf, np.ndarray): @@ -814,10 +820,10 @@ def sample( else: self.y_bar = 0 self.y_std = 1 - + # Compute residual value resid_train = (y_train - self.y_bar) / self.y_std - + # Compute initial value of root nodes in mean forest init_val_mean = np.squeeze(np.mean(resid_train)) @@ -836,13 +842,17 @@ def sample( ) if self.has_basis: if sigma_leaf is None: - current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) np.fill_diagonal( current_leaf_scale, np.squeeze(np.var(resid_train)) / num_trees_mean, ) elif isinstance(sigma_leaf, float): - current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float) + current_leaf_scale = np.zeros( + (self.num_basis, self.num_basis), dtype=float + ) np.fill_diagonal(current_leaf_scale, sigma_leaf) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: @@ -936,7 +946,10 @@ def sample( alpha_init = np.array([1]) elif num_rfx_components > 1: alpha_init = np.concatenate( - (np.ones(1, dtype=float), np.zeros(num_rfx_components - 1, dtype=float)) + ( + np.ones(1, dtype=float), + np.zeros(num_rfx_components - 1, dtype=float), + ) ) else: raise ValueError("There must be at least 1 random effect component") @@ -1126,10 +1139,10 @@ def sample( if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred = active_forest_mean.predict(forest_dataset_train) - mu0 = forest_pred[y_train[:,0] == 0] - mu1 = forest_pred[y_train[:,0] == 1] - n0 = np.sum(y_train[:,0] == 0) - n1 = np.sum(y_train[:,0] == 1) + mu0 = forest_pred[y_train[:, 0] == 0] + mu1 = forest_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( low=0.0, high=norm.cdf(0 - mu0), @@ -1140,13 +1153,13 @@ def sample( high=1.0, size=n1, ) - resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) - resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome new_outcome = np.squeeze(resid_train) - forest_pred residual_train.update_data(new_outcome) - + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, @@ -1311,15 +1324,17 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 - + if self.include_mean_forest: if self.probit_outcome_model: # Sample latent probit variable z | - - forest_pred = active_forest_mean.predict(forest_dataset_train) - mu0 = forest_pred[y_train[:,0] == 0] - mu1 = forest_pred[y_train[:,0] == 1] - n0 = np.sum(y_train[:,0] == 0) - n1 = np.sum(y_train[:,0] == 1) + forest_pred = active_forest_mean.predict( + forest_dataset_train + ) + mu0 = forest_pred[y_train[:, 0] == 0] + mu1 = forest_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( low=0.0, high=norm.cdf(0 - mu0), @@ -1330,13 +1345,13 @@ def sample( high=1.0, size=n1, ) - resid_train[y_train[:,0] == 0,0] = mu0 + norm.ppf(u0) - resid_train[y_train[:,0] == 1,0] = mu1 + norm.ppf(u1) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome new_outcome = np.squeeze(resid_train) - forest_pred residual_train.update_data(new_outcome) - + # Sample the mean forest forest_sampler_mean.sample_one_iteration( self.forest_container_mean, @@ -2020,7 +2035,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.num_samples = json_object_default.get_integer("num_samples") self.num_basis = json_object_default.get_integer("num_basis") self.has_basis = json_object_default.get_boolean("requires_basis") - self.probit_outcome_model = json_object_default.get_boolean("probit_outcome_model") + self.probit_outcome_model = json_object_default.get_boolean( + "probit_outcome_model" + ) # Unpack parameter samples if self.sample_sigma_global: From a53af2b7fcce8a900c349f784bb5d33a26e6622f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 1 May 2025 16:38:11 -0500 Subject: [PATCH 09/14] Updated probit BART and BCF --- R/bart.R | 12 +- R/bcf.R | 185 +++++++++++++++++++---- vignettes/BayesianSupervisedLearning.Rmd | 4 +- vignettes/CausalInference.Rmd | 2 +- 4 files changed, 161 insertions(+), 42 deletions(-) diff --git a/R/bart.R b/R/bart.R index 340e53ae..a8e6ebe2 100644 --- a/R/bart.R +++ b/R/bart.R @@ -540,7 +540,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train y_std_train <- 1 } - # Compute residual value + # Compute standardized outcome resid_train <- (y_train-y_bar_train)/y_std_train # Compute initial value of root nodes in mean forest @@ -552,14 +552,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) if (has_basis) { if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(2*var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) if (!is.matrix(sigma_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) } else { current_leaf_scale <- sigma_leaf_init } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) if (!is.matrix(sigma_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) } else { @@ -567,7 +567,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(2*var(resid_train)/(num_trees_mean)) if (!is.matrix(sigma_leaf_init)) { current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) } else { @@ -724,7 +724,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + forest_pred <- active_forest_mean$predict(forest_dataset_train) mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) @@ -878,7 +878,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict(forest_dataset_train) + y_bar_train + forest_pred <- active_forest_mean$predict(forest_dataset_train) mu0 <- forest_pred[y_train == 0] mu1 <- forest_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) diff --git a/R/bcf.R b/R/bcf.R index fcbd1347..8fd46808 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -46,6 +46,7 @@ #' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. +#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. #' #' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. #' @@ -74,6 +75,7 @@ #' - `sigma2_leaf_init` Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. #' - `sigma2_leaf_shape` Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Default: `3`. #' - `sigma2_leaf_scale` Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. +#' - `delta_max` Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Default: `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`. #' - `keep_vars` Vector of variable names or column indices denoting variables that should be included in the forest. Default: `NULL`. #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' @@ -156,7 +158,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id adaptive_coding = TRUE, control_coding_init = -0.5, treated_coding_init = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = FALSE, keep_gfr = FALSE, - keep_every = 1, num_chains = 1, verbose = FALSE + keep_every = 1, num_chains = 1, verbose = FALSE, + probit_outcome_model = FALSE ) general_params_updated <- preprocessParams( general_params_default, general_params @@ -180,7 +183,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id min_samples_leaf = 5, max_depth = 5, sample_sigma2_leaf = FALSE, sigma2_leaf_init = NULL, sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL, - keep_vars = NULL, drop_vars = NULL + keep_vars = NULL, drop_vars = NULL, + delta_max = 0.9 ) treatment_effect_forest_params_updated <- preprocessParams( treatment_effect_forest_params_default, treatment_effect_forest_params @@ -220,6 +224,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id keep_every <- general_params_updated$keep_every num_chains <- general_params_updated$num_chains verbose <- general_params_updated$verbose + probit_outcome_model <- general_params_updated$probit_outcome_model # 2. Mu forest parameters num_trees_mu <- prognostic_forest_params_updated$num_trees @@ -246,6 +251,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars + delta_max <- treatment_effect_forest_params_updated$delta_max # 4. Variance forest parameters num_trees_variance <- variance_forest_params_updated$num_trees @@ -352,6 +358,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } num_cov_orig <- ncol(X_train) + # Check delta_max is valid + if ((delta_max <= 0) || (delta_max >= 1)) { + stop("delta_max must be > 0 and < 1") + } + # Standardize the keep variable lists to numeric indices if (!is.null(keep_vars_mu)) { if (is.character(keep_vars_mu)) { @@ -674,44 +685,118 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id variable_weights_variance <- variable_weights_variance / sum(variable_weights_variance) } - # Standardize outcome separately for test and train - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 - y_std_train <- 1 + # Preliminary runtime checks for probit link + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop("You specified a probit outcome model, but supplied an outcome with more than 2 unique values") + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0,1)))) { + stop("You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1") + } + if (include_variance_forest) { + stop("We do not support heteroskedasticity with a probit link") + } + if (sample_sigma_global) { + warning("Global error variance will not be sampled with a probit link as it is fixed at 1") + sample_sigma_global <- F + } } - resid_train <- (y_train-y_bar_train)/y_std_train - # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau - if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) - if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) - if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) - if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) - if (is.null(sigma_leaf_mu)) { - sigma_leaf_mu <- var(resid_train)/(num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) - } else { - if (!is.matrix(sigma_leaf_mu)) { + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) + y_std_train <- 1 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial value for the mu forest + init_mu <- 0.0 + + # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau + # Set sigma2_init to 1, ignoring any defaults provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + if (is.null(b_leaf_mu)) b_leaf_mu <- 1/num_trees_mu + if (is.null(b_leaf_tau)) b_leaf_tau <- 1/(2*num_trees_tau) + if (is.null(sigma_leaf_mu)) { + sigma_leaf_mu <- 2/(num_trees_mu) current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) } else { - current_leaf_scale_mu <- sigma_leaf_mu + if (!is.matrix(sigma_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } else { + current_leaf_scale_mu <- sigma_leaf_mu + } } - } - if (is.null(sigma_leaf_tau)) { - sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau) - current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + if (is.null(sigma_leaf_tau)) { + # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in + # treatment_effect_forest_params. + p <- 0.6827 + q_quantile <- qnorm((p+1)/2) + sigma2_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau + current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + } else { + if (!is.matrix(sigma_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + } else { + if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + current_leaf_scale_tau <- sigma_leaf_tau + } + } + current_sigma2 <- sigma2_init } else { - if (!is.matrix(sigma_leaf_tau)) { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 + } + + # Compute standardized outcome + resid_train <- (y_train-y_bar_train)/y_std_train + + # Set initial value for the mu forest + init_mu <- mean(resid_train) + + # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau + if (is.null(sigma2_init)) sigma2_init <- 1.0*var(resid_train) + if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) + if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) + if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) + if (is.null(sigma_leaf_mu)) { + sigma_leaf_mu <- 2.0*var(resid_train)/(num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } else { + if (!is.matrix(sigma_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } else { + current_leaf_scale_mu <- sigma_leaf_mu + } + } + if (is.null(sigma_leaf_tau)) { + sigma_leaf_tau <- var(resid_train)/(num_trees_tau) current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) } else { - if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") - current_leaf_scale_tau <- sigma_leaf_tau + if (!is.matrix(sigma_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + } else { + if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + current_leaf_scale_tau <- sigma_leaf_tau + } } + current_sigma2 <- sigma2_init } - current_sigma2 <- sigma2_init # Switch off leaf scale sampling for multivariate treatments if (ncol(Z_train) > 1) { @@ -842,7 +927,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } # Initialize the leaves of each tree in the prognostic forest - init_mu <- mean(resid_train) active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu) active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, FALSE, FALSE) @@ -870,6 +954,22 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) + tau_forest_pred <- active_forest_tau$predict(forest_dataset_train) + forest_pred <- mu_forest_pred + tau_forest_pred + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, @@ -1120,6 +1220,22 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) + tau_forest_pred <- active_forest_tau$predict(forest_dataset_train) + forest_pred <- mu_forest_pred + tau_forest_pred + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train==0] <- mu0 + qnorm(u0) + resid_train[y_train==1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, @@ -1337,7 +1453,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id "include_variance_forest" = include_variance_forest, "sample_sigma_global" = sample_sigma_global, "sample_sigma_leaf_mu" = sample_sigma_leaf_mu, - "sample_sigma_leaf_tau" = sample_sigma_leaf_tau + "sample_sigma_leaf_tau" = sample_sigma_leaf_tau, + "probit_outcome_model" = probit_outcome_model ) result <- list( "forests_mu" = forest_samples_mu, @@ -1779,6 +1896,7 @@ saveBCFModelToJson <- function(object){ jsonobj$add_scalar("keep_every", object$model_params$keep_every) jsonobj$add_scalar("num_chains", object$model_params$num_chains) jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_boolean("probit_outcome_model", object$model_params$probit_outcome_model) if (object$model_params$sample_sigma_global) { jsonobj$add_vector("sigma2_samples", object$sigma2_samples, "parameters") } @@ -2106,6 +2224,7 @@ createBCFModelFromJson <- function(json_object){ model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") model_params[["num_samples"]] <- json_object$get_scalar("num_samples") model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["probit_outcome_model"]] <- json_object$get_boolean("probit_outcome_model") output[["model_params"]] <- model_params # Unpack sampled parameters @@ -2439,6 +2558,7 @@ createBCFModelFromCombinedJson <- function(json_object_list){ model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") # Combine values that are sample-specific for (i in 1:length(json_object_list)) { @@ -2665,6 +2785,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") # Combine values that are sample-specific for (i in 1:length(json_object_list)) { diff --git a/vignettes/BayesianSupervisedLearning.Rmd b/vignettes/BayesianSupervisedLearning.Rmd index 6bfbf85e..250f74c1 100644 --- a/vignettes/BayesianSupervisedLearning.Rmd +++ b/vignettes/BayesianSupervisedLearning.Rmd @@ -436,7 +436,7 @@ for (i in 1:num_mcmc) { lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) ``` -Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). +Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean could sit above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). ### BART MCMC without Warmstart @@ -509,6 +509,4 @@ for (i in 1:num_mcmc) { lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) ``` -Note that the nonlinearity of the ROC function means that the ROC curve of the posterior mean lies above most of the individual posterior sample ROC curves (which would not be the case if we had simply taken the mean of the ROC curves). - # References diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index 287b58eb..2e0dbf89 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -15,7 +15,7 @@ knitr::opts_chunk$set( ) ``` -This vignette demonstrates how to use the `bcf()` function for supervised learning. +This vignette demonstrates how to use the `bcf()` function for causal inference. To begin, we load the stochtree package. ```{r setup} From ed83d4d09fc11256057b5a97f2a20c0ad46d5a5f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 2 May 2025 01:13:50 -0500 Subject: [PATCH 10/14] Initial non-working python implementation of probit BCF with debug script --- demo/debug/causal_inference_binary_outcome.py | 130 ++++++++ stochtree/bcf.py | 314 ++++++++++++++---- 2 files changed, 373 insertions(+), 71 deletions(-) create mode 100644 demo/debug/causal_inference_binary_outcome.py diff --git a/demo/debug/causal_inference_binary_outcome.py b/demo/debug/causal_inference_binary_outcome.py new file mode 100644 index 00000000..57cbd5fe --- /dev/null +++ b/demo/debug/causal_inference_binary_outcome.py @@ -0,0 +1,130 @@ +# Load necessary libraries +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +from stochtree import BCFModel +from sklearn.model_selection import train_test_split +from scipy.stats import norm + +# Generate sample data +# RNG +random_seed = 101 +rng = np.random.default_rng(random_seed) + +# Generate covariates and basis +n = 2000 +x1 = rng.normal(loc=0., scale=1., size=(n,)) +x2 = rng.normal(loc=0., scale=1., size=(n,)) +x3 = rng.normal(loc=0., scale=1., size=(n,)) +x4 = rng.binomial(n=1,p=0.5,size=(n,)) +x5 = rng.choice(a=[0,1,2], size=(n,), replace=True) +x4_cat = pd.Categorical(x4, categories=[0,1], ordered=True) +x5_cat = pd.Categorical(x4, categories=[0,1,2], ordered=True) +p = 5 +X = pd.DataFrame(data={ + "x1": pd.Series(x1), + "x2": pd.Series(x2), + "x3": pd.Series(x3), + "x4": pd.Series(x4), + "x5": pd.Series(x5) +}) +def g(x): + return np.where( + x.loc[:,"x5"] == 0, 2.0, + np.where( + x.loc[:,"x5"] == 1, -1.0, -4.0 + ) + ) +mu_x = (1.0 + g(X) + X.loc[:,"x1"]*X.loc[:,"x3"])*0.25 +tau_x = (1.0 + 2*X.loc[:,"x2"]*X.loc[:,"x4"])*0.5 +pi_x = ( + 0.8*norm.cdf(3.0*mu_x / np.squeeze(np.std(mu_x)) - 0.5*X.loc[:,"x1"]) + + 0.05 + rng.uniform(low=0., high=0.1, size=(n,)) +) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +E_XZ = mu_x + tau_x*Z +w = E_XZ + rng.normal(loc=0., scale=1., size=(n,)) +y = np.where(w > 0, 1, 0) +delta_x = norm.cdf(mu_x + tau_x) - norm.cdf(mu_x) + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) +X_train = X.iloc[train_inds,:] +X_test = X.iloc[test_inds,:] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +pi_train = pi_x[train_inds] +pi_test = pi_x[test_inds] +mu_train = mu_x[train_inds] +mu_test = mu_x[test_inds] +tau_train = tau_x[train_inds] +tau_test = tau_x[test_inds] +w_train = w[train_inds] +w_test = w[test_inds] +delta_train = delta_x[train_inds] +delta_test = delta_x[test_inds] + +# Number of iterations +num_gfr = 10 +num_burnin = 0 +num_mcmc = 100 + +# Tau prior calibration (this is also done internally in the BCF sampler) +num_trees_tau = 50 +p = 0.6827 +q_quantile = norm.ppf((p+1)/2) +delta_max = 0.9 +sigma2_leaf_tau = ((delta_max/(q_quantile*norm.pdf(0.)))**2) / num_trees_tau + +# Mu prior calibration (this is also done internally in the BCF sampler) +num_trees_mu = 200 +sigma2_leaf_mu = 2/num_trees_mu + +# Construct parameter lists +general_params = { + 'keep_every': 1, + 'probit_outcome_model': True, + 'sample_sigma2_global': False, + 'adaptive_coding': False, + 'num_chains': 1} +prognostic_forest_params = { + 'sample_sigma2_leaf': False, + 'sigma2_leaf_init': sigma2_leaf_mu, + 'num_trees': num_trees_mu} +treatment_effect_forest_params = { + 'sample_sigma2_leaf': False, + 'sigma2_leaf_init': sigma2_leaf_tau, + 'num_trees': num_trees_tau} + +# Run the sampler +bcf_model = BCFModel() +bcf_model.sample(X_train=X_train, Z_train=Z_train, y_train=y_train, pi_train=pi_train, + X_test=X_test, Z_test=Z_test, pi_test=pi_test, num_gfr=num_gfr, + num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, + prognostic_forest_params=prognostic_forest_params, + treatment_effect_forest_params=treatment_effect_forest_params) + +# Inspect the MCMC (BART) samples +plt.scatter(np.squeeze(bcf_model.y_hat_test).mean(axis = 1), y_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.show() + +plt.scatter(np.squeeze(bcf_model.tau_hat_test).mean(axis = 1), tau_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.show() + +plt.scatter(np.squeeze(bcf_model.mu_hat_test).mean(axis = 1), mu_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.show() + +# # Compute RMSEs +# y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2))) +# tau_rmse = np.sqrt(np.mean(np.power(np.expand_dims(tau_test,1) - tau_avg_mcmc, 2))) +# mu_rmse = np.sqrt(np.mean(np.power(np.expand_dims(mu_test,1) - mu_avg_mcmc, 2))) +# print("y hat RMSE: {:.2f}".format(y_rmse)) +# print("tau hat RMSE: {:.2f}".format(tau_rmse)) +# print("mu hat RMSE: {:.2f}".format(mu_rmse)) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 4ad3a841..c532754e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd from sklearn.utils import check_scalar +from scipy.stats import norm from .bart import BARTModel from .config import ForestModelConfig, GlobalModelConfig @@ -150,6 +151,7 @@ def sample( * `keep_gfr` (`bool`): Whether or not "warm-start" / grow-from-root samples should be included in predictions. Defaults to `False`. Ignored if `num_mcmc == 0`. * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. + * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. prognostic_forest_params : dict, optional Dictionary of prognostic forest model parameters, each of which has a default value processed internally, so this argument is optional. @@ -179,6 +181,7 @@ def sample( * `sigma2_leaf_init` (`float`): Starting value of leaf node scale parameter. Calibrated internally as `1/num_trees` if not set here. * `sigma2_leaf_shape` (`float`): Shape parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Defaults to `3`. * `sigma2_leaf_scale` (`float`): Scale parameter in the `IG(sigma2_leaf_shape, sigma2_leaf_scale)` leaf node parameter variance model. Calibrated internally as `0.5/num_trees` if not set here. + * `delta_max` (`float`): Maximum plausible conditional distributional treatment effect (i.e. P(Y(1) = 1 | X) - P(Y(0) = 1 | X)) when the outcome is binary. Only used when the outcome is specified as a probit model in `general_params`. Must be > 0 and < 1. Defaults to `0.9`. Ignored if `sigma2_leaf_init` is set directly, as this parameter is used to calibrate `sigma2_leaf_init`. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the treatment effect (`tau(X)`) forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the treatment effect (`tau(X)`) forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. @@ -220,6 +223,7 @@ def sample( "keep_gfr": False, "keep_every": 1, "num_chains": 1, + "probit_outcome_model": False, } general_params_updated = _preprocess_params( general_params_default, general_params @@ -254,6 +258,7 @@ def sample( "sigma2_leaf_init": None, "sigma2_leaf_shape": 3, "sigma2_leaf_scale": None, + "delta_max": 0.9, "keep_vars": None, "drop_vars": None, } @@ -296,6 +301,7 @@ def sample( keep_burnin = general_params_updated["keep_burnin"] keep_gfr = general_params_updated["keep_gfr"] keep_every = general_params_updated["keep_every"] + self.probit_outcome_model = general_params_updated["probit_outcome_model"] # 2. Mu forest parameters num_trees_mu = prognostic_forest_params_updated["num_trees"] @@ -324,6 +330,7 @@ def sample( sigma_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_init"] a_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_shape"] b_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_scale"] + delta_max = treatment_effect_forest_params_updated["delta_max"] keep_vars_tau = treatment_effect_forest_params_updated["keep_vars"] drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"] @@ -1115,83 +1122,187 @@ def sample( else: self.internal_propensity_model = False - # Scale outcome - if self.standardize: - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) - else: - self.y_bar = 0 - self.y_std = 1 - resid_train = (y_train - self.y_bar) / self.y_std - - # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART) - if not sigma2_init: - sigma2_init = 1.0 * np.var(resid_train) - if not variance_forest_leaf_init: - variance_forest_leaf_init = 0.6 * np.var(resid_train) - b_leaf_mu = ( - np.squeeze(np.var(resid_train)) / num_trees_mu - if b_leaf_mu is None - else b_leaf_mu - ) - b_leaf_tau = ( - np.squeeze(np.var(resid_train)) / (2 * num_trees_tau) - if b_leaf_tau is None - else b_leaf_tau - ) - sigma_leaf_mu = ( - np.squeeze(np.var(resid_train)) / num_trees_mu - if sigma_leaf_mu is None - else sigma_leaf_mu - ) - sigma_leaf_tau = ( - np.squeeze(np.var(resid_train)) / (2 * num_trees_tau) - if sigma_leaf_tau is None - else sigma_leaf_tau - ) - if self.multivariate_treatment: - if not isinstance(sigma_leaf_tau, np.ndarray): - sigma_leaf_tau = np.diagflat( - np.repeat(sigma_leaf_tau, self.treatment_dim) - ) - current_sigma2 = sigma2_init - self.sigma2_init = sigma2_init - if isinstance(sigma_leaf_mu, float): - current_leaf_scale_mu = np.array([[sigma_leaf_mu]]) - else: - raise ValueError("sigma_leaf_mu must be a scalar") - if isinstance(sigma_leaf_tau, float): - if Z_train.shape[1] > 1: - current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) - np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau) - else: - current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) - elif isinstance(sigma_leaf_tau, np.ndarray): - if sigma_leaf_tau.ndim != 2: + # Preliminary runtime checks for probit link + if self.probit_outcome_model: + if np.unique(y_train).size != 2: raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" ) - if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: + unique_outcomes = np.squeeze(np.unique(y_train)) + if not np.array_equal(unique_outcomes, [0, 1]): raise ValueError( - "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" ) - if sigma_leaf_tau.shape[0] != Z_train.shape[1]: + if self.include_variance_forest: raise ValueError( - "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + "We do not support heteroskedasticity with a probit link" ) - current_leaf_scale_tau = sigma_leaf_tau - else: - raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") - if self.include_variance_forest: - if not a_forest: - a_forest = num_trees_variance / a_0**2 + 0.5 - if not b_forest: - b_forest = num_trees_variance / a_0**2 + if sample_sigma_global: + warnings.warn( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma_global = False + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if self.probit_outcome_model: + # Compute a probit-scale offset and fix scale to 1 + self.y_bar = norm.ppf(np.squeeze(np.mean(y_train))) + self.y_std = 1.0 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train = y_train - np.squeeze(np.mean(y_train)) + + # Set initial value for the mu forest + init_mu = 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init = 1.0 + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf_mu = ( + 1.0 / num_trees_mu + if b_leaf_mu is None + else b_leaf_mu + ) + b_leaf_tau = ( + 1.0 / (2 * num_trees_tau) + if b_leaf_tau is None + else b_leaf_tau + ) + sigma_leaf_mu = ( + 1 / num_trees_mu + if sigma_leaf_mu is None + else sigma_leaf_mu + ) + if isinstance(sigma_leaf_mu, float): + current_leaf_scale_mu = np.array([[sigma_leaf_mu]]) + else: + raise ValueError("sigma_leaf_mu must be a scalar") + # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in + # treatment_effect_forest_params. + p = 0.6827 + q_quantile = norm.ppf((p + 1) / 2.0) + sigma_leaf_tau = ( + ((delta_max / (q_quantile*norm.pdf(0)))**2) / num_trees_tau + if sigma_leaf_tau is None + else sigma_leaf_tau + ) + if self.multivariate_treatment: + if not isinstance(sigma_leaf_tau, np.ndarray): + sigma_leaf_tau = np.diagflat( + np.repeat(sigma_leaf_tau, self.treatment_dim) + ) + if isinstance(sigma_leaf_tau, float): + if Z_train.shape[1] > 1: + current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau) + else: + current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) + elif isinstance(sigma_leaf_tau, np.ndarray): + if sigma_leaf_tau.ndim != 2: + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf_tau.shape[0] != Z_train.shape[1]: + raise ValueError( + "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + ) + current_leaf_scale_tau = sigma_leaf_tau + else: + raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") else: - if not a_forest: - a_forest = 1.0 - if not b_forest: - b_forest = 1.0 + # Standardize if requested + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 + + # Compute residual value + resid_train = (y_train - self.y_bar) / self.y_std + + # Compute initial value of root nodes in mean forest + init_mu = np.squeeze(np.mean(resid_train)) + + # Calibrate priors for global sigma^2 and sigma_leaf + if not sigma2_init: + sigma2_init = 1.0 * np.var(resid_train) + if not variance_forest_leaf_init: + variance_forest_leaf_init = 0.6 * np.var(resid_train) + current_sigma2 = sigma2_init + self.sigma2_init = sigma2_init + b_leaf_mu = ( + np.squeeze(np.var(resid_train)) / num_trees_mu + if b_leaf_mu is None + else b_leaf_mu + ) + b_leaf_tau = ( + np.squeeze(np.var(resid_train)) / (2 * num_trees_tau) + if b_leaf_tau is None + else b_leaf_tau + ) + sigma_leaf_mu = ( + np.squeeze(2 * np.var(resid_train)) / num_trees_mu + if sigma_leaf_mu is None + else sigma_leaf_mu + ) + if isinstance(sigma_leaf_mu, float): + current_leaf_scale_mu = np.array([[sigma_leaf_mu]]) + else: + raise ValueError("sigma_leaf_mu must be a scalar") + sigma_leaf_tau = ( + np.squeeze(np.var(resid_train)) / (num_trees_tau) + if sigma_leaf_tau is None + else sigma_leaf_tau + ) + if self.multivariate_treatment: + if not isinstance(sigma_leaf_tau, np.ndarray): + sigma_leaf_tau = np.diagflat( + np.repeat(sigma_leaf_tau, self.treatment_dim) + ) + if isinstance(sigma_leaf_tau, float): + if Z_train.shape[1] > 1: + current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau) + else: + current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) + elif isinstance(sigma_leaf_tau, np.ndarray): + if sigma_leaf_tau.ndim != 2: + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) + if sigma_leaf_tau.shape[0] != Z_train.shape[1]: + raise ValueError( + "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + ) + current_leaf_scale_tau = sigma_leaf_tau + else: + raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") + if self.include_variance_forest: + if not a_forest: + a_forest = num_trees_variance / a_0**2 + 0.5 + if not b_forest: + b_forest = num_trees_variance / a_0**2 + else: + if not a_forest: + a_forest = 1.0 + if not b_forest: + b_forest = 1.0 # Runtime checks on RFX group ids self.has_rfx = False @@ -1478,7 +1589,8 @@ def sample( leaf_var_model_tau = LeafVarianceModel() # Initialize the leaves of each tree in the prognostic forest - init_mu = np.array([np.squeeze(np.mean(resid_train))]) + if not isinstance(init_mu, np.ndarray): + init_mu = np.array([init_mu]) forest_sampler_mu.prepare_for_sampler( forest_dataset_train, residual_train, @@ -1519,6 +1631,33 @@ def sample( keep_sample = True if keep_sample: sample_counter += 1 + + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred_mu = active_forest_mu.predict(forest_dataset_train) + forest_pred_tau = active_forest_mu.predict(forest_dataset_train) + forest_pred = forest_pred_mu + forest_pred_tau + mu0 = forest_pred[y_train[:, 0] == 0] + mu1 = forest_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) + + # Update outcome + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -1673,6 +1812,33 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 + + if self.probit_outcome_model: + # Sample latent probit variable z | - + forest_pred_mu = active_forest_mu.predict(forest_dataset_train) + forest_pred_tau = active_forest_mu.predict(forest_dataset_train) + forest_pred = forest_pred_mu + forest_pred_tau + mu0 = forest_pred[y_train[:, 0] == 0] + mu1 = forest_pred[y_train[:, 0] == 1] + n0 = np.sum(y_train[:, 0] == 0) + n1 = np.sum(y_train[:, 0] == 1) + u0 = self.rng.uniform( + low=0.0, + high=norm.cdf(0 - mu0), + size=n0, + ) + u1 = self.rng.uniform( + low=norm.cdf(0 - mu1), + high=1.0, + size=n1, + ) + resid_train[y_train[:, 0] == 0, 0] = mu0 + norm.ppf(u0) + resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) + + # Update outcome + new_outcome = np.squeeze(resid_train) - forest_pred + residual_train.update_data(new_outcome) + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -2341,6 +2507,9 @@ def to_json(self) -> str: bcf_json.add_boolean( "internal_propensity_model", self.internal_propensity_model ) + bcf_json.add_boolean( + "probit_outcome_model", self.probit_outcome_model + ) # Add parameter samples if self.sample_sigma_global: @@ -2425,6 +2594,9 @@ def from_json(self, json_string: str) -> None: self.internal_propensity_model = bcf_json.get_boolean( "internal_propensity_model" ) + self.probit_outcome_model = bcf_json.get_boolean( + "probit_outcome_model" + ) # Unpack parameter samples if self.sample_sigma_global: From d669b5d75757ba7f36c3337898046e2fedde3940 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 2 May 2025 12:40:20 -0500 Subject: [PATCH 11/14] Fixed python probit BCF bug and updated demo script --- demo/debug/causal_inference_binary_outcome.py | 69 ++++++++++++------- stochtree/bcf.py | 5 +- stochtree/preprocessing.py | 1 - 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/demo/debug/causal_inference_binary_outcome.py b/demo/debug/causal_inference_binary_outcome.py index 57cbd5fe..c603927d 100644 --- a/demo/debug/causal_inference_binary_outcome.py +++ b/demo/debug/causal_inference_binary_outcome.py @@ -13,7 +13,7 @@ rng = np.random.default_rng(random_seed) # Generate covariates and basis -n = 2000 +n = 4000 x1 = rng.normal(loc=0., scale=1., size=(n,)) x2 = rng.normal(loc=0., scale=1., size=(n,)) x3 = rng.normal(loc=0., scale=1., size=(n,)) @@ -26,20 +26,20 @@ "x1": pd.Series(x1), "x2": pd.Series(x2), "x3": pd.Series(x3), - "x4": pd.Series(x4), - "x5": pd.Series(x5) + "x4": pd.Series(x4_cat), + "x5": pd.Series(x5_cat) }) -def g(x): +def g(x5): return np.where( - x.loc[:,"x5"] == 0, 2.0, + x5 == 0, 2.0, np.where( - x.loc[:,"x5"] == 1, -1.0, -4.0 + x5 == 1, -1.0, -4.0 ) ) -mu_x = (1.0 + g(X) + X.loc[:,"x1"]*X.loc[:,"x3"])*0.25 -tau_x = (1.0 + 2*X.loc[:,"x2"]*X.loc[:,"x4"])*0.5 +mu_x = (1.0 + g(x5) + x1*x3)*0.25 +tau_x = (1.0 + 2*x2*x4)*0.5 pi_x = ( - 0.8*norm.cdf(3.0*mu_x / np.squeeze(np.std(mu_x)) - 0.5*X.loc[:,"x1"]) + + 0.8*norm.cdf(3.0*mu_x / np.squeeze(np.std(mu_x)) - 0.5*x1) + 0.05 + rng.uniform(low=0., high=0.1, size=(n,)) ) Z = rng.binomial(n=1, p=pi_x, size=(n,)) @@ -50,7 +50,7 @@ def g(x): # Test-train split sample_inds = np.arange(n) -train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) X_train = X.iloc[train_inds,:] X_test = X.iloc[test_inds,:] Z_train = Z[train_inds] @@ -86,11 +86,10 @@ def g(x): # Construct parameter lists general_params = { - 'keep_every': 1, + 'keep_every': 5, 'probit_outcome_model': True, 'sample_sigma2_global': False, - 'adaptive_coding': False, - 'num_chains': 1} + 'adaptive_coding': False} prognostic_forest_params = { 'sample_sigma2_leaf': False, 'sigma2_leaf_init': sigma2_leaf_mu, @@ -109,22 +108,46 @@ def g(x): treatment_effect_forest_params=treatment_effect_forest_params) # Inspect the MCMC (BART) samples -plt.scatter(np.squeeze(bcf_model.y_hat_test).mean(axis = 1), y_test, color="black") +mu_hat_test = np.squeeze(bcf_model.mu_hat_test).mean(axis = 1) +plt.scatter(mu_hat_test, mu_test, color="black") plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Prognostic function") plt.show() -plt.scatter(np.squeeze(bcf_model.tau_hat_test).mean(axis = 1), tau_test, color="black") +tau_hat_test = np.squeeze(bcf_model.tau_hat_test).mean(axis = 1) +plt.scatter(tau_hat_test, tau_test, color="black") plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Probit-scale treatment effect function") plt.show() -plt.scatter(np.squeeze(bcf_model.mu_hat_test).mean(axis = 1), mu_test, color="black") +delta_hat_test = norm.cdf(mu_hat_test + tau_hat_test) - norm.cdf(mu_hat_test) +plt.scatter(delta_hat_test, delta_test, color="black") plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Distributional treatment effect function") plt.show() -# # Compute RMSEs -# y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2))) -# tau_rmse = np.sqrt(np.mean(np.power(np.expand_dims(tau_test,1) - tau_avg_mcmc, 2))) -# mu_rmse = np.sqrt(np.mean(np.power(np.expand_dims(mu_test,1) - mu_avg_mcmc, 2))) -# print("y hat RMSE: {:.2f}".format(y_rmse)) -# print("tau hat RMSE: {:.2f}".format(tau_rmse)) -# print("mu hat RMSE: {:.2f}".format(mu_rmse)) +w_hat_test = np.squeeze(bcf_model.y_hat_test).mean(axis = 1) +plt.scatter(w_hat_test, w_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Probit scale latent outcome") +plt.show() + +# Compute prediction accuracy +preds_test = w_hat_test > 0 +print(f"Test set accuracy: {np.mean(y_test == preds_test):.3f}") + +# Compute RMSEs +w_rmse = np.sqrt(np.mean(np.power(w_test - w_hat_test, 2))) +tau_rmse = np.sqrt(np.mean(np.power(tau_test - tau_hat_test, 2))) +mu_rmse = np.sqrt(np.mean(np.power(mu_test - mu_hat_test, 2))) +print("w hat RMSE: {:.2f}".format(w_rmse)) +print("tau hat RMSE: {:.2f}".format(tau_rmse)) +print("mu hat RMSE: {:.2f}".format(mu_rmse)) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index c532754e..6ee743a5 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -301,6 +301,7 @@ def sample( keep_burnin = general_params_updated["keep_burnin"] keep_gfr = general_params_updated["keep_gfr"] keep_every = general_params_updated["keep_every"] + num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] # 2. Mu forest parameters @@ -1635,7 +1636,7 @@ def sample( if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) - forest_pred_tau = active_forest_mu.predict(forest_dataset_train) + forest_pred_tau = active_forest_tau.predict(forest_dataset_train) forest_pred = forest_pred_mu + forest_pred_tau mu0 = forest_pred[y_train[:, 0] == 0] mu1 = forest_pred[y_train[:, 0] == 1] @@ -1816,7 +1817,7 @@ def sample( if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) - forest_pred_tau = active_forest_mu.predict(forest_dataset_train) + forest_pred_tau = active_forest_tau.predict(forest_dataset_train) forest_pred = forest_pred_mu + forest_pred_tau mu0 = forest_pred[y_train[:, 0] == 0] mu1 = forest_pred[y_train[:, 0] == 1] diff --git a/stochtree/preprocessing.py b/stochtree/preprocessing.py index 53f12e37..954a0f9b 100644 --- a/stochtree/preprocessing.py +++ b/stochtree/preprocessing.py @@ -389,7 +389,6 @@ def _transform_pandas(self, covariates: pd.DataFrame) -> np.array: ) output_iter = 0 original_feature_indices = [] - print(self._original_feature_types) for i in range(covariates.shape[1]): covariate = covariates.iloc[:, i] if ( From 36944b303b08cbc4642b6d293af40ba3914e9f9f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 2 May 2025 13:19:35 -0500 Subject: [PATCH 12/14] Updated changelog --- CHANGELOG.md | 38 ++++++++++++++++++++++++++++++++++++++ NEWS.md | 8 +++++++- 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..d5cf26c5 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,38 @@ +# Changelog + +# stochtree 0.1.2 + +## New Features + +* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164)) + +## Bug Fixes + +* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models +* Avoid using covariate preprocessor in `computeForestLeafIndices` R function when a `ForestSamples` object is provided (instead of a `bartmodel` or `bcfmodel` object) + +# stochtree 0.1.1 + +## Bug Fixes + +* Fixed initialization bug in several R package code examples for random effects models + +# stochtree 0.1.0 + +Initial "alpha" release + +## New Features + +* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR) +* High-level model types supported: + * Supervised learning with constant leaves or user-specified leaf regression models + * Causal effect estimation with binary or continuous treatments +* Additional high-level modeling features: + * Forest-based variance function estimation (heteroskedasticity) + * Additive (univariate or multivariate) group random effects + * Multi-chain sampling and support for parallelism + * "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm + * Automated preprocessing / handling of categorical variables +* Low-level interface: + * Ability to combine a forest sampler with other (additive) model terms, without using C++ + * Combine and sample an arbitrary number of forests or random effects terms diff --git a/NEWS.md b/NEWS.md index 7fddbbf1..fc00c05e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,13 @@ # stochtree 0.1.2 +## New Features + +* Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164)) + +## Bug Fixes + * Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models -* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided +* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided (rather than a `bartmodel` or `bcfmodel` object) # stochtree 0.1.1 From 7b3a5ae3f5f55eac9cbab4290708ff3308363493 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 5 May 2025 13:50:17 -0500 Subject: [PATCH 13/14] Updated python probit BART interface and demos --- demo/debug/supervised_learning.py | 2 +- .../supervised_learning_binary_outcome.py | 112 ++++++++++++++++++ .../supervised_learning_classification.ipynb | 6 +- stochtree/bart.py | 4 +- 4 files changed, 117 insertions(+), 7 deletions(-) create mode 100644 demo/debug/supervised_learning_binary_outcome.py diff --git a/demo/debug/supervised_learning.py b/demo/debug/supervised_learning.py index bbbd99dd..955d83ec 100644 --- a/demo/debug/supervised_learning.py +++ b/demo/debug/supervised_learning.py @@ -44,7 +44,7 @@ def outcome_mean(X, W): # Test-train split sample_inds = np.arange(n) -train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) X_train = X[train_inds,:] X_test = X[test_inds,:] basis_train = W[train_inds,:] diff --git a/demo/debug/supervised_learning_binary_outcome.py b/demo/debug/supervised_learning_binary_outcome.py new file mode 100644 index 00000000..430d5b59 --- /dev/null +++ b/demo/debug/supervised_learning_binary_outcome.py @@ -0,0 +1,112 @@ +# Supervised Learning Demo Script + +# Load necessary libraries +import numpy as np +import matplotlib.pyplot as plt +from stochtree import BARTModel +from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_curve + +# Generate sample data +# RNG +rng = np.random.default_rng() + +# Generate covariates and basis +n = 1000 +p_X = 10 +p_basis = 1 +X = rng.uniform(0, 1, (n, p_X)) +basis = rng.uniform(0, 1, (n, p_basis)) + +# Define the outcome mean function +def outcome_mean(X, basis = None): + if basis is not None: + return np.where( + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * basis[:,0], + np.where( + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * basis[:,0], + np.where( + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * basis[:,0], + 7.5 * basis[:,0] + ) + ) + ) + else: + return np.where( + (X[:,0] >= 0.0) & (X[:,0] < 0.25), -7.5 * X[:,1], + np.where( + (X[:,0] >= 0.25) & (X[:,0] < 0.5), -2.5 * X[:,1], + np.where( + (X[:,0] >= 0.5) & (X[:,0] < 0.75), 2.5 * X[:,1], + 7.5 * X[:,1] + ) + ) + ) + + +# Generate outcome +epsilon = rng.normal(0, 1, n) +w = outcome_mean(X, basis) + epsilon +# w = outcome_mean(X) + epsilon +y = np.where(w > 0, 1, 0) + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.2) +X_train = X[train_inds,:] +X_test = X[test_inds,:] +basis_train = basis[train_inds,:] +basis_test = basis[test_inds,:] +w_train = w[train_inds] +w_test = w[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] + +# Construct parameter lists +general_params = { + 'probit_outcome_model': True, + 'sample_sigma2_global': False +} + +# Run BART +num_gfr = 10 +num_mcmc = 100 +bart_model = BARTModel() +bart_model.sample(X_train=X_train, y_train=y_train, leaf_basis_train=basis_train, + X_test=X_test, leaf_basis_test=basis_test, num_gfr=num_gfr, + num_burnin=0, num_mcmc=num_mcmc, general_params=general_params) +# bart_model.sample(X_train=X_train, y_train=y_train, X_test=X_test, num_gfr=num_gfr, +# num_burnin=0, num_mcmc=num_mcmc, general_params=general_params) + +# Inspect the MCMC (BART) samples +w_hat_test = np.squeeze(bart_model.y_hat_test).mean(axis = 1) +plt.scatter(w_hat_test, w_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Probit scale latent outcome") +plt.show() + +# Compute prediction accuracy +preds_test = w_hat_test > 0 +print(f"Test set accuracy: {np.mean(y_test == preds_test):.3f}") + +# Present a ROC curve +fpr_list = list() +tpr_list = list() +threshold_list = list() +for i in range(num_mcmc): + fpr, tpr, thresholds = roc_curve(y_test, bart_model.y_hat_test[:,i], pos_label=1) + fpr_list.append(fpr) + tpr_list.append(tpr) + threshold_list.append(thresholds) +fpr_mean, tpr_mean, thresholds_mean = roc_curve(y_test, w_hat_test, pos_label=1) +for i in range(num_mcmc): + plt.plot(fpr_list[i], tpr_list[i], color = 'blue', linestyle='solid', linewidth = 1.25) +plt.plot(fpr_mean, tpr_mean, color = 'black', linestyle='dashed', linewidth = 2.0) +plt.axline((0, 0), slope=1, color="red", linestyle='dashed', linewidth=1.5) +plt.xlabel("False Positive Rate") +plt.ylabel("True Positive Rate") +plt.xlim(0, 1) +plt.ylim(0, 1) +plt.show() diff --git a/demo/notebooks/supervised_learning_classification.ipynb b/demo/notebooks/supervised_learning_classification.ipynb index e88b1b7b..40feb510 100644 --- a/demo/notebooks/supervised_learning_classification.ipynb +++ b/demo/notebooks/supervised_learning_classification.ipynb @@ -110,16 +110,14 @@ "num_gfr = 10\n", "num_mcmc = 100\n", "bart_model = BARTModel()\n", - "general_params = {\"num_chains\": 1}\n", - "mean_forest_params = {\"probit_outcome_model\": True}\n", + "general_params = {\"num_chains\": 1, \"probit_outcome_model\": True}\n", "bart_model.sample(\n", " X_train=X_train,\n", " y_train=y_train,\n", " X_test=X_test,\n", " num_gfr=num_gfr,\n", " num_mcmc=num_mcmc,\n", - " general_params=general_params,\n", - " mean_forest_params=mean_forest_params\n", + " general_params=general_params\n", ")" ] }, diff --git a/stochtree/bart.py b/stochtree/bart.py index fd1159de..3bbc8d8e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -187,6 +187,7 @@ def sample( "keep_gfr": False, "keep_every": 1, "num_chains": 1, + "probit_outcome_model": False, } general_params_updated = _preprocess_params( general_params_default, general_params @@ -205,7 +206,6 @@ def sample( "sigma2_leaf_scale": None, "keep_vars": None, "drop_vars": None, - "probit_outcome_model": False, } mean_forest_params_updated = _preprocess_params( mean_forest_params_default, mean_forest_params @@ -243,6 +243,7 @@ def sample( keep_gfr = general_params_updated["keep_gfr"] keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] + self.probit_outcome_model = general_params_updated["probit_outcome_model"] # 2. Mean forest parameters num_trees_mean = mean_forest_params_updated["num_trees"] @@ -256,7 +257,6 @@ def sample( b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] drop_vars_mean = mean_forest_params_updated["drop_vars"] - self.probit_outcome_model = mean_forest_params_updated["probit_outcome_model"] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] From 58e6d2282fef307ea6471023638543e351032ba4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 6 May 2025 02:01:26 -0500 Subject: [PATCH 14/14] Updated probit BCF --- R/bcf.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/bcf.R b/R/bcf.R index 8fd46808..8161fe69 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -740,7 +740,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # treatment_effect_forest_params. p <- 0.6827 q_quantile <- qnorm((p+1)/2) - sigma2_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau + sigma_leaf_tau <- ((delta_max/(q_quantile*dnorm(0)))^2)/num_trees_tau current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) } else { if (!is.matrix(sigma_leaf_tau)) {