Skip to content

Commit 0ded1e8

Browse files
authored
Merge pull request #183 from StochasticTree/multivariate-bcf-hotfix
Making multivariate treatment BCF work in R
2 parents adf891c + 03c291e commit 0ded1e8

File tree

3 files changed

+135
-28
lines changed

3 files changed

+135
-28
lines changed

R/bcf.R

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
504504
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
505505

506506
# Convert all input data to matrices if not already converted
507-
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
508-
Z_train <- as.matrix(as.numeric(Z_train))
509-
}
507+
Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train))
508+
Z_train <- matrix(as.numeric(Z_train), ncol = Z_col)
510509
if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) {
511510
propensity_train <- as.matrix(propensity_train)
512511
}
513-
if ((is.null(dim(Z_test))) && (!is.null(Z_test))) {
514-
Z_test <- as.matrix(as.numeric(Z_test))
512+
if (!is.null(Z_test)) {
513+
Z_test <- matrix(as.numeric(Z_test), ncol = Z_col)
515514
}
516515
if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) {
517516
propensity_test <- as.matrix(propensity_test)
@@ -580,9 +579,30 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
580579
}
581580
}
582581

583-
# Stop if multivariate treatment is provided
584-
if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")
585-
582+
# # Stop if multivariate treatment is provided
583+
# if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")
584+
585+
# Handle multivariate treatment
586+
has_multivariate_treatment <- ncol(Z_train) > 1
587+
if (has_multivariate_treatment) {
588+
# Disable adaptive coding, internal propensity model, and
589+
# leaf scale sampling if treatment is multivariate
590+
if (adaptive_coding) {
591+
warning("Adaptive coding is incompatible with multivariate treatment and will be ignored")
592+
adaptive_coding <- FALSE
593+
}
594+
if (is.null(propensity_train)) {
595+
if (propensity_covariate != "none") {
596+
warning("No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'")
597+
propensity_covariate <- "none"
598+
}
599+
}
600+
if (sample_sigma2_leaf_tau) {
601+
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
602+
sample_sigma2_leaf_tau <- FALSE
603+
}
604+
}
605+
586606
# Random effects covariance prior
587607
if (has_rfx) {
588608
if (is.null(rfx_prior_var)) {
@@ -835,18 +855,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
835855
current_sigma2 <- sigma2_init
836856
}
837857

838-
# Switch off leaf scale sampling for multivariate treatments
839-
if (ncol(Z_train) > 1) {
840-
if (sample_sigma2_leaf_tau) {
841-
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
842-
sample_sigma2_leaf_tau <- FALSE
843-
}
844-
}
845-
846858
# Set mu and tau leaf models / dimensions
847859
leaf_model_mu_forest <- 0
848860
leaf_dimension_mu_forest <- 1
849-
if (ncol(Z_train) > 1) {
861+
if (has_multivariate_treatment) {
850862
leaf_model_tau_forest <- 2
851863
leaf_dimension_tau_forest <- ncol(Z_train)
852864
} else {
@@ -973,21 +985,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
973985

974986
# Container of forest samples
975987
forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE)
976-
forest_samples_tau <- createForestSamples(num_trees_tau, 1, FALSE)
988+
forest_samples_tau <- createForestSamples(num_trees_tau, ncol(Z_train), FALSE)
977989
active_forest_mu <- createForest(num_trees_mu, 1, TRUE)
978-
active_forest_tau <- createForest(num_trees_tau, 1, FALSE)
990+
active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE)
979991
if (include_variance_forest) {
980992
forest_samples_variance <- createForestSamples(num_trees_variance, 1, TRUE, TRUE)
981993
active_forest_variance <- createForest(num_trees_variance, 1, TRUE, TRUE)
982994
}
983995

984996
# Initialize the leaves of each tree in the prognostic forest
985-
active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, 0, init_mu)
997+
active_forest_mu$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mu, leaf_model_mu_forest, init_mu)
986998
active_forest_mu$adjust_residual(forest_dataset_train, outcome_train, forest_model_mu, FALSE, FALSE)
987999

9881000
# Initialize the leaves of each tree in the treatment effect forest
989-
init_tau <- 0.
990-
active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, 1, init_tau)
1001+
init_tau <- rep(0., ncol(Z_train))
1002+
active_forest_tau$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_tau, leaf_model_tau_forest, init_tau)
9911003
active_forest_tau$adjust_residual(forest_dataset_train, outcome_train, forest_model_tau, TRUE, FALSE)
9921004

9931005
# Initialize the leaves of each tree in the variance forest
@@ -1450,7 +1462,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
14501462
} else {
14511463
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train)*y_std_train
14521464
}
1453-
y_hat_train <- mu_hat_train + tau_hat_train * as.numeric(Z_train)
1465+
if (has_multivariate_treatment) {
1466+
tau_train_dim <- dim(tau_hat_train)
1467+
tau_num_obs <- tau_train_dim[1]
1468+
tau_num_samples <- tau_train_dim[3]
1469+
treatment_term_train <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
1470+
for (i in 1:nrow(Z_train)) {
1471+
treatment_term_train[i,] <- colSums(tau_hat_train[i,,] * Z_train[i,])
1472+
}
1473+
} else {
1474+
treatment_term_train <- tau_hat_train * as.numeric(Z_train)
1475+
}
1476+
y_hat_train <- mu_hat_train + treatment_term_train
14541477
if (has_test) {
14551478
mu_hat_test <- forest_samples_mu$predict(forest_dataset_test)*y_std_train + y_bar_train
14561479
if (adaptive_coding) {
@@ -1459,7 +1482,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
14591482
} else {
14601483
tau_hat_test <- forest_samples_tau$predict_raw(forest_dataset_test)*y_std_train
14611484
}
1462-
y_hat_test <- mu_hat_test + tau_hat_test * as.numeric(Z_test)
1485+
if (has_multivariate_treatment) {
1486+
tau_test_dim <- dim(tau_hat_test)
1487+
tau_num_obs <- tau_test_dim[1]
1488+
tau_num_samples <- tau_test_dim[3]
1489+
treatment_term_test <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
1490+
for (i in 1:nrow(Z_test)) {
1491+
treatment_term_test[i,] <- colSums(tau_hat_test[i,,] * Z_test[i,])
1492+
}
1493+
} else {
1494+
treatment_term_test <- tau_hat_test * as.numeric(Z_test)
1495+
}
1496+
y_hat_test <- mu_hat_test + treatment_term_test
14631497
}
14641498
if (include_variance_forest) {
14651499
sigma2_x_hat_train <- exp(sigma2_x_train_raw)
@@ -1526,6 +1560,7 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
15261560
"treatment_dim" = ncol(Z_train),
15271561
"propensity_covariate" = propensity_covariate,
15281562
"binary_treatment" = binary_treatment,
1563+
"multivariate_treatment" = has_multivariate_treatment,
15291564
"adaptive_coding" = adaptive_coding,
15301565
"internal_propensity_model" = internal_propensity_model,
15311566
"num_samples" = num_retained_samples,
@@ -1722,6 +1757,17 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
17221757
} else {
17231758
tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred)*y_std
17241759
}
1760+
if (object$model_params$multivariate_treatment) {
1761+
tau_dim <- dim(tau_hat)
1762+
tau_num_obs <- tau_dim[1]
1763+
tau_num_samples <- tau_dim[3]
1764+
treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples)
1765+
for (i in 1:nrow(Z)) {
1766+
treatment_term[i,] <- colSums(tau_hat[i,,] * Z[i,])
1767+
}
1768+
} else {
1769+
treatment_term <- tau_hat * as.numeric(Z)
1770+
}
17251771
if (object$model_params$include_variance_forest) {
17261772
s_x_raw <- object$forests_variance$predict(forest_dataset_pred)
17271773
}
@@ -1732,7 +1778,7 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU
17321778
}
17331779

17341780
# Compute overall "y_hat" predictions
1735-
y_hat <- mu_hat + tau_hat * as.numeric(Z)
1781+
y_hat <- mu_hat + treatment_term
17361782
if (object$model_params$has_rfx) y_hat <- y_hat + rfx_predictions
17371783

17381784
# Scale variance forest predictions
@@ -1974,6 +2020,7 @@ saveBCFModelToJson <- function(object){
19742020
jsonobj$add_boolean("has_rfx", object$model_params$has_rfx)
19752021
jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis)
19762022
jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis)
2023+
jsonobj$add_boolean("multivariate_treatment", object$model_params$multivariate_treatment)
19772024
jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding)
19782025
jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model)
19792026
jsonobj$add_scalar("num_gfr", object$model_params$num_gfr)
@@ -2305,6 +2352,7 @@ createBCFModelFromJson <- function(json_object){
23052352
model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis")
23062353
model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis")
23072354
model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding")
2355+
model_params[["multivariate_treatment"]] <- json_object$get_boolean("multivariate_treatment")
23082356
model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model")
23092357
model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr")
23102358
model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin")
@@ -2644,6 +2692,7 @@ createBCFModelFromCombinedJson <- function(json_object_list){
26442692
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
26452693
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
26462694
model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding")
2695+
model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment")
26472696
model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model")
26482697
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")
26492698

@@ -2870,6 +2919,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){
28702919
model_params[["num_covariates"]] <- json_object_default$get_scalar("num_covariates")
28712920
model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains")
28722921
model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every")
2922+
model_params[["multivariate_treatment"]] <- json_object_default$get_boolean("multivariate_treatment")
28732923
model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding")
28742924
model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model")
28752925
model_params[["probit_outcome_model"]] <- json_object_default$get_boolean("probit_outcome_model")

test/R/testthat/test-bcf.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,14 @@ test_that("Multivariate Treatment MCMC BCF", {
419419
y_train <- y[train_inds]
420420

421421
# 1 chain, no thinning
422-
general_param_list <- list(num_chains = 1, keep_every = 1)
423-
expect_error(
422+
general_param_list <- list(num_chains = 1, keep_every = 1, adaptive_coding = F)
423+
expect_no_error({
424424
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
425425
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
426426
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
427427
num_mcmc = 10, general_params = general_param_list)
428-
)
428+
predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test)
429+
})
429430
})
430431

431432
test_that("BCF Predictions", {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Generate data
5+
n <- 500
6+
p <- 5
7+
snr <- 2.0
8+
X <- matrix(runif(n*p), ncol = p)
9+
pi_x <- cbind(0.25 + 0.5 * X[, 1], 0.75 - 0.5 * X[, 2])
10+
mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3]
11+
tau_x <- cbind(X[, 2], X[, 3])
12+
Z <- matrix(NA_integer_, nrow = n, ncol = ncol(pi_x))
13+
for (i in 1:ncol(pi_x)) {
14+
Z[, i] <- rbinom(n, 1, pi_x[, i])
15+
}
16+
E_XZ <- mu_x + rowSums(Z * tau_x)
17+
y <- E_XZ + rnorm(n, 0, 1)*(sd(E_XZ) / snr)
18+
19+
# Split data into test and train sets
20+
test_set_pct <- 0.2
21+
n_test <- round(test_set_pct*n)
22+
n_train <- n - n_test
23+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
24+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
25+
X_test <- X[test_inds,]
26+
X_train <- X[train_inds,]
27+
pi_test <- pi_x[test_inds,]
28+
pi_train <- pi_x[train_inds,]
29+
Z_test <- Z[test_inds,]
30+
Z_train <- Z[train_inds,]
31+
y_test <- y[test_inds]
32+
y_train <- y[train_inds]
33+
mu_test <- mu_x[test_inds]
34+
mu_train <- mu_x[train_inds]
35+
tau_test <- tau_x[test_inds,]
36+
tau_train <- tau_x[train_inds,]
37+
38+
# Run BCF
39+
num_gfr <- 10
40+
num_burnin <- 0
41+
num_mcmc <- 100
42+
general_params <- list(adaptive_coding = F)
43+
prognostic_forest_params <- list(sample_sigma2_leaf = F)
44+
treatment_effect_forest_params <- list(sample_sigma2_leaf = F)
45+
bcf_model <- bcf(
46+
X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train,
47+
X_test = X_test, Z_test = Z_test, propensity_test = pi_test,
48+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
49+
general_params = general_params, prognostic_forest_params = prognostic_forest_params,
50+
treatment_effect_forest_params = treatment_effect_forest_params
51+
)
52+
53+
# Check results
54+
y_hat_test_mean <- rowMeans(bcf_model$y_hat_test)
55+
plot(y_hat_test_mean, y_test); abline(0,1,col="red")
56+
sqrt(mean((y_hat_test_mean - y_test)^2))

0 commit comments

Comments
 (0)