Skip to content

Commit e8978b6

Browse files
committed
Updated regression testing workflows to include R BCF
1 parent 5778e21 commit e8978b6

File tree

4 files changed

+344
-8
lines changed

4 files changed

+344
-8
lines changed

.github/workflows/regression-test.yml

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
name: Running stochtree on benchmark datasets
55

66
jobs:
7-
stochtree_r_bart:
7+
stochtree_r:
88
name: stochtree-r-bart-regression-test
99
runs-on: ubuntu-latest
1010

@@ -31,20 +31,29 @@ jobs:
3131
with:
3232
extra-packages: any::testthat, any::decor, local::stochtree_cran
3333

34-
- name: Create output directory for regression test results
34+
- name: Create output directory for BART regression test results
3535
run: |
36-
mkdir -p tools/regression/stochtree_bart_r_results
36+
mkdir -p tools/regression/bart/stochtree_bart_r_results
37+
mkdir -p tools/regression/bcf/stochtree_bcf_r_results
3738
38-
- name: Run the regression test benchmark suite
39+
- name: Run the BART regression test benchmark suite
3940
run: |
40-
Rscript tools/regression/regression_test_dispatch_bart.R
41+
Rscript tools/regression/bart/regression_test_dispatch_bart.R
42+
Rscript tools/regression/bcf/regression_test_dispatch_bcf.R
4143
4244
- name: Collate and analyze regression test results
4345
run: |
44-
Rscript tools/regression/regression_test_analysis_bart.R
46+
Rscript tools/regression/bart/regression_test_analysis_bart.R
47+
Rscript tools/regression/bcf/regression_test_analysis_bcf.R
4548
46-
- name: Store benchmark test results as an artifact of the run
49+
- name: Store BART benchmark test results as an artifact of the run
4750
uses: actions/upload-artifact@v4
4851
with:
4952
name: stochtree-r-bart-summary
50-
path: tools/regression/stochtree_bart_r_results/stochtree_bart_r_summary.csv
53+
path: tools/regression/bart/stochtree_bart_r_results/stochtree_bart_r_summary.csv
54+
55+
- name: Store BCF benchmark test results as an artifact of the run
56+
uses: actions/upload-artifact@v4
57+
with:
58+
name: stochtree-r-bcf-summary
59+
path: tools/regression/bcf/stochtree_bcf_r_results/stochtree_bcf_r_summary.csv
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Define DGPs
5+
dgp1 <- function(n, p, snr) {
6+
X <- matrix(rnorm(n*p), ncol = p)
7+
plm_term <- (
8+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
9+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
10+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
11+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
12+
)
13+
trig_term <- (
14+
2*sin(X[,3]*2*pi) -
15+
2*cos(X[,4]*2*pi)
16+
)
17+
mu_x <- plm_term + trig_term
18+
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
19+
Z <- rbinom(n,1,pi_x)
20+
tau_x <- 1 + 2*X[,2]*X[,4]
21+
f_XZ <- mu_x + tau_x * Z
22+
noise_sd <- sd(f_XZ)/snr
23+
y <- f_XZ + rnorm(n, 0, noise_sd)
24+
return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x,
25+
prognostic_effect = mu_x, treatment_effect = tau_x,
26+
conditional_mean = f_XZ, rfx_group_ids = NULL, rfx_basis = NULL))
27+
}
28+
dgp2 <- function(n, p, snr) {
29+
X <- matrix(runif(n*p), ncol = p)
30+
pi_x <- cbind(0.125 + 0.75 * X[, 1], 0.875 - 0.75 * X[, 2])
31+
mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3]
32+
tau_x <- cbind(X[, 2], X[, 3]) * 2
33+
Z <- matrix(NA_real_, nrow = n, ncol = ncol(pi_x))
34+
for (i in 1:ncol(pi_x)) {
35+
Z[, i] <- rbinom(n, 1, pi_x[, i])
36+
}
37+
f_XZ <- mu_x + rowSums(Z * tau_x)
38+
noise_sd <- sd(f_XZ)/snr
39+
y <- f_XZ + rnorm(n, 0, noise_sd)
40+
return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x,
41+
prognostic_effect = mu_x, treatment_effect = tau_x,
42+
conditional_mean = f_XZ, rfx_group_ids = NULL, rfx_basis = NULL))
43+
}
44+
dgp3 <- function(n, p, snr) {
45+
X <- matrix(rnorm(n*p), ncol = p)
46+
plm_term <- (
47+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
48+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
49+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
50+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
51+
)
52+
trig_term <- (
53+
2*sin(X[,3]*2*pi) -
54+
2*cos(X[,4]*2*pi)
55+
)
56+
mu_x <- plm_term + trig_term
57+
pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10
58+
Z <- rbinom(n,1,pi_x)
59+
tau_x <- 1 + 2*X[,2]*X[,4]
60+
rfx_group_ids <- sample(1:3, size = n, replace = T)
61+
rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE))
62+
rfx_basis <- cbind(1, runif(n, -1, 1))
63+
rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
64+
f_XZ <- mu_x + tau_x * Z + rfx_term
65+
noise_sd <- sd(f_XZ)/snr
66+
y <- f_XZ + rnorm(n, 0, noise_sd)
67+
return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x,
68+
prognostic_effect = mu_x, treatment_effect = tau_x,
69+
conditional_mean = f_XZ, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis))
70+
}
71+
dgp4 <- function(n, p, snr) {
72+
X <- matrix(runif(n*p), ncol = p)
73+
pi_x <- cbind(0.125 + 0.75 * X[, 1], 0.875 - 0.75 * X[, 2])
74+
mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3]
75+
tau_x <- cbind(X[, 2], X[, 3]) * 2
76+
Z <- matrix(NA_real_, nrow = n, ncol = ncol(pi_x))
77+
for (i in 1:ncol(pi_x)) {
78+
Z[, i] <- rbinom(n, 1, pi_x[, i])
79+
}
80+
rfx_group_ids <- sample(1:3, size = n, replace = T)
81+
rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE))
82+
rfx_basis <- cbind(1, runif(n, -1, 1))
83+
rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
84+
f_XZ <- mu_x + rowSums(Z * tau_x) + rfx_term
85+
noise_sd <- sd(f_XZ)/snr
86+
y <- f_XZ + rnorm(n, 0, noise_sd)
87+
return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x,
88+
prognostic_effect = mu_x, treatment_effect = tau_x,
89+
conditional_mean = f_XZ, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis))
90+
}
91+
92+
# Test / train split utilities
93+
compute_test_train_indices <- function(n, test_set_pct) {
94+
n_test <- round(test_set_pct*n)
95+
n_train <- n - n_test
96+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
97+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
98+
return(list(test_inds = test_inds, train_inds = train_inds))
99+
}
100+
subset_data <- function(data, subset_inds) {
101+
if (is.matrix(data)) {
102+
return(data[subset_inds,])
103+
} else {
104+
return(data[subset_inds])
105+
}
106+
}
107+
108+
# Capture command line arguments
109+
args <- commandArgs(trailingOnly = T)
110+
if (length(args) > 0){
111+
n_iter <- as.integer(args[1])
112+
n <- as.integer(args[2])
113+
p <- as.integer(args[3])
114+
num_gfr <- as.integer(args[4])
115+
num_mcmc <- as.integer(args[5])
116+
dgp_num <- as.integer(args[6])
117+
snr <- as.numeric(args[7])
118+
test_set_pct <- as.numeric(args[8])
119+
num_threads <- as.integer(args[9])
120+
} else{
121+
# Default arguments
122+
n_iter <- 5
123+
n <- 1000
124+
p <- 5
125+
num_gfr <- 10
126+
num_mcmc <- 100
127+
dgp_num <- 1
128+
snr <- 2.0
129+
test_set_pct <- 0.2
130+
num_threads <- -1
131+
}
132+
cat("n_iter = ", n_iter, "\nn = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr,
133+
"\nnum_mcmc = ", num_mcmc, "\ndgp_num = ", dgp_num, "\nsnr = ", snr,
134+
"\ntest_set_pct = ", test_set_pct, "\nnum_threads = ", num_threads, "\n", sep = "")
135+
136+
# Run the performance evaluation
137+
results <- matrix(NA, nrow = n_iter, ncol = 6)
138+
colnames(results) <- c("iter", "outcome_rmse", "outcome_coverage", "treatment_effect_rmse", "treatment_effect_coverage", "runtime")
139+
for (i in 1:n_iter) {
140+
# Generate data
141+
if (dgp_num == 1) {
142+
data_list <- dgp1(n = n, p = p, snr = snr)
143+
} else if (dgp_num == 2) {
144+
data_list <- dgp2(n = n, p = p, snr = snr)
145+
} else if (dgp_num == 3) {
146+
data_list <- dgp3(n = n, p = p, snr = snr)
147+
} else if (dgp_num == 4) {
148+
data_list <- dgp4(n = n, p = p, snr = snr)
149+
} else {
150+
stop("Invalid DGP input")
151+
}
152+
covariates <- data_list[['covariates']]
153+
treatment <- data_list[['treatment']]
154+
propensity <- data_list[['propensity']]
155+
prognostic_effect <- data_list[['prognostic_effect']]
156+
treatment_effect <- data_list[['treatment_effect']]
157+
conditional_mean <- data_list[['conditional_mean']]
158+
outcome <- data_list[['outcome']]
159+
rfx_group_ids <- data_list[['rfx_group_ids']]
160+
rfx_basis <- data_list[['rfx_basis']]
161+
if (dgp_num %in% c(2,4)) {
162+
has_multivariate_treatment <- T
163+
} else {
164+
has_multivariate_treatment <- F
165+
}
166+
167+
# Split into train / test sets
168+
subset_inds_list <- compute_test_train_indices(n, test_set_pct)
169+
test_inds <- subset_inds_list$test_inds
170+
train_inds <- subset_inds_list$train_inds
171+
covariates_train <- subset_data(covariates, train_inds)
172+
covariates_test <- subset_data(covariates, test_inds)
173+
treatment_train <- subset_data(treatment, train_inds)
174+
treatment_test <- subset_data(treatment, test_inds)
175+
propensity_train <- subset_data(propensity, train_inds)
176+
propensity_test <- subset_data(propensity, test_inds)
177+
outcome_train <- subset_data(outcome, train_inds)
178+
outcome_test <- subset_data(outcome, test_inds)
179+
prognostic_effect_train <- subset_data(prognostic_effect, train_inds)
180+
prognostic_effect_test <- subset_data(prognostic_effect, test_inds)
181+
treatment_effect_train <- subset_data(treatment_effect, train_inds)
182+
treatment_effect_test <- subset_data(treatment_effect, test_inds)
183+
conditional_mean_train <- subset_data(conditional_mean, train_inds)
184+
conditional_mean_test <- subset_data(conditional_mean, test_inds)
185+
has_rfx <- !is.null(rfx_group_ids)
186+
if (has_rfx) {
187+
rfx_group_ids_train <- subset_data(rfx_group_ids, train_inds)
188+
rfx_group_ids_test <- subset_data(rfx_group_ids, test_inds)
189+
rfx_basis_train <- subset_data(rfx_basis, train_inds)
190+
rfx_basis_test <- subset_data(rfx_basis, test_inds)
191+
} else {
192+
rfx_group_ids_train <- NULL
193+
rfx_group_ids_test <- NULL
194+
rfx_basis_train <- NULL
195+
rfx_basis_test <- NULL
196+
}
197+
198+
# Run (and time) BCF
199+
bcf_timing <- system.time({
200+
# Sample BCF model
201+
general_params <- list(num_threads = num_threads, adaptive_coding = F)
202+
prognostic_forest_params <- list(sample_sigma2_leaf = F)
203+
treatment_effect_forest_params <- list(sample_sigma2_leaf = F)
204+
bcf_model <- stochtree::bcf(
205+
X_train = covariates_train, Z_train = treatment_train,
206+
propensity_train = propensity_train, y_train = outcome_train,
207+
rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train,
208+
num_gfr = num_gfr, num_mcmc = num_mcmc, general_params = general_params,
209+
prognostic_forest_params = prognostic_forest_params,
210+
treatment_effect_forest_params = treatment_effect_forest_params
211+
)
212+
213+
# Predict on the test set
214+
test_preds <- predict(
215+
bcf_model, X = covariates_test, Z = treatment_test, propensity = propensity_test,
216+
rfx_group_ids = rfx_group_ids_test, rfx_basis = rfx_basis_test
217+
)
218+
})[3]
219+
220+
# Compute test set evals
221+
y_hat_posterior <- test_preds$y_hat
222+
y_hat_posterior_mean <- rowMeans(y_hat_posterior)
223+
tau_hat_posterior <- test_preds$tau_hat
224+
if (has_multivariate_treatment) tau_hat_posterior_mean <- apply(tau_hat_posterior, c(1,2), mean)
225+
else tau_hat_posterior_mean <- apply(tau_hat_posterior, 1, mean)
226+
y_hat_rmse_test <- sqrt(mean((y_hat_posterior_mean - outcome_test)^2))
227+
tau_hat_rmse_test <- sqrt(mean((tau_hat_posterior_mean - treatment_effect_test)^2))
228+
y_hat_posterior_quantile_025 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.025))
229+
y_hat_posterior_quantile_975 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.975))
230+
if (has_multivariate_treatment) {
231+
tau_hat_posterior_quantile_025 <- apply(tau_hat_posterior, c(1,2), function(x) quantile(x, 0.025))
232+
tau_hat_posterior_quantile_975 <- apply(tau_hat_posterior, c(1,2), function(x) quantile(x, 0.975))
233+
} else {
234+
tau_hat_posterior_quantile_025 <- apply(tau_hat_posterior, 1, function(x) quantile(x, 0.025))
235+
tau_hat_posterior_quantile_975 <- apply(tau_hat_posterior, 1, function(x) quantile(x, 0.975))
236+
}
237+
y_hat_covered <- rep(NA, nrow(y_hat_posterior))
238+
for (j in 1:nrow(y_hat_posterior)) {
239+
y_hat_covered[j] <- (
240+
(conditional_mean_test[j] >= y_hat_posterior_quantile_025[j]) &
241+
(conditional_mean_test[j] <= y_hat_posterior_quantile_975[j])
242+
)
243+
}
244+
y_hat_coverage_test <- mean(y_hat_covered)
245+
if (has_multivariate_treatment) {
246+
tau_hat_covered <- matrix(NA_real_, nrow(tau_hat_posterior_mean), ncol(tau_hat_posterior_mean))
247+
for (j in 1:nrow(tau_hat_covered)) {
248+
for (k in 1:ncol(tau_hat_covered)) {
249+
tau_hat_covered[j,k] <- (
250+
(treatment_effect_test[j,k] >= tau_hat_posterior_quantile_025[j,k]) &
251+
(treatment_effect_test[j,k] <= tau_hat_posterior_quantile_975[j,k])
252+
)
253+
}
254+
}
255+
} else {
256+
tau_hat_covered <- rep(NA, nrow(tau_hat_posterior))
257+
for (j in 1:nrow(tau_hat_posterior)) {
258+
tau_hat_covered[j] <- (
259+
(treatment_effect_test[j] >= tau_hat_posterior_quantile_025[j]) &
260+
(treatment_effect_test[j] <= tau_hat_posterior_quantile_025[j])
261+
)
262+
}
263+
}
264+
tau_hat_coverage_test <- mean(tau_hat_covered)
265+
266+
# Store evaluations
267+
results[i,] <- c(i, y_hat_rmse_test, y_hat_coverage_test, tau_hat_rmse_test, tau_hat_coverage_test, bcf_timing)
268+
}
269+
270+
# Wrangle and save results to CSV
271+
results_df <- data.frame(
272+
cbind(n, p, num_gfr, num_mcmc, dgp_num, snr, test_set_pct, num_threads, results)
273+
)
274+
snr_rounded <- as.integer(snr)
275+
test_set_pct_rounded <- as.integer(test_set_pct*100)
276+
num_threads_clean <- ifelse(num_threads < 0, 0, num_threads)
277+
filename <- paste(
278+
"stochtree", "bcf", "r", "n", n, "p", p, "num_gfr", num_gfr, "num_mcmc", num_mcmc,
279+
"dgp_num", dgp_num, "snr", snr_rounded, "test_set_pct", test_set_pct_rounded,
280+
"num_threads", num_threads_clean, sep = "_"
281+
)
282+
filename_full <- paste0("tools/regression/bcf/stochtree_bcf_r_results/", filename, ".csv")
283+
write.csv(x = results_df, file = filename_full, row.names = F)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
reg_test_dir <- "tools/regression/bcf/stochtree_bcf_r_results"
2+
reg_test_files <- list.files(reg_test_dir, pattern = ".csv", full.names = T)
3+
4+
reg_test_df <- data.frame()
5+
for (file in reg_test_files) {
6+
temp_df <- read.csv(file)
7+
reg_test_df <- rbind(reg_test_df, temp_df)
8+
}
9+
10+
summary_df <- aggregate(
11+
cbind(rmse, coverage, runtime) ~ n + p + num_gfr + num_mcmc + dgp_num + snr + test_set_pct + num_threads,
12+
data = reg_test_df, FUN = median, drop = TRUE
13+
)
14+
15+
summary_file_output <- file.path(reg_test_dir, "stochtree_bcf_r_summary.csv")
16+
write.csv(summary_df, summary_file_output, row.names = F)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Test case parameters
2+
dgps <- 1:4
3+
ns <- c(1000, 10000)
4+
ps <- c(5, 20)
5+
threads <- c(-1, 1)
6+
varying_param_grid <- expand.grid(dgps, ns, ps, threads)
7+
test_case_grid <- cbind(
8+
5, varying_param_grid[,2], varying_param_grid[,3],
9+
10, 100, varying_param_grid[,1], 2.0, 0.2, varying_param_grid[,4]
10+
)
11+
12+
# Run script for every case
13+
script_path <- "tools/regression/bcf/individual_regression_test_bcf.R"
14+
for (i in 1:nrow(test_case_grid)) {
15+
n_iter <- test_case_grid[i,1]
16+
n <- test_case_grid[i,2]
17+
p <- test_case_grid[i,3]
18+
num_gfr <- test_case_grid[i,4]
19+
num_mcmc <- test_case_grid[i,5]
20+
dgp_num <- test_case_grid[i,6]
21+
snr <- test_case_grid[i,7]
22+
test_set_pct <- test_case_grid[i,8]
23+
num_threads <- test_case_grid[i,9]
24+
system2(
25+
"Rscript",
26+
args = c(script_path, n_iter, n, p, num_gfr, num_mcmc, dgp_num, snr, test_set_pct, num_threads)
27+
)
28+
}

0 commit comments

Comments
 (0)