|
| 1 | +# Load libraries |
| 2 | +library(stochtree) |
| 3 | + |
| 4 | +# Define DGPs |
| 5 | +dgp1 <- function(n, p, snr) { |
| 6 | + X <- matrix(rnorm(n*p), ncol = p) |
| 7 | + plm_term <- ( |
| 8 | + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + |
| 9 | + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + |
| 10 | + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + |
| 11 | + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) |
| 12 | + ) |
| 13 | + trig_term <- ( |
| 14 | + 2*sin(X[,3]*2*pi) - |
| 15 | + 2*cos(X[,4]*2*pi) |
| 16 | + ) |
| 17 | + mu_x <- plm_term + trig_term |
| 18 | + pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 |
| 19 | + Z <- rbinom(n,1,pi_x) |
| 20 | + tau_x <- 1 + 2*X[,2]*X[,4] |
| 21 | + f_XZ <- mu_x + tau_x * Z |
| 22 | + noise_sd <- sd(f_XZ)/snr |
| 23 | + y <- f_XZ + rnorm(n, 0, noise_sd) |
| 24 | + return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x, |
| 25 | + prognostic_effect = mu_x, treatment_effect = tau_x, |
| 26 | + conditional_mean = f_XZ, rfx_group_ids = NULL, rfx_basis = NULL)) |
| 27 | +} |
| 28 | +dgp2 <- function(n, p, snr) { |
| 29 | + X <- matrix(runif(n*p), ncol = p) |
| 30 | + pi_x <- cbind(0.125 + 0.75 * X[, 1], 0.875 - 0.75 * X[, 2]) |
| 31 | + mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3] |
| 32 | + tau_x <- cbind(X[, 2], X[, 3]) * 2 |
| 33 | + Z <- matrix(NA_real_, nrow = n, ncol = ncol(pi_x)) |
| 34 | + for (i in 1:ncol(pi_x)) { |
| 35 | + Z[, i] <- rbinom(n, 1, pi_x[, i]) |
| 36 | + } |
| 37 | + f_XZ <- mu_x + rowSums(Z * tau_x) |
| 38 | + noise_sd <- sd(f_XZ)/snr |
| 39 | + y <- f_XZ + rnorm(n, 0, noise_sd) |
| 40 | + return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x, |
| 41 | + prognostic_effect = mu_x, treatment_effect = tau_x, |
| 42 | + conditional_mean = f_XZ, rfx_group_ids = NULL, rfx_basis = NULL)) |
| 43 | +} |
| 44 | +dgp3 <- function(n, p, snr) { |
| 45 | + X <- matrix(rnorm(n*p), ncol = p) |
| 46 | + plm_term <- ( |
| 47 | + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*X[,2]) + |
| 48 | + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*X[,2]) + |
| 49 | + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*X[,2]) + |
| 50 | + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*X[,2]) |
| 51 | + ) |
| 52 | + trig_term <- ( |
| 53 | + 2*sin(X[,3]*2*pi) - |
| 54 | + 2*cos(X[,4]*2*pi) |
| 55 | + ) |
| 56 | + mu_x <- plm_term + trig_term |
| 57 | + pi_x <- 0.8*pnorm((3*mu_x/sd(mu_x)) - 0.5*X[,1]) + 0.05 + runif(n)/10 |
| 58 | + Z <- rbinom(n,1,pi_x) |
| 59 | + tau_x <- 1 + 2*X[,2]*X[,4] |
| 60 | + rfx_group_ids <- sample(1:3, size = n, replace = T) |
| 61 | + rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE)) |
| 62 | + rfx_basis <- cbind(1, runif(n, -1, 1)) |
| 63 | + rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) |
| 64 | + f_XZ <- mu_x + tau_x * Z + rfx_term |
| 65 | + noise_sd <- sd(f_XZ)/snr |
| 66 | + y <- f_XZ + rnorm(n, 0, noise_sd) |
| 67 | + return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x, |
| 68 | + prognostic_effect = mu_x, treatment_effect = tau_x, |
| 69 | + conditional_mean = f_XZ, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis)) |
| 70 | +} |
| 71 | +dgp4 <- function(n, p, snr) { |
| 72 | + X <- matrix(runif(n*p), ncol = p) |
| 73 | + pi_x <- cbind(0.125 + 0.75 * X[, 1], 0.875 - 0.75 * X[, 2]) |
| 74 | + mu_x <- pi_x[, 1] * 5 + pi_x[, 2] * 2 + 2 * X[, 3] |
| 75 | + tau_x <- cbind(X[, 2], X[, 3]) * 2 |
| 76 | + Z <- matrix(NA_real_, nrow = n, ncol = ncol(pi_x)) |
| 77 | + for (i in 1:ncol(pi_x)) { |
| 78 | + Z[, i] <- rbinom(n, 1, pi_x[, i]) |
| 79 | + } |
| 80 | + rfx_group_ids <- sample(1:3, size = n, replace = T) |
| 81 | + rfx_coefs <- t(matrix(c(-5, -3, -1, 5, 3, 1), nrow=2, byrow=TRUE)) |
| 82 | + rfx_basis <- cbind(1, runif(n, -1, 1)) |
| 83 | + rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) |
| 84 | + f_XZ <- mu_x + rowSums(Z * tau_x) + rfx_term |
| 85 | + noise_sd <- sd(f_XZ)/snr |
| 86 | + y <- f_XZ + rnorm(n, 0, noise_sd) |
| 87 | + return(list(covariates = X, treatment = Z, outcome = y, propensity = pi_x, |
| 88 | + prognostic_effect = mu_x, treatment_effect = tau_x, |
| 89 | + conditional_mean = f_XZ, rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis)) |
| 90 | +} |
| 91 | + |
| 92 | +# Test / train split utilities |
| 93 | +compute_test_train_indices <- function(n, test_set_pct) { |
| 94 | + n_test <- round(test_set_pct*n) |
| 95 | + n_train <- n - n_test |
| 96 | + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) |
| 97 | + train_inds <- (1:n)[!((1:n) %in% test_inds)] |
| 98 | + return(list(test_inds = test_inds, train_inds = train_inds)) |
| 99 | +} |
| 100 | +subset_data <- function(data, subset_inds) { |
| 101 | + if (is.matrix(data)) { |
| 102 | + return(data[subset_inds,]) |
| 103 | + } else { |
| 104 | + return(data[subset_inds]) |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +# Capture command line arguments |
| 109 | +args <- commandArgs(trailingOnly = T) |
| 110 | +if (length(args) > 0){ |
| 111 | + n_iter <- as.integer(args[1]) |
| 112 | + n <- as.integer(args[2]) |
| 113 | + p <- as.integer(args[3]) |
| 114 | + num_gfr <- as.integer(args[4]) |
| 115 | + num_mcmc <- as.integer(args[5]) |
| 116 | + dgp_num <- as.integer(args[6]) |
| 117 | + snr <- as.numeric(args[7]) |
| 118 | + test_set_pct <- as.numeric(args[8]) |
| 119 | + num_threads <- as.integer(args[9]) |
| 120 | +} else{ |
| 121 | + # Default arguments |
| 122 | + n_iter <- 5 |
| 123 | + n <- 1000 |
| 124 | + p <- 5 |
| 125 | + num_gfr <- 10 |
| 126 | + num_mcmc <- 100 |
| 127 | + dgp_num <- 1 |
| 128 | + snr <- 2.0 |
| 129 | + test_set_pct <- 0.2 |
| 130 | + num_threads <- -1 |
| 131 | +} |
| 132 | +cat("n_iter = ", n_iter, "\nn = ", n, "\np = ", p, "\nnum_gfr = ", num_gfr, |
| 133 | + "\nnum_mcmc = ", num_mcmc, "\ndgp_num = ", dgp_num, "\nsnr = ", snr, |
| 134 | + "\ntest_set_pct = ", test_set_pct, "\nnum_threads = ", num_threads, "\n", sep = "") |
| 135 | + |
| 136 | +# Run the performance evaluation |
| 137 | +results <- matrix(NA, nrow = n_iter, ncol = 6) |
| 138 | +colnames(results) <- c("iter", "outcome_rmse", "outcome_coverage", "treatment_effect_rmse", "treatment_effect_coverage", "runtime") |
| 139 | +for (i in 1:n_iter) { |
| 140 | + # Generate data |
| 141 | + if (dgp_num == 1) { |
| 142 | + data_list <- dgp1(n = n, p = p, snr = snr) |
| 143 | + } else if (dgp_num == 2) { |
| 144 | + data_list <- dgp2(n = n, p = p, snr = snr) |
| 145 | + } else if (dgp_num == 3) { |
| 146 | + data_list <- dgp3(n = n, p = p, snr = snr) |
| 147 | + } else if (dgp_num == 4) { |
| 148 | + data_list <- dgp4(n = n, p = p, snr = snr) |
| 149 | + } else { |
| 150 | + stop("Invalid DGP input") |
| 151 | + } |
| 152 | + covariates <- data_list[['covariates']] |
| 153 | + treatment <- data_list[['treatment']] |
| 154 | + propensity <- data_list[['propensity']] |
| 155 | + prognostic_effect <- data_list[['prognostic_effect']] |
| 156 | + treatment_effect <- data_list[['treatment_effect']] |
| 157 | + conditional_mean <- data_list[['conditional_mean']] |
| 158 | + outcome <- data_list[['outcome']] |
| 159 | + rfx_group_ids <- data_list[['rfx_group_ids']] |
| 160 | + rfx_basis <- data_list[['rfx_basis']] |
| 161 | + if (dgp_num %in% c(2,4)) { |
| 162 | + has_multivariate_treatment <- T |
| 163 | + } else { |
| 164 | + has_multivariate_treatment <- F |
| 165 | + } |
| 166 | + |
| 167 | + # Split into train / test sets |
| 168 | + subset_inds_list <- compute_test_train_indices(n, test_set_pct) |
| 169 | + test_inds <- subset_inds_list$test_inds |
| 170 | + train_inds <- subset_inds_list$train_inds |
| 171 | + covariates_train <- subset_data(covariates, train_inds) |
| 172 | + covariates_test <- subset_data(covariates, test_inds) |
| 173 | + treatment_train <- subset_data(treatment, train_inds) |
| 174 | + treatment_test <- subset_data(treatment, test_inds) |
| 175 | + propensity_train <- subset_data(propensity, train_inds) |
| 176 | + propensity_test <- subset_data(propensity, test_inds) |
| 177 | + outcome_train <- subset_data(outcome, train_inds) |
| 178 | + outcome_test <- subset_data(outcome, test_inds) |
| 179 | + prognostic_effect_train <- subset_data(prognostic_effect, train_inds) |
| 180 | + prognostic_effect_test <- subset_data(prognostic_effect, test_inds) |
| 181 | + treatment_effect_train <- subset_data(treatment_effect, train_inds) |
| 182 | + treatment_effect_test <- subset_data(treatment_effect, test_inds) |
| 183 | + conditional_mean_train <- subset_data(conditional_mean, train_inds) |
| 184 | + conditional_mean_test <- subset_data(conditional_mean, test_inds) |
| 185 | + has_rfx <- !is.null(rfx_group_ids) |
| 186 | + if (has_rfx) { |
| 187 | + rfx_group_ids_train <- subset_data(rfx_group_ids, train_inds) |
| 188 | + rfx_group_ids_test <- subset_data(rfx_group_ids, test_inds) |
| 189 | + rfx_basis_train <- subset_data(rfx_basis, train_inds) |
| 190 | + rfx_basis_test <- subset_data(rfx_basis, test_inds) |
| 191 | + } else { |
| 192 | + rfx_group_ids_train <- NULL |
| 193 | + rfx_group_ids_test <- NULL |
| 194 | + rfx_basis_train <- NULL |
| 195 | + rfx_basis_test <- NULL |
| 196 | + } |
| 197 | + |
| 198 | + # Run (and time) BCF |
| 199 | + bcf_timing <- system.time({ |
| 200 | + # Sample BCF model |
| 201 | + general_params <- list(num_threads = num_threads, adaptive_coding = F) |
| 202 | + prognostic_forest_params <- list(sample_sigma2_leaf = F) |
| 203 | + treatment_effect_forest_params <- list(sample_sigma2_leaf = F) |
| 204 | + bcf_model <- stochtree::bcf( |
| 205 | + X_train = covariates_train, Z_train = treatment_train, |
| 206 | + propensity_train = propensity_train, y_train = outcome_train, |
| 207 | + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, |
| 208 | + num_gfr = num_gfr, num_mcmc = num_mcmc, general_params = general_params, |
| 209 | + prognostic_forest_params = prognostic_forest_params, |
| 210 | + treatment_effect_forest_params = treatment_effect_forest_params |
| 211 | + ) |
| 212 | + |
| 213 | + # Predict on the test set |
| 214 | + test_preds <- predict( |
| 215 | + bcf_model, X = covariates_test, Z = treatment_test, propensity = propensity_test, |
| 216 | + rfx_group_ids = rfx_group_ids_test, rfx_basis = rfx_basis_test |
| 217 | + ) |
| 218 | + })[3] |
| 219 | + |
| 220 | + # Compute test set evals |
| 221 | + y_hat_posterior <- test_preds$y_hat |
| 222 | + y_hat_posterior_mean <- rowMeans(y_hat_posterior) |
| 223 | + tau_hat_posterior <- test_preds$tau_hat |
| 224 | + if (has_multivariate_treatment) tau_hat_posterior_mean <- apply(tau_hat_posterior, c(1,2), mean) |
| 225 | + else tau_hat_posterior_mean <- apply(tau_hat_posterior, 1, mean) |
| 226 | + y_hat_rmse_test <- sqrt(mean((y_hat_posterior_mean - outcome_test)^2)) |
| 227 | + tau_hat_rmse_test <- sqrt(mean((tau_hat_posterior_mean - treatment_effect_test)^2)) |
| 228 | + y_hat_posterior_quantile_025 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.025)) |
| 229 | + y_hat_posterior_quantile_975 <- apply(y_hat_posterior, 1, function(x) quantile(x, 0.975)) |
| 230 | + if (has_multivariate_treatment) { |
| 231 | + tau_hat_posterior_quantile_025 <- apply(tau_hat_posterior, c(1,2), function(x) quantile(x, 0.025)) |
| 232 | + tau_hat_posterior_quantile_975 <- apply(tau_hat_posterior, c(1,2), function(x) quantile(x, 0.975)) |
| 233 | + } else { |
| 234 | + tau_hat_posterior_quantile_025 <- apply(tau_hat_posterior, 1, function(x) quantile(x, 0.025)) |
| 235 | + tau_hat_posterior_quantile_975 <- apply(tau_hat_posterior, 1, function(x) quantile(x, 0.975)) |
| 236 | + } |
| 237 | + y_hat_covered <- rep(NA, nrow(y_hat_posterior)) |
| 238 | + for (j in 1:nrow(y_hat_posterior)) { |
| 239 | + y_hat_covered[j] <- ( |
| 240 | + (conditional_mean_test[j] >= y_hat_posterior_quantile_025[j]) & |
| 241 | + (conditional_mean_test[j] <= y_hat_posterior_quantile_975[j]) |
| 242 | + ) |
| 243 | + } |
| 244 | + y_hat_coverage_test <- mean(y_hat_covered) |
| 245 | + if (has_multivariate_treatment) { |
| 246 | + tau_hat_covered <- matrix(NA_real_, nrow(tau_hat_posterior_mean), ncol(tau_hat_posterior_mean)) |
| 247 | + for (j in 1:nrow(tau_hat_covered)) { |
| 248 | + for (k in 1:ncol(tau_hat_covered)) { |
| 249 | + tau_hat_covered[j,k] <- ( |
| 250 | + (treatment_effect_test[j,k] >= tau_hat_posterior_quantile_025[j,k]) & |
| 251 | + (treatment_effect_test[j,k] <= tau_hat_posterior_quantile_975[j,k]) |
| 252 | + ) |
| 253 | + } |
| 254 | + } |
| 255 | + } else { |
| 256 | + tau_hat_covered <- rep(NA, nrow(tau_hat_posterior)) |
| 257 | + for (j in 1:nrow(tau_hat_posterior)) { |
| 258 | + tau_hat_covered[j] <- ( |
| 259 | + (treatment_effect_test[j] >= tau_hat_posterior_quantile_025[j]) & |
| 260 | + (treatment_effect_test[j] <= tau_hat_posterior_quantile_025[j]) |
| 261 | + ) |
| 262 | + } |
| 263 | + } |
| 264 | + tau_hat_coverage_test <- mean(tau_hat_covered) |
| 265 | + |
| 266 | + # Store evaluations |
| 267 | + results[i,] <- c(i, y_hat_rmse_test, y_hat_coverage_test, tau_hat_rmse_test, tau_hat_coverage_test, bcf_timing) |
| 268 | +} |
| 269 | + |
| 270 | +# Wrangle and save results to CSV |
| 271 | +results_df <- data.frame( |
| 272 | + cbind(n, p, num_gfr, num_mcmc, dgp_num, snr, test_set_pct, num_threads, results) |
| 273 | +) |
| 274 | +snr_rounded <- as.integer(snr) |
| 275 | +test_set_pct_rounded <- as.integer(test_set_pct*100) |
| 276 | +num_threads_clean <- ifelse(num_threads < 0, 0, num_threads) |
| 277 | +filename <- paste( |
| 278 | + "stochtree", "bcf", "r", "n", n, "p", p, "num_gfr", num_gfr, "num_mcmc", num_mcmc, |
| 279 | + "dgp_num", dgp_num, "snr", snr_rounded, "test_set_pct", test_set_pct_rounded, |
| 280 | + "num_threads", num_threads_clean, sep = "_" |
| 281 | +) |
| 282 | +filename_full <- paste0("tools/regression/bcf/stochtree_bcf_r_results/", filename, ".csv") |
| 283 | +write.csv(x = results_df, file = filename_full, row.names = F) |
0 commit comments