Skip to content

Commit 64e2f7c

Browse files
committed
Basic regression testing setup, manually dispatched
1 parent 61ba8b9 commit 64e2f7c

File tree

4 files changed

+329
-0
lines changed

4 files changed

+329
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
on:
2+
workflow_dispatch:
3+
4+
name: Running stochtree on benchmark datasets
5+
6+
jobs:
7+
stochtree_r_bart:
8+
name: stochtree-r-bart-regression-test
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- name: Checkout stochtree repo
13+
uses: actions/checkout@v4
14+
with:
15+
submodules: 'recursive'
16+
17+
- name: Setup pandoc
18+
uses: r-lib/actions/setup-pandoc@v2
19+
20+
- name: Setup R
21+
uses: r-lib/actions/setup-r@v2
22+
with:
23+
use-public-rspm: true
24+
25+
- name: Create a properly formatted version of the stochtree R package in a subfolder
26+
run: |
27+
Rscript cran-bootstrap.R 0 0 1
28+
29+
- name: Setup R dependencies
30+
uses: r-lib/actions/setup-r-dependencies@v2
31+
with:
32+
extra-packages: any::testthat, any::decor, local::stochtree_cran
33+
34+
- name: Create output directory for regression test results
35+
run: |
36+
mkdir -p tools/regression/stochtree_bart_r_results
37+
38+
- name: Run the regression test benchmark suite
39+
run: |
40+
Rscript tools/regression/regression_test_dispatch_bart.R
41+
42+
- name: Collate and analyze regression test results
43+
run: |
44+
Rscript tools/regression/regression_test_analysis_bart.R
45+
46+
- name: Store benchmark test results as an artifact of the run
47+
uses: actions/upload-artifact@v4
48+
with:
49+
name: stochtree-r-bart-summary
50+
path: tools/regression/stochtree_bart_r_results/stochtree_bart_r_summary.csv
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Load libraries
2+
library(stochtree)
3+
4+
# Define DGPs
5+
dgp1 <- function(n, p, snr) {
6+
X <- matrix(runif(n*p), ncol = p)
7+
plm_term <- (
8+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
9+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
10+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
11+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
12+
)
13+
trig_term <- (
14+
2*sin(X[,3]*2*pi) -
15+
1.5*cos(X[,4]*2*pi)
16+
)
17+
f_XW <- plm_term + trig_term
18+
noise_sd <- sd(f_XW)/snr
19+
y <- f_XW + rnorm(n, 0, noise_sd)
20+
return(list(covariates = X, basis = NULL, outcome = y, conditional_mean = f_XW,
21+
rfx_group_ids = NULL, rfx_basis = NULL))
22+
}
23+
dgp2 <- function(n, p, snr) {
24+
X <- matrix(runif(n*p), ncol = p)
25+
W <- matrix(runif(n*2), ncol = 2)
26+
plm_term <- (
27+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
28+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
29+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
30+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
31+
)
32+
trig_term <- (
33+
2*sin(X[,3]*2*pi) -
34+
1.5*cos(X[,4]*2*pi)
35+
)
36+
f_XW <- plm_term + trig_term
37+
noise_sd <- sd(f_XW)/snr
38+
y <- f_XW + rnorm(n, 0, noise_sd)
39+
return(list(covariates = X, basis = W, outcome = y, conditional_mean = f_XW,
40+
rfx_group_ids = NULL, rfx_basis = NULL))
41+
}
42+
dgp3 <- function(n, p, snr) {
43+
X <- matrix(runif(n*p), ncol = p)
44+
plm_term <- (
45+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) +
46+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) +
47+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) +
48+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2])
49+
)
50+
trig_term <- (
51+
2*sin(X[,3]*2*pi) -
52+
1.5*cos(X[,4]*2*pi)
53+
)
54+
rfx_group_ids <- sample(1:3, size = n, replace = T)
55+
rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE))
56+
rfx_basis <- cbind(1, runif(n, -1, 1))
57+
rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
58+
f_XW <- plm_term + trig_term + rfx_term
59+
noise_sd <- sd(f_XW)/snr
60+
y <- f_XW + rnorm(n, 0, noise_sd)
61+
return(list(covariates = X, basis = NULL, outcome = y, conditional_mean = f_XW,
62+
rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis))
63+
}
64+
dgp4 <- function(n, p, snr) {
65+
X <- matrix(runif(n*p), ncol = p)
66+
W <- matrix(runif(n*2), ncol = 2)
67+
plm_term <- (
68+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
69+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
70+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
71+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
72+
)
73+
trig_term <- (
74+
2*sin(X[,3]*2*pi) -
75+
1.5*cos(X[,4]*2*pi)
76+
)
77+
rfx_group_ids <- sample(1:3, size = n, replace = T)
78+
rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE))
79+
rfx_basis <- cbind(1, runif(n, -1, 1))
80+
rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis)
81+
f_XW <- plm_term + trig_term + rfx_term
82+
noise_sd <- sd(f_XW)/snr
83+
y <- f_XW + rnorm(n, 0, noise_sd)
84+
return(list(covariates = X, basis = W, outcome = y, conditional_mean = f_XW,
85+
rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis))
86+
}
87+
88+
# Test / train split utilities
89+
compute_test_train_indices <- function(n, test_set_pct) {
90+
n_test <- round(test_set_pct*n)
91+
n_train <- n - n_test
92+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
93+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
94+
return(list(test_inds = test_inds, train_inds = train_inds))
95+
}
96+
subset_data <- function(data, subset_inds) {
97+
if (is.matrix(data)) {
98+
return(data[subset_inds,])
99+
} else {
100+
return(data[subset_inds])
101+
}
102+
}
103+
104+
# Capture command line arguments
105+
args <- commandArgs(trailingOnly = T)
106+
if (length(args) > 0){
107+
n_iter <- as.integer(args[1])
108+
n <- as.integer(args[2])
109+
p <- as.integer(args[3])
110+
num_gfr <- as.integer(args[4])
111+
num_mcmc <- as.integer(args[5])
112+
dgp_num <- as.integer(args[6])
113+
snr <- as.numeric(args[7])
114+
test_set_pct <- as.numeric(args[8])
115+
num_threads <- as.integer(args[9])
116+
} else{
117+
# Default arguments
118+
n_iter <- 5
119+
n <- 1000
120+
p <- 5
121+
num_gfr <- 10
122+
num_mcmc <- 100
123+
dgp_num <- 1
124+
snr <- 2.0
125+
test_set_pct <- 0.2
126+
num_threads <- -1
127+
}
128+
cat("n_iter = ", n_iter, "\nn = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr,
129+
"\nnum_mcmc = ", num_mcmc, "\ndgp_num = ", dgp_num, "\nsnr = ", snr,
130+
"\ntest_set_pct = ", test_set_pct, "\nnum_threads = ", num_threads, "\n", sep = "")
131+
132+
# Run the performance evaluation
133+
results <- matrix(NA, nrow = n_iter, ncol = 4)
134+
colnames(results) <- c("iter", "rmse", "coverage", "runtime")
135+
for (i in 1:n_iter) {
136+
# Generate data
137+
if (dgp_num == 1) {
138+
data_list <- dgp1(n = n, p = p, snr = snr)
139+
} else if (dgp_num == 2) {
140+
data_list <- dgp2(n = n, p = p, snr = snr)
141+
} else if (dgp_num == 3) {
142+
data_list <- dgp3(n = n, p = p, snr = snr)
143+
} else if (dgp_num == 4) {
144+
data_list <- dgp4(n = n, p = p, snr = snr)
145+
} else {
146+
stop("Invalid DGP input")
147+
}
148+
covariates <- data_list[['covariates']]
149+
basis <- data_list[['basis']]
150+
conditional_mean <- data_list[['conditional_mean']]
151+
outcome <- data_list[['outcome']]
152+
rfx_group_ids <- data_list[['rfx_group_ids']]
153+
rfx_basis <- data_list[['rfx_basis']]
154+
155+
# Split into train / test sets
156+
subset_inds_list <- compute_test_train_indices(n, test_set_pct)
157+
test_inds <- subset_inds_list$test_inds
158+
train_inds <- subset_inds_list$train_inds
159+
covariates_train <- subset_data(covariates, train_inds)
160+
covariates_test <- subset_data(covariates, test_inds)
161+
outcome_train <- subset_data(outcome, train_inds)
162+
outcome_test <- subset_data(outcome, test_inds)
163+
conditional_mean_train <- subset_data(conditional_mean, train_inds)
164+
conditional_mean_test <- subset_data(conditional_mean, test_inds)
165+
has_basis <- !is.null(basis)
166+
has_rfx <- !is.null(rfx_group_ids)
167+
if (has_basis) {
168+
basis_train <- subset_data(basis, train_inds)
169+
basis_test <- subset_data(basis, test_inds)
170+
} else {
171+
basis_train <- NULL
172+
basis_test <- NULL
173+
}
174+
if (has_rfx) {
175+
rfx_group_ids_train <- subset_data(rfx_group_ids, train_inds)
176+
rfx_group_ids_test <- subset_data(rfx_group_ids, test_inds)
177+
rfx_basis_train <- subset_data(rfx_basis, train_inds)
178+
rfx_basis_test <- subset_data(rfx_basis, test_inds)
179+
} else {
180+
rfx_group_ids_train <- NULL
181+
rfx_group_ids_test <- NULL
182+
rfx_basis_train <- NULL
183+
rfx_basis_test <- NULL
184+
}
185+
186+
# Run (and time) BART
187+
bart_timing <- system.time({
188+
# Sample BART model
189+
general_params <- list(num_threads = num_threads)
190+
bart_model <- stochtree::bart(
191+
X_train = covariates_train, y_train = outcome_train, leaf_basis_train = basis_train,
192+
rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train,
193+
num_gfr = num_gfr, num_mcmc = num_mcmc, general_params = general_params
194+
)
195+
196+
# Predict on the test set
197+
test_preds <- predict(
198+
bart_model, X = covariates_test, leaf_basis = basis_test,
199+
rfx_group_ids = rfx_group_ids_test, rfx_basis = rfx_basis_test
200+
)
201+
})[3]
202+
203+
# Compute test set evals
204+
y_hat_posterior <- test_preds$y_hat
205+
y_hat_posterior_mean <- rowMeans(y_hat_posterior)
206+
rmse_test <- sqrt(mean((y_hat_posterior_mean - outcome_test)^2))
207+
y_hat_posterior_quantile_025 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.025))
208+
y_hat_posterior_quantile_975 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.975))
209+
covered <- rep(NA, nrow(y_hat_posterior))
210+
for (j in 1:nrow(y_hat_posterior)) {
211+
covered[j] <- (
212+
(conditional_mean_test[j] >= y_hat_posterior_quantile_025[j]) &
213+
(conditional_mean_test[j] <= y_hat_posterior_quantile_975[j])
214+
)
215+
}
216+
coverage_test <- mean(covered)
217+
218+
# Store evaluations
219+
results[i,] <- c(i, rmse_test, coverage_test, bart_timing)
220+
}
221+
222+
# Wrangle and save results to CSV
223+
results_df <- data.frame(
224+
cbind(n, p, num_gfr, num_mcmc, dgp_num, snr, test_set_pct, num_threads, results)
225+
)
226+
snr_rounded <- as.integer(snr)
227+
test_set_pct_rounded <- as.integer(test_set_pct*100)
228+
num_threads_clean <- ifelse(num_threads < 0, 0, num_threads)
229+
filename <- paste(
230+
"stochtree", "bart", "r", "n", n, "p", p, "num_gfr", num_gfr, "num_mcmc", num_mcmc,
231+
"dgp_num", dgp_num, "snr", snr_rounded, "test_set_pct", test_set_pct_rounded,
232+
"num_threads", num_threads_clean, sep = "_"
233+
)
234+
filename_full <- paste0("tools/regression/stochtree_bart_r_results/", filename, ".csv")
235+
write.csv(x = results_df, file = filename_full, row.names = F)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
reg_test_dir <- "tools/regression/stochtree_bart_r_results"
2+
reg_test_files <- list.files(reg_test_dir, pattern = ".csv", full.names = T)
3+
4+
reg_test_df <- data.frame()
5+
for (file in reg_test_files) {
6+
temp_df <- read.csv(file)
7+
reg_test_df <- rbind(reg_test_df, temp_df)
8+
}
9+
10+
summary_df <- aggregate(
11+
cbind(rmse, coverage, runtime) ~ n + p + num_gfr + num_mcmc + dgp_num + snr + test_set_pct + num_threads,
12+
data = reg_test_df, FUN = median, drop = TRUE
13+
)
14+
15+
summary_file_output <- file.path(reg_test_dir, "stochtree_bart_r_summary.csv")
16+
write.csv(summary_df, summary_file_output, row.names = F)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Test case parameters
2+
dgps <- 1:4
3+
ns <- c(1000, 10000)
4+
ps <- c(5, 20)
5+
threads <- c(-1, 1)
6+
varying_param_grid <- expand.grid(dgps, ns, ps, threads)
7+
test_case_grid <- cbind(
8+
5, varying_param_grid[,2], varying_param_grid[,3],
9+
10, 100, varying_param_grid[,1], 2.0, 0.2, varying_param_grid[,4]
10+
)
11+
12+
# Run script for every case
13+
script_path <- "tools/regression/individual_regression_test_bart.R"
14+
for (i in 1:nrow(test_case_grid)) {
15+
n_iter <- test_case_grid[i,1]
16+
n <- test_case_grid[i,2]
17+
p <- test_case_grid[i,3]
18+
num_gfr <- test_case_grid[i,4]
19+
num_mcmc <- test_case_grid[i,5]
20+
dgp_num <- test_case_grid[i,6]
21+
snr <- test_case_grid[i,7]
22+
test_set_pct <- test_case_grid[i,8]
23+
num_threads <- test_case_grid[i,9]
24+
system2(
25+
"Rscript",
26+
args = c(script_path, n_iter, n, p, num_gfr, num_mcmc, dgp_num, snr, test_set_pct, num_threads)
27+
)
28+
}

0 commit comments

Comments
 (0)