@@ -35,8 +35,8 @@ test_that("MCMC BCF", {
3535 X_train <- X [train_inds ,]
3636 Z_test <- Z [test_inds ]
3737 Z_train <- Z [train_inds ]
38- pi_test <- pi [test_inds ]
39- pi_train <- pi [train_inds ]
38+ pi_test <- pi_X [test_inds ]
39+ pi_train <- pi_X [train_inds ]
4040 mu_test <- mu_X [test_inds ]
4141 mu_train <- mu_X [train_inds ]
4242 tau_test <- tau_X [test_inds ]
@@ -53,6 +53,32 @@ test_that("MCMC BCF", {
5353 num_mcmc = 10 , general_params = general_param_list )
5454 )
5555
56+ # 1 chain, no thinning, matrix leaf scale parameter provided
57+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
58+ mu_forest_param_list <- list (sigma2_leaf_init = as.matrix(0.5 ))
59+ tau_forest_param_list <- list (sigma2_leaf_init = as.matrix(0.5 ))
60+ expect_no_error(
61+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
62+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
63+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
64+ num_mcmc = 10 , general_params = general_param_list ,
65+ mu_forest_params = mu_forest_param_list ,
66+ tau_forest_params = tau_forest_param_list )
67+ )
68+
69+ # 1 chain, no thinning, scalar leaf scale parameter provided
70+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
71+ mu_forest_param_list <- list (sigma2_leaf_init = 0.5 )
72+ tau_forest_param_list <- list (sigma2_leaf_init = 0.5 )
73+ expect_no_error(
74+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
75+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
76+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
77+ num_mcmc = 10 , general_params = general_param_list ,
78+ mu_forest_params = mu_forest_param_list ,
79+ tau_forest_params = tau_forest_param_list )
80+ )
81+
5682 # 3 chains, no thinning
5783 general_param_list <- list (num_chains = 3 , keep_every = 1 )
5884 expect_no_error(
@@ -118,8 +144,8 @@ test_that("GFR BCF", {
118144 X_train <- X [train_inds ,]
119145 Z_test <- Z [test_inds ]
120146 Z_train <- Z [train_inds ]
121- pi_test <- pi [test_inds ]
122- pi_train <- pi [train_inds ]
147+ pi_test <- pi_X [test_inds ]
148+ pi_train <- pi_X [train_inds ]
123149 mu_test <- mu_X [test_inds ]
124150 mu_train <- mu_X [train_inds ]
125151 tau_test <- tau_X [test_inds ]
@@ -219,8 +245,8 @@ test_that("Warmstart BCF", {
219245 X_train <- X [train_inds ,]
220246 Z_test <- Z [test_inds ]
221247 Z_train <- Z [train_inds ]
222- pi_test <- pi [test_inds ]
223- pi_train <- pi [train_inds ]
248+ pi_test <- pi_X [test_inds ]
249+ pi_train <- pi_X [train_inds ]
224250 mu_test <- mu_X [test_inds ]
225251 mu_train <- mu_X [train_inds ]
226252 tau_test <- tau_X [test_inds ]
@@ -287,8 +313,8 @@ test_that("Warmstart BCF", {
287313 X_train <- X [train_inds ,]
288314 Z_test <- Z [test_inds ]
289315 Z_train <- Z [train_inds ]
290- pi_test <- pi [test_inds ]
291- pi_train <- pi [train_inds ]
316+ pi_test <- pi_X [test_inds ]
317+ pi_train <- pi_X [train_inds ]
292318 mu_test <- mu_X [test_inds ]
293319 mu_train <- mu_X [train_inds ]
294320 tau_test <- tau_X [test_inds ]
@@ -329,3 +355,75 @@ test_that("Warmstart BCF", {
329355 general_params = general_param_list )
330356 )
331357})
358+
359+ test_that(" Multivariate Treatment MCMC BCF" , {
360+ skip_on_cran()
361+
362+ # Generate simulated data
363+ n <- 100
364+ p <- 5
365+ X <- matrix (runif(n * p ), ncol = p )
366+ mu_X <- (
367+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 ) +
368+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 ) +
369+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 ) +
370+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 )
371+ )
372+ pi_X_1 <- (
373+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (0.2 ) +
374+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (0.4 ) +
375+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (0.6 ) +
376+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (0.8 )
377+ )
378+ pi_X_2 <- (
379+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.8 ) +
380+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (0.4 ) +
381+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (0.6 ) +
382+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (0.2 )
383+ )
384+ pi_X <- cbind(pi_X_1 , pi_X_2 )
385+ tau_X_1 <- (
386+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.5 ) +
387+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (1.0 ) +
388+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (1.5 ) +
389+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (2.0 )
390+ )
391+ tau_X_2 <- (
392+ ((0 < = X [,3 ]) & (0.25 > X [,3 ])) * (- 0.5 ) +
393+ ((0.25 < = X [,3 ]) & (0.5 > X [,3 ])) * (- 1.5 ) +
394+ ((0.5 < = X [,3 ]) & (0.75 > X [,3 ])) * (- 1.0 ) +
395+ ((0.75 < = X [,3 ]) & (1 > X [,3 ])) * (0.0 )
396+ )
397+ tau_X <- cbind(tau_X_1 , tau_X_2 )
398+ Z_1 <- as.numeric(rbinom(n , 1 , pi_X_1 ))
399+ Z_2 <- as.numeric(rbinom(n , 1 , pi_X_2 ))
400+ Z <- cbind(Z_1 , Z_2 )
401+ noise_sd <- 1
402+ y <- mu_X + rowSums(tau_X * Z ) + rnorm(n , 0 , noise_sd )
403+ test_set_pct <- 0.2
404+ n_test <- round(test_set_pct * n )
405+ n_train <- n - n_test
406+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
407+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
408+ X_test <- X [test_inds ,]
409+ X_train <- X [train_inds ,]
410+ Z_test <- Z [test_inds ,]
411+ Z_train <- Z [train_inds ,]
412+ pi_test <- pi_X [test_inds ,]
413+ pi_train <- pi_X [train_inds ,]
414+ mu_test <- mu_X [test_inds ]
415+ mu_train <- mu_X [train_inds ]
416+ tau_test <- tau_X [test_inds ,]
417+ tau_train <- tau_X [train_inds ,]
418+ y_test <- y [test_inds ]
419+ y_train <- y [train_inds ]
420+
421+ # 1 chain, no thinning
422+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
423+ expect_error(
424+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
425+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
426+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
427+ num_mcmc = 10 , general_params = general_param_list )
428+ )
429+ })
0 commit comments