Skip to content

Commit 49b51ca

Browse files
committed
Added R BART and BCF unit tests for cached predictions
1 parent 2ff2f49 commit 49b51ca

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

test/R/testthat/test-bart.R

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,48 @@ test_that("Warmstart BART", {
291291
general_params = general_param_list)
292292
)
293293
})
294+
295+
test_that("BART Predictions", {
296+
skip_on_cran()
297+
298+
# Generate simulated data
299+
n <- 100
300+
p <- 5
301+
X <- matrix(runif(n*p), ncol = p)
302+
f_XW <- (
303+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
304+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
305+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
306+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
307+
)
308+
noise_sd <- 1
309+
y <- f_XW + rnorm(n, 0, noise_sd)
310+
test_set_pct <- 0.2
311+
n_test <- round(test_set_pct*n)
312+
n_train <- n - n_test
313+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
314+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
315+
X_test <- X[test_inds,]
316+
X_train <- X[train_inds,]
317+
y_test <- y[test_inds]
318+
y_train <- y[train_inds]
319+
320+
# Run a BART model with only GFR
321+
general_params <- list(num_chains = 1)
322+
variance_forest_params <- list(num_trees = 50)
323+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
324+
num_gfr = 10, num_burnin = 0, num_mcmc = 10,
325+
general_params = general_params,
326+
variance_forest_params = variance_forest_params)
327+
328+
# Check that cached predictions agree with results of predict() function
329+
train_preds <- predict(bart_model, X = X_train)
330+
train_preds_mean_cached <- bart_model$y_hat_train
331+
train_preds_mean_recomputed <- train_preds$mean_forest_predictions
332+
train_preds_variance_cached <- bart_model$sigma2_x_hat_train
333+
train_preds_variance_recomputed <- train_preds$variance_forest_predictions
334+
335+
# Assertion
336+
expect_equal(train_preds_mean_cached, train_preds_mean_recomputed)
337+
expect_equal(train_preds_variance_cached, train_preds_variance_recomputed)
338+
})

test/R/testthat/test-bcf.R

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,71 @@ test_that("Multivariate Treatment MCMC BCF", {
426426
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
427427
num_mcmc = 10, general_params = general_param_list)
428428
)
429-
})
429+
})
430+
431+
test_that("BCF Predictions", {
432+
skip_on_cran()
433+
434+
# Generate simulated data
435+
n <- 100
436+
p <- 5
437+
X <- matrix(runif(n*p), ncol = p)
438+
mu_X <- (
439+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
440+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
441+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
442+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
443+
)
444+
pi_X <- (
445+
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
446+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
447+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
448+
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
449+
)
450+
tau_X <- (
451+
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
452+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
453+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
454+
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
455+
)
456+
Z <- rbinom(n, 1, pi_X)
457+
noise_sd <- 1
458+
y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd)
459+
test_set_pct <- 0.2
460+
n_test <- round(test_set_pct*n)
461+
n_train <- n - n_test
462+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
463+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
464+
X_test <- X[test_inds,]
465+
X_train <- X[train_inds,]
466+
Z_test <- Z[test_inds]
467+
Z_train <- Z[train_inds]
468+
pi_test <- pi_X[test_inds]
469+
pi_train <- pi_X[train_inds]
470+
mu_test <- mu_X[test_inds]
471+
mu_train <- mu_X[train_inds]
472+
tau_test <- tau_X[test_inds]
473+
tau_train <- tau_X[train_inds]
474+
y_test <- y[test_inds]
475+
y_train <- y[train_inds]
476+
477+
# Run a BCF model with only GFR
478+
general_params <- list(num_chains = 1, keep_every = 1)
479+
variance_forest_params <- list(num_trees = 50)
480+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
481+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
482+
propensity_test = pi_test, num_gfr = 10, num_burnin = 0,
483+
num_mcmc = 10, general_params = general_params,
484+
variance_forest_params = variance_forest_params)
485+
486+
# Check that cached predictions agree with results of predict() function
487+
train_preds <- predict(bcf_model, X = X_train, Z = Z_train, propensity = pi_train)
488+
train_preds_mean_cached <- bcf_model$y_hat_train
489+
train_preds_mean_recomputed <- train_preds$y_hat
490+
train_preds_variance_cached <- bcf_model$sigma2_x_hat_train
491+
train_preds_variance_recomputed <- train_preds$variance_forest_predictions
492+
493+
# Assertion
494+
expect_equal(train_preds_mean_cached, train_preds_mean_recomputed)
495+
expect_equal(train_preds_variance_cached, train_preds_variance_recomputed)
496+
})

0 commit comments

Comments
 (0)