From 91c961401baa32ef94af260de227fcb18e9ce1f2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 25 Jul 2025 18:13:52 -0300 Subject: [PATCH 1/2] Making multivariate treatment BCF work in R --- R/bcf.R | 96 ++++++++++++++++++++-------- tools/debug/multivariate_bcf_debug.R | 56 ++++++++++++++++ 2 files changed, 127 insertions(+), 25 deletions(-) create mode 100644 tools/debug/multivariate_bcf_debug.R diff --git a/R/bcf.R b/R/bcf.R index acede05d..8af79815 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -504,14 +504,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) # Convert all input data to matrices if not already converted - if ((is.null(dim(Z_train))) && (!is.null(Z_train))) { - Z_train <- as.matrix(as.numeric(Z_train)) - } + Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) + Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { propensity_train <- as.matrix(propensity_train) } - if ((is.null(dim(Z_test))) && (!is.null(Z_test))) { - Z_test <- as.matrix(as.numeric(Z_test)) + if (!is.null(Z_test)) { + Z_test <- matrix(as.numeric(Z_test), ncol = Z_col) } if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { propensity_test <- as.matrix(propensity_test) @@ -580,9 +579,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } - # Stop if multivariate treatment is provided - if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - + # # Stop if multivariate treatment is provided + # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") + + # Handle multivariate treatment + has_multivariate_treatment <- ncol(Z_train) > 1 + if (has_multivariate_treatment) { + # Disable adaptive coding, internal propensity model, and + # leaf scale sampling if treatment is multivariate + if (adaptive_coding) { + warning("Adaptive coding is incompatible with multivariate treatment and will be ignored") + adaptive_coding <- FALSE + } + if (is.null(propensity_train)) { + if (propensity_covariate != "none") { + warning("No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'") + propensity_covariate <- "none" + } + } + if (sample_sigma2_leaf_tau) { + warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.") + sample_sigma2_leaf_tau <- FALSE + } + } + # Random effects covariance prior if (has_rfx) { if (is.null(rfx_prior_var)) { @@ -835,18 +855,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id current_sigma2 <- sigma2_init } - # Switch off leaf scale sampling for multivariate treatments - if (ncol(Z_train) > 1) { - if (sample_sigma2_leaf_tau) { - warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.") - sample_sigma2_leaf_tau <- FALSE - } - } - # Set mu and tau leaf models / dimensions leaf_model_mu_forest <- 0 leaf_dimension_mu_forest <- 1 - if (ncol(Z_train) > 1) { + if (has_multivariate_treatment) { leaf_model_tau_forest <- 2 leaf_dimension_tau_forest <- ncol(Z_train) } else { @@ -973,21 +985,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Container of forest samples forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) - forest_samples_tau <- createForestSamples(num_trees_tau, 1, FALSE) + forest_samples_tau <- createForestSamples(num_trees_tau, ncol(Z_train), FALSE) active_forest_mu <- createForest(num_trees_mu, 1, TRUE) - active_forest_tau <- createForest(num_trees_tau, 1, FALSE) + active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) if (include_variance_forest) { forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE) active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE) } # Initialize the leaves of each tree in the prognostic forest - active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu) + active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, leaf_model_mu_forest, init_mu) active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, FALSE, FALSE) # Initialize the leaves of each tree in the treatment effect forest - init_tau <- 0. - active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau) + init_tau <- rep(0., ncol(Z_train)) + active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, leaf_model_tau_forest, init_tau) active_forest_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, TRUE, FALSE) # Initialize the leaves of each tree in the variance forest @@ -1450,7 +1462,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } else { tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train)*y_std_train } - y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train) + if (has_multivariate_treatment) { + tau_train_dim <- dim(tau_hat_train) + tau_num_obs <- tau_train_dim[1] + tau_num_samples <- tau_train_dim[3] + treatment_term_train <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z_train)) { + treatment_term_train[i,] <- colSums(tau_hat_train[i,,] * Z_train[i,]) + } + } else { + treatment_term_train <- tau_hat_train * as.numeric(Z_train) + } + y_hat_train <- mu_hat_train + treatment_term_train if (has_test) { mu_hat_test <- forest_samples_mu$predict(forest_dataset_test)*y_std_train + y_bar_train if (adaptive_coding) { @@ -1459,7 +1482,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } else { tau_hat_test <- forest_samples_tau$predict_raw(forest_dataset_test)*y_std_train } - y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test) + if (has_multivariate_treatment) { + tau_test_dim <- dim(tau_hat_test) + tau_num_obs <- tau_test_dim[1] + tau_num_samples <- tau_test_dim[3] + treatment_term_test <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z_test)) { + treatment_term_test[i,] <- colSums(tau_hat_test[i,,] * Z_test[i,]) + } + } else { + treatment_term_test <- tau_hat_test * as.numeric(Z_test) + } + y_hat_test <- mu_hat_test + treatment_term_test } if (include_variance_forest) { sigma2_x_hat_train <- exp(sigma2_x_train_raw) @@ -1526,6 +1560,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id "treatment_dim" = ncol(Z_train), "propensity_covariate" = propensity_covariate, "binary_treatment" = binary_treatment, + "multivariate_treatment" = has_multivariate_treatment, "adaptive_coding" = adaptive_coding, "internal_propensity_model" = internal_propensity_model, "num_samples" = num_retained_samples, @@ -1722,6 +1757,17 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU } else { tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred)*y_std } + if (object$model_params$multivariate_treatment) { + tau_dim <- dim(tau_hat) + tau_num_obs <- tau_dim[1] + tau_num_samples <- tau_dim[3] + treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z_train)) { + treatment_term[i,] <- colSums(tau_hat[i,,] * Z[i,]) + } + } else { + treatment_term <- tau_hat * as.numeric(Z) + } if (object$model_params$include_variance_forest) { s_x_raw <- object$forests_variance$predict(forest_dataset_pred) } @@ -1732,7 +1778,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU } # Compute overall "y_hat" predictions - y_hat <- mu_hat + tau_hat * as.numeric(Z) + y_hat <- mu_hat + treatment_term if (object$model_params$has_rfx) y_hat <- y_hat + rfx_predictions # Scale variance forest predictions diff --git a/tools/debug/multivariate_bcf_debug.R b/tools/debug/multivariate_bcf_debug.R new file mode 100644 index 00000000..f61b9090 --- /dev/null +++ b/tools/debug/multivariate_bcf_debug.R @@ -0,0 +1,56 @@ +# Load libraries +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +snr <- 2.0 +X <- matrix(runif(n*p), ncol = p) +pi_x <- cbind(0.25 + 0.5 * X[, 1], 0.75 - 0.5 * X[, 2]) +mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3] +tau_x <- cbind(X[, 2], X[, 3]) +Z <- matrix(NA_integer_, nrow = n, ncol = ncol(pi_x)) +for (i in 1:ncol(pi_x)) { + Z[, i] <- rbinom(n, 1, pi_x[, i]) +} +E_XZ <- mu_x + rowSums(Z * tau_x) +y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ) / snr) + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds,] +pi_train <- pi_x[train_inds,] +Z_test <- Z[test_inds,] +Z_train <- Z[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds,] +tau_train <- tau_x[train_inds,] + +# Run BCF +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +general_params <- list(adaptive_coding = F) +prognostic_forest_params <- list(sample_sigma2_leaf = F) +treatment_effect_forest_params <- list(sample_sigma2_leaf = F) +bcf_model <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, + X_test = X_test, Z_test = Z_test, propensity_test = pi_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + general_params = general_params, prognostic_forest_params = prognostic_forest_params, + treatment_effect_forest_params = treatment_effect_forest_params +) + +# Check results +y_hat_test_mean <- rowMeans(bcf_model$y_hat_test) +plot(y_hat_test_mean, y_test); abline(0,1,col="red") +sqrt(mean((y_hat_test_mean - y_test)^2)) From 03c291e2c87ee849031fb36739e9347cb724591b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 25 Jul 2025 21:04:19 -0300 Subject: [PATCH 2/2] Updated BCF --- R/bcf.R | 6 +++++- test/R/testthat/test-bcf.R | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 8af79815..053f8b21 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1762,7 +1762,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU tau_num_obs <- tau_dim[1] tau_num_samples <- tau_dim[3] treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) - for (i in 1:nrow(Z_train)) { + for (i in 1:nrow(Z)) { treatment_term[i,] <- colSums(tau_hat[i,,] * Z[i,]) } } else { @@ -2020,6 +2020,7 @@ saveBCFModelToJson <- function(object){ jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_boolean("multivariate_treatment", object$model_params$multivariate_treatment) jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model) jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) @@ -2351,6 +2352,7 @@ createBCFModelFromJson <- function(json_object){ model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding") + model_params[["multivariate_treatment"]] <- json_object$get_boolean("multivariate_treatment") model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model") model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") @@ -2690,6 +2692,7 @@ createBCFModelFromCombinedJson <- function(json_object_list){ model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment") model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") @@ -2916,6 +2919,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates") model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment") model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model") diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 9b54ac05..531320f6 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -419,13 +419,14 @@ test_that("Multivariate Treatment MCMC BCF", { y_train <- y[train_inds] # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_error( + general_param_list <- list(num_chains = 1, keep_every = 1, adaptive_coding = F) + expect_no_error({ bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, propensity_train = pi_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, num_gfr = 0, num_burnin = 10, num_mcmc = 10, general_params = general_param_list) - ) + predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test) + }) }) test_that("BCF Predictions", {