From 250f240211f81adc6e50d4094d52579e2700c21e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 20 Jan 2025 01:02:03 -0600 Subject: [PATCH 1/3] Added internal BART serialization to BCF --- stochtree/bcf.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 5fd447d3..8c7ca21c 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1402,6 +1402,7 @@ def to_json(self) -> str: bcf_json.add_scalar("num_samples", self.num_samples) bcf_json.add_boolean("adaptive_coding", self.adaptive_coding) bcf_json.add_string("propensity_covariate", self.propensity_covariate) + bcf_json.add_boolean("internal_propensity_model", self.internal_propensity_model) # Add parameter samples if self.sample_sigma_global: @@ -1414,6 +1415,11 @@ def to_json(self) -> str: bcf_json.add_numeric_vector("b0_samples", self.b0_samples, "parameters") bcf_json.add_numeric_vector("b1_samples", self.b1_samples, "parameters") + # Add propensity model (if it exists) + if self.internal_propensity_model: + bart_propensity_string = self.bart_propensity_model.to_json() + bcf_json.add_string("bart_propensity_model", bart_propensity_string) + return bcf_json.return_json_string() def from_json(self, json_string: str) -> None: @@ -1457,6 +1463,7 @@ def from_json(self, json_string: str) -> None: self.num_samples = int(bcf_json.get_scalar("num_samples")) self.adaptive_coding = bcf_json.get_boolean("adaptive_coding") self.propensity_covariate = bcf_json.get_string("propensity_covariate") + self.internal_propensity_model = bcf_json.get_boolean("internal_propensity_model") # Unpack parameter samples if self.sample_sigma_global: @@ -1469,6 +1476,12 @@ def from_json(self, json_string: str) -> None: self.b1_samples = bcf_json.get_numeric_vector("b1_samples", "parameters") self.b0_samples = bcf_json.get_numeric_vector("b0_samples", "parameters") + # Unpack internal propensity model + if self.internal_propensity_model: + bart_propensity_string = bcf_json.get_string("bart_propensity_model") + self.bart_propensity_model = BARTModel() + self.bart_propensity_model.from_json(bart_propensity_string) + # Mark the deserialized model as "sampled" self.sampled = True From 3640d0fc93452e02250e4a508b83eb5ea8c63e61 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 20 Jan 2025 01:05:54 -0600 Subject: [PATCH 2/3] Added BCF propensity serialization unit test --- test/python/test_json.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/python/test_json.py b/test/python/test_json.py index 55d30286..2bd71cd8 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -242,3 +242,39 @@ def test_bcf_string(self): np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) + + def test_bcf_propensity_string(self): + # RNG + random_seed = 1234 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 5 + X = rng.uniform(0, 1, (n, p_X)) + pi_X = 0.25 + 0.5*X[:,0] + Z = rng.binomial(1, pi_X, n).astype(float) + + # Define the outcome mean functions (prognostic and treatment effects) + mu_X = pi_X*5 + tau_X = X[:,1]*2 + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = mu_X + tau_X*Z + epsilon + + # Run BCF without passing propensity scores (so an internal propensity model must be constructed) + bcf_orig = BCFModel() + bcf_orig.sample(X_train=X, Z_train=Z, y_train=y, num_gfr=10, num_mcmc=10) + + # Extract predictions from the sampler + mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_orig.predict(X, Z, pi_X) + + # "Round-trip" the model to JSON string and back and check that the predictions agree + bcf_json_string = bcf_orig.to_json() + bcf_reloaded = BCFModel() + bcf_reloaded.from_json(bcf_json_string) + mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_reloaded.predict(X, Z, pi_X) + np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) + np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) + np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) From 3a851942e48f10233bc1493e5a831ecabf2bd744 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 20 Jan 2025 01:51:11 -0600 Subject: [PATCH 3/3] Updated R BCF serialization and unit tests --- R/bcf.R | 39 ++++++- test/R/testthat/test-serialization.R | 162 +++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 test/R/testthat/test-serialization.R diff --git a/R/bcf.R b/R/bcf.R index 8bab8a06..bc5b9d5f 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -613,11 +613,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU } # Estimate if pre-estimated propensity score is not provided + internal_propensity_model <- F if ((is.null(pi_train)) && (propensity_covariate != "none")) { + internal_propensity_model <- T # Estimate using the last of several iterations of GFR BART num_burnin <- 10 num_total <- 50 - bart_model_propensity <- bart(X_train = X_train_raw, y_train = as.numeric(Z_train), X_test = X_test_raw, + bart_model_propensity <- bart(X_train = X_train, y_train = as.numeric(Z_train), X_test = X_test_raw, num_gfr = num_total, num_burnin = 0, num_mcmc = 0) pi_train <- rowMeans(bart_model_propensity$y_hat_train[,(num_burnin+1):num_total]) if ((is.null(dim(pi_train))) && (!is.null(pi_train))) { @@ -1233,6 +1235,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "propensity_covariate" = propensity_covariate, "binary_treatment" = binary_treatment, "adaptive_coding" = adaptive_coding, + "internal_propensity_model" = internal_propensity_model, "num_samples" = num_retained_samples, "num_gfr" = num_gfr, "num_burnin" = num_burnin, @@ -1277,6 +1280,9 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU result[["rfx_unique_group_ids"]] = levels(group_ids_factor) } if ((has_rfx_test) && (has_test)) result[["rfx_preds_test"]] = rfx_preds_test + if (internal_propensity_model) { + result[["bart_propensity_model"]] = bart_model_propensity + } class(result) <- "bcf" return(result) @@ -1366,7 +1372,11 @@ predict.bcf <- function(bcf, X_test, Z_test, pi_test = NULL, group_ids_test = NU # Data checks if ((bcf$model_params$propensity_covariate != "none") && (is.null(pi_test))) { - stop("pi_test must be provided for this model") + if (!bcf$model_params$internal_propensity_model) { + stop("pi_test must be provided for this model") + } + # Compute propensity score using the internal bart model + pi_test <- rowMeans(predict(bcf$bart_propensity_model, X_test)$y_hat) } if (nrow(X_test) != nrow(Z_test)) { stop("X_test and Z_test must have the same number of rows") @@ -1662,6 +1672,7 @@ convertBCFModelToJson <- function(object){ jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) + jsonobj$add_boolean("internal_propensity_model", object$model_params$internal_propensity_model) jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) @@ -1689,6 +1700,14 @@ convertBCFModelToJson <- function(object){ jsonobj$add_string_vector("rfx_unique_group_ids", object$rfx_unique_group_ids) } + # Add propensity model (if it exists) + if (object$model_params$internal_propensity_model) { + bart_propensity_string <- saveBARTModelToJsonString( + object$bart_propensity_model + ) + jsonobj$add_string("bart_propensity_model", bart_propensity_string) + } + return(jsonobj) } @@ -1962,6 +1981,7 @@ createBCFModelFromJson <- function(json_object){ model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") model_params[["adaptive_coding"]] <- json_object$get_boolean("adaptive_coding") + model_params[["internal_propensity_model"]] <- json_object$get_boolean("internal_propensity_model") model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") @@ -1990,6 +2010,14 @@ createBCFModelFromJson <- function(json_object){ output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) } + # Unpack propensity model (if it exists) + if (model_params[["internal_propensity_model"]]) { + bart_propensity_string <- json_object$get_string("bart_propensity_model") + output[["bart_propensity_model"]] <- createBARTModelFromJsonString( + bart_propensity_string + ) + } + class(output) <- "bcf" return(output) } @@ -2229,6 +2257,12 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ for (i in 1:length(json_string_list)) { json_string <- json_string_list[[i]] json_object_list[[i]] <- createCppJsonString(json_string) + # Add runtime check for separately serialized propensity models + # We don't support merging BCF models with independent propensity models + # this way at the moment + if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { + stop("Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling.") + } } # For scalar / preprocessing details which aren't sample-dependent, @@ -2279,6 +2313,7 @@ createBCFModelFromCombinedJsonString <- function(json_string_list){ model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") model_params[["adaptive_coding"]] <- json_object_default$get_boolean("adaptive_coding") + model_params[["internal_propensity_model"]] <- json_object_default$get_boolean("internal_propensity_model") # Combine values that are sample-specific for (i in 1:length(json_object_list)) { diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R new file mode 100644 index 00000000..fcce0bb3 --- /dev/null +++ b/test/R/testthat/test-serialization.R @@ -0,0 +1,162 @@ +test_that("BART Serialization", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Sample a BART model + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart(X_train = X_train, y_train = y_train, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + general_params = general_param_list) + y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) + + # Save to JSON + bart_json_string <- saveBARTModelToJsonString(bart_model) + + # Reload as a BART model + bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) + + # Predict from the roundtrip BART model + y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) +}) + +test_that("BCF Serialization", { + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1,x2,x3,x4,x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5*X[,1] + mu_x <- pi_x * 5 + tau_x <- X[,2] * 2 + Z <- rbinom(n,1,pi_x) + E_XZ <- mu_x + Z*tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + pi_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) +}) + +test_that("BCF Serialization (no propensity)", { + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1,x2,x3,x4,x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5*X[,1] + mu_x <- pi_x * 5 + tau_x <- X[,2] * 2 + Z <- rbinom(n,1,pi_x) + E_XZ <- mu_x + Z*tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + num_gfr = 100, num_burnin = 0, num_mcmc = 100) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) +}) \ No newline at end of file