From ce57b767ef5a2a5f7ca75599613c487ccc6da343 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 29 Sep 2025 16:46:27 -0500 Subject: [PATCH] Add AIR format file and initial debug script for stochtree's propagation of random seed --- air.toml | 9 ++++++ tools/debug/bart_random_seed.R | 55 ++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 air.toml create mode 100644 tools/debug/bart_random_seed.R diff --git a/air.toml b/air.toml new file mode 100644 index 00000000..22e43b99 --- /dev/null +++ b/air.toml @@ -0,0 +1,9 @@ +[format] +line-width = 80 +indent-width = 2 +indent-style = "space" +line-ending = "auto" +persistent-line-breaks = true +exclude = [] +default-exclude = true +skip = [] \ No newline at end of file diff --git a/tools/debug/bart_random_seed.R b/tools/debug/bart_random_seed.R new file mode 100644 index 00000000..c18bc615 --- /dev/null +++ b/tools/debug/bart_random_seed.R @@ -0,0 +1,55 @@ +# Load libraries +library(stochtree) + +# Generate data +random_seed <- 1234 +set.seed(random_seed) +n <- 500 +p <- 50 +X <- matrix(runif(n * p), ncol = p) +# fmt: skip +f_XW <- ( + ((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) + +# Split into train and test sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Run BART model +general_params <- list(num_threads = 1, random_seed = random_seed) +bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 100, + num_mcmc = 100, + general_params = general_params +) + +# # Save results +# write.csv( +# bart_model$y_hat_test, +# file = "tools/debug/seed_benchmark_y_hat.csv", +# row.names = FALSE +# ) + +# Read results and compare to our estimates +y_hat_test_benchmark <- as.matrix(read.csv( + "tools/debug/seed_benchmark_y_hat.csv" +)) + +# Compare results +sum(abs(y_hat_test_benchmark - bart_model$y_hat_test) > 1e-6)