@@ -426,4 +426,71 @@ test_that("Multivariate Treatment MCMC BCF", {
426
426
propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
427
427
num_mcmc = 10 , general_params = general_param_list )
428
428
)
429
- })
429
+ })
430
+
431
+ test_that(" BCF Predictions" , {
432
+ skip_on_cran()
433
+
434
+ # Generate simulated data
435
+ n <- 100
436
+ p <- 5
437
+ X <- matrix (runif(n * p ), ncol = p )
438
+ mu_X <- (
439
+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 ) +
440
+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 ) +
441
+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 ) +
442
+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 )
443
+ )
444
+ pi_X <- (
445
+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (0.2 ) +
446
+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (0.4 ) +
447
+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (0.6 ) +
448
+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (0.8 )
449
+ )
450
+ tau_X <- (
451
+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.5 ) +
452
+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (1.0 ) +
453
+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (1.5 ) +
454
+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (2.0 )
455
+ )
456
+ Z <- rbinom(n , 1 , pi_X )
457
+ noise_sd <- 1
458
+ y <- mu_X + tau_X * Z + rnorm(n , 0 , noise_sd )
459
+ test_set_pct <- 0.2
460
+ n_test <- round(test_set_pct * n )
461
+ n_train <- n - n_test
462
+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
463
+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
464
+ X_test <- X [test_inds ,]
465
+ X_train <- X [train_inds ,]
466
+ Z_test <- Z [test_inds ]
467
+ Z_train <- Z [train_inds ]
468
+ pi_test <- pi_X [test_inds ]
469
+ pi_train <- pi_X [train_inds ]
470
+ mu_test <- mu_X [test_inds ]
471
+ mu_train <- mu_X [train_inds ]
472
+ tau_test <- tau_X [test_inds ]
473
+ tau_train <- tau_X [train_inds ]
474
+ y_test <- y [test_inds ]
475
+ y_train <- y [train_inds ]
476
+
477
+ # Run a BCF model with only GFR
478
+ general_params <- list (num_chains = 1 , keep_every = 1 )
479
+ variance_forest_params <- list (num_trees = 50 )
480
+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
481
+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
482
+ propensity_test = pi_test , num_gfr = 10 , num_burnin = 0 ,
483
+ num_mcmc = 10 , general_params = general_params ,
484
+ variance_forest_params = variance_forest_params )
485
+
486
+ # Check that cached predictions agree with results of predict() function
487
+ train_preds <- predict(bcf_model , X = X_train , Z = Z_train , propensity = pi_train )
488
+ train_preds_mean_cached <- bcf_model $ y_hat_train
489
+ train_preds_mean_recomputed <- train_preds $ y_hat
490
+ train_preds_variance_cached <- bcf_model $ sigma2_x_hat_train
491
+ train_preds_variance_recomputed <- train_preds $ variance_forest_predictions
492
+
493
+ # Assertion
494
+ expect_equal(train_preds_mean_cached , train_preds_mean_recomputed )
495
+ expect_equal(train_preds_variance_cached , train_preds_variance_recomputed )
496
+ })
0 commit comments