diff --git a/R/bart.R b/R/bart.R index ba33ae84..c7124a93 100644 --- a/R/bart.R +++ b/R/bart.R @@ -484,14 +484,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (has_basis) { if (ncol(leaf_basis_train) > 1) { if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train)) - current_leaf_scale <- sigma_leaf_init + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train))) + } else { + current_leaf_scale <- sigma_leaf_init + } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean) - current_leaf_scale <- as.matrix(sigma_leaf_init) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } } } else { - if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean) - current_leaf_scale <- as.matrix(sigma_leaf_init) + if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean)) + if (!is.matrix(sigma_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1)) + } else { + current_leaf_scale <- sigma_leaf_init + } } current_sigma2 <- sigma2_init @@ -522,7 +534,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train is_leaf_constant = F leaf_regression = T if (sample_sigma_leaf) { - stop("Sampling leaf scale not yet supported for multivariate leaf models") + warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.") + sample_sigma_leaf <- F } } diff --git a/R/bcf.R b/R/bcf.R index 9118333c..abef39be 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -546,6 +546,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } } } + + # Stop if multivariate treatment is provided + if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") # Random effects covariance prior if (has_rfx) { @@ -650,20 +653,20 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Update feature_types and covariates feature_types <- as.integer(feature_types) if (propensity_covariate != "none") { - feature_types <- as.integer(c(feature_types,0)) + feature_types <- as.integer(c(feature_types,rep(0, ncol(propensity_train)))) X_train <- cbind(X_train, propensity_train) if (propensity_covariate == "mu") { variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train))) - variable_weights_tau <- c(variable_weights_tau, 0) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0) + variable_weights_tau <- c(variable_weights_tau, rep(0, ncol(propensity_train))) + if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) } else if (propensity_covariate == "tau") { - variable_weights_mu <- c(variable_weights_mu, 0) + variable_weights_mu <- c(variable_weights_mu, rep(0, ncol(propensity_train))) variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train))) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0) + if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) } else if (propensity_covariate == "both") { variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train))) variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train))) - if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0) + if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train))) } if (has_test) X_test <- cbind(X_test, propensity_test) } @@ -690,11 +693,37 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train) if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu) if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau) - if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu) - if (is.null(sigma_leaf_tau)) sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau) + if (is.null(sigma_leaf_mu)) { + sigma_leaf_mu <- var(resid_train)/(num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } else { + if (!is.matrix(sigma_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + } else { + current_leaf_scale_mu <- sigma_leaf_mu + } + } + if (is.null(sigma_leaf_tau)) { + sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau) + current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + } else { + if (!is.matrix(sigma_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train))) + } else { + if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix") + current_leaf_scale_tau <- sigma_leaf_tau + } + } current_sigma2 <- sigma2_init - current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) - current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + + # Switch off leaf scale sampling for multivariate treatments + if (ncol(Z_train) > 1) { + if (sample_sigma_leaf_tau) { + warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.") + sample_sigma_leaf_tau <- F + } + } # Set mu and tau leaf models / dimensions leaf_model_mu_forest <- 0 diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 9a885b91..ac03b9b4 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -54,6 +54,65 @@ test_that("MCMC BART", { num_gfr = 0, num_burnin = 10, num_mcmc = 10, general_params = general_param_list) ) + + # Generate simulated data with a leaf basis + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n*p), ncol = p) + W <- matrix(runif(n*p_w), ncol = p_w) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + W_test <- W[test_inds,] + W_train <- W[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 3 chains, thinning, leaf regression + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = F) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + leaf_basis_train = W_train, leaf_basis_test = W_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = F, sigma2_leaf_init = 0.5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + leaf_basis_train = W_train, leaf_basis_test = W_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5) + expect_warning( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + leaf_basis_train = W_train, leaf_basis_test = W_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list) + ) }) test_that("GFR BART", { diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 0a34c37c..f78106ac 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -35,8 +35,8 @@ test_that("MCMC BCF", { X_train <- X[train_inds,] Z_test <- Z[test_inds] Z_train <- Z[train_inds] - pi_test <- pi[test_inds] - pi_train <- pi[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] mu_test <- mu_X[test_inds] mu_train <- mu_X[train_inds] tau_test <- tau_X[test_inds] @@ -53,6 +53,32 @@ test_that("MCMC BCF", { num_mcmc = 10, general_params = general_param_list) ) + # 1 chain, no thinning, matrix leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list, + mu_forest_params = mu_forest_param_list, + tau_forest_params = tau_forest_param_list) + ) + + # 1 chain, no thinning, scalar leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = 0.5) + tau_forest_param_list <- list(sigma2_leaf_init = 0.5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list, + mu_forest_params = mu_forest_param_list, + tau_forest_params = tau_forest_param_list) + ) + # 3 chains, no thinning general_param_list <- list(num_chains = 3, keep_every = 1) expect_no_error( @@ -118,8 +144,8 @@ test_that("GFR BCF", { X_train <- X[train_inds,] Z_test <- Z[test_inds] Z_train <- Z[train_inds] - pi_test <- pi[test_inds] - pi_train <- pi[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] mu_test <- mu_X[test_inds] mu_train <- mu_X[train_inds] tau_test <- tau_X[test_inds] @@ -219,8 +245,8 @@ test_that("Warmstart BCF", { X_train <- X[train_inds,] Z_test <- Z[test_inds] Z_train <- Z[train_inds] - pi_test <- pi[test_inds] - pi_train <- pi[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] mu_test <- mu_X[test_inds] mu_train <- mu_X[train_inds] tau_test <- tau_X[test_inds] @@ -287,8 +313,8 @@ test_that("Warmstart BCF", { X_train <- X[train_inds,] Z_test <- Z[test_inds] Z_train <- Z[train_inds] - pi_test <- pi[test_inds] - pi_train <- pi[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] mu_test <- mu_X[test_inds] mu_train <- mu_X[train_inds] tau_test <- tau_X[test_inds] @@ -329,3 +355,75 @@ test_that("Warmstart BCF", { general_params = general_param_list) ) }) + +test_that("Multivariate Treatment MCMC BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X_1 <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + pi_X_2 <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.8) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (0.4) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (0.6) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (0.2) + ) + pi_X <- cbind(pi_X_1, pi_X_2) + tau_X_1 <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + tau_X_2 <- ( + ((0 <= X[,3]) & (0.25 > X[,3])) * (-0.5) + + ((0.25 <= X[,3]) & (0.5 > X[,3])) * (-1.5) + + ((0.5 <= X[,3]) & (0.75 > X[,3])) * (-1.0) + + ((0.75 <= X[,3]) & (1 > X[,3])) * (0.0) + ) + tau_X <- cbind(tau_X_1, tau_X_2) + Z_1 <- as.numeric(rbinom(n, 1, pi_X_1)) + Z_2 <- as.numeric(rbinom(n, 1, pi_X_2)) + Z <- cbind(Z_1, Z_2) + noise_sd <- 1 + y <- mu_X + rowSums(tau_X*Z) + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds,] + Z_train <- Z[train_inds,] + pi_test <- pi_X[test_inds,] + pi_train <- pi_X[train_inds,] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds,] + tau_train <- tau_X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) +}) \ No newline at end of file