Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 75 additions & 25 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)

# Convert all input data to matrices if not already converted
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
Z_train <- as.matrix(as.numeric(Z_train))
}
Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train))
Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)
if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
propensity_train <- as.matrix(propensity_train)
}
if ((is.null(dim(Z_test))) && (!is.null(Z_test))) {
Z_test <- as.matrix(as.numeric(Z_test))
if (!is.null(Z_test)) {
Z_test <- matrix(as.numeric(Z_test), ncol = Z_col)
}
if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
propensity_test <- as.matrix(propensity_test)
Expand Down Expand Up @@ -580,9 +579,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
}
}

# Stop if multivariate treatment is provided
if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")

# # Stop if multivariate treatment is provided
# if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")

# Handle multivariate treatment
has_multivariate_treatment <- ncol(Z_train) > 1
if (has_multivariate_treatment) {
# Disable adaptive coding, internal propensity model, and
# leaf scale sampling if treatment is multivariate
if (adaptive_coding) {
warning("Adaptive coding is incompatible with multivariate treatment and will be ignored")
adaptive_coding <- FALSE
}
if (is.null(propensity_train)) {
if (propensity_covariate != "none") {
warning("No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'")
propensity_covariate <- "none"
}
}
if (sample_sigma2_leaf_tau) {
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
sample_sigma2_leaf_tau <- FALSE
}
}

# Random effects covariance prior
if (has_rfx) {
if (is.null(rfx_prior_var)) {
Expand Down Expand Up @@ -835,18 +855,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
current_sigma2 <- sigma2_init
}

# Switch off leaf scale sampling for multivariate treatments
if (ncol(Z_train) > 1) {
if (sample_sigma2_leaf_tau) {
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
sample_sigma2_leaf_tau <- FALSE
}
}

# Set mu and tau leaf models / dimensions
leaf_model_mu_forest <- 0
leaf_dimension_mu_forest <- 1
if (ncol(Z_train) > 1) {
if (has_multivariate_treatment) {
leaf_model_tau_forest <- 2
leaf_dimension_tau_forest <- ncol(Z_train)
} else {
Expand Down Expand Up @@ -973,21 +985,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id

# Container of forest samples
forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE)
forest_samples_tau <- createForestSamples(num_trees_tau, 1, FALSE)
forest_samples_tau <- createForestSamples(num_trees_tau, ncol(Z_train), FALSE)
active_forest_mu <- createForest(num_trees_mu, 1, TRUE)
active_forest_tau <- createForest(num_trees_tau, 1, FALSE)
active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE)
if (include_variance_forest) {
forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE)
active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE)
}

# Initialize the leaves of each tree in the prognostic forest
active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu)
active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, leaf_model_mu_forest, init_mu)
active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, FALSE, FALSE)

# Initialize the leaves of each tree in the treatment effect forest
init_tau <- 0.
active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau)
init_tau <- rep(0., ncol(Z_train))
active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, leaf_model_tau_forest, init_tau)
active_forest_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, TRUE, FALSE)

# Initialize the leaves of each tree in the variance forest
Expand Down Expand Up @@ -1450,7 +1462,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
} else {
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train)*y_std_train
}
y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train)
if (has_multivariate_treatment) {
tau_train_dim <- dim(tau_hat_train)
tau_num_obs <- tau_train_dim[1]
tau_num_samples <- tau_train_dim[3]
treatment_term_train <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
for (i in 1:nrow(Z_train)) {
treatment_term_train[i,] <- colSums(tau_hat_train[i,,] * Z_train[i,])
}
} else {
treatment_term_train <- tau_hat_train * as.numeric(Z_train)
}
y_hat_train <- mu_hat_train + treatment_term_train
if (has_test) {
mu_hat_test <- forest_samples_mu$predict(forest_dataset_test)*y_std_train + y_bar_train
if (adaptive_coding) {
Expand All @@ -1459,7 +1482,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
} else {
tau_hat_test <- forest_samples_tau$predict_raw(forest_dataset_test)*y_std_train
}
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
if (has_multivariate_treatment) {
tau_test_dim <- dim(tau_hat_test)
tau_num_obs <- tau_test_dim[1]
tau_num_samples <- tau_test_dim[3]
treatment_term_test <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
for (i in 1:nrow(Z_test)) {
treatment_term_test[i,] <- colSums(tau_hat_test[i,,] * Z_test[i,])
}
} else {
treatment_term_test <- tau_hat_test * as.numeric(Z_test)
}
y_hat_test <- mu_hat_test + treatment_term_test
}
if (include_variance_forest) {
sigma2_x_hat_train <- exp(sigma2_x_train_raw)
Expand Down Expand Up @@ -1526,6 +1560,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
"treatment_dim" = ncol(Z_train),
"propensity_covariate" = propensity_covariate,
"binary_treatment" = binary_treatment,
"multivariate_treatment" = has_multivariate_treatment,
"adaptive_coding" = adaptive_coding,
"internal_propensity_model" = internal_propensity_model,
"num_samples" = num_retained_samples,
Expand Down Expand Up @@ -1722,6 +1757,17 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
} else {
tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred)*y_std
}
if (object$model_params$multivariate_treatment) {
tau_dim <- dim(tau_hat)
tau_num_obs <- tau_dim[1]
tau_num_samples <- tau_dim[3]
treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
for (i in 1:nrow(Z)) {
treatment_term[i,] <- colSums(tau_hat[i,,] * Z[i,])
}
} else {
treatment_term <- tau_hat * as.numeric(Z)
}
if (object$model_params$include_variance_forest) {
s_x_raw <- object$forests_variance$predict(forest_dataset_pred)
}
Expand All @@ -1732,7 +1778,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
}

# Compute overall "y_hat" predictions
y_hat <- mu_hat + tau_hat * as.numeric(Z)
y_hat <- mu_hat + treatment_term
if (object$model_params$has_rfx) y_hat <- y_hat + rfx_predictions

# Scale variance forest predictions
Expand Down Expand Up @@ -1974,6 +2020,7 @@ saveBCFModelToJson <- function(object){
jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
jsonobj$add_boolean("multivariate_treatment", object$model_params$multivariate_treatment)
jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding)
jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model)
jsonobj$add_scalar("num_gfr", object$model_params$num_gfr)
Expand Down Expand Up @@ -2305,6 +2352,7 @@ createBCFModelFromJson <- function(json_object){
model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding")
model_params[["multivariate_treatment"]] <- json_object$get_boolean("multivariate_treatment")
model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model")
model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
Expand Down Expand Up @@ -2644,6 +2692,7 @@ createBCFModelFromCombinedJson <- function(json_object_list){
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding")
model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment")
model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")

Expand Down Expand Up @@ -2870,6 +2919,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment")
model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding")
model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model")
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
Expand Down
7 changes: 4 additions & 3 deletions test/R/testthat/test-bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,14 @@ test_that("Multivariate Treatment MCMC BCF", {
y_train <- y[train_inds]

# 1 chain, no thinning
general_param_list <- list(num_chains = 1, keep_every = 1)
expect_error(
general_param_list <- list(num_chains = 1, keep_every = 1, adaptive_coding = F)
expect_no_error({
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
num_mcmc = 10, general_params = general_param_list)
)
predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test)
})
})

test_that("BCF Predictions", {
Expand Down
56 changes: 56 additions & 0 deletions tools/debug/multivariate_bcf_debug.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Load libraries
library(stochtree)

# Generate data
n <- 500
p <- 5
snr <- 2.0
X <- matrix(runif(n*p), ncol = p)
pi_x <- cbind(0.25 + 0.5 * X[, 1], 0.75 - 0.5 * X[, 2])
mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3]
tau_x <- cbind(X[, 2], X[, 3])
Z <- matrix(NA_integer_, nrow = n, ncol = ncol(pi_x))
for (i in 1:ncol(pi_x)) {
Z[, i] <- rbinom(n, 1, pi_x[, i])
}
E_XZ <- mu_x + rowSums(Z * tau_x)
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ) / snr)

# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
pi_test <- pi_x[test_inds,]
pi_train <- pi_x[train_inds,]
Z_test <- Z[test_inds,]
Z_train <- Z[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]
mu_test <- mu_x[test_inds]
mu_train <- mu_x[train_inds]
tau_test <- tau_x[test_inds,]
tau_train <- tau_x[train_inds,]

# Run BCF
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
general_params <- list(adaptive_coding = F)
prognostic_forest_params <- list(sample_sigma2_leaf = F)
treatment_effect_forest_params <- list(sample_sigma2_leaf = F)
bcf_model <- bcf(
X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train,
X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, prognostic_forest_params = prognostic_forest_params,
treatment_effect_forest_params = treatment_effect_forest_params
)

# Check results
y_hat_test_mean <- rowMeans(bcf_model$y_hat_test)
plot(y_hat_test_mean, y_test); abline(0,1,col="red")
sqrt(mean((y_hat_test_mean - y_test)^2))
Loading