From b90ceb01341ab54b66b95c5c9a6471444d70d99d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 14 Nov 2024 21:26:11 -0600 Subject: [PATCH] Made standardization option in BART / BCF for R / Python --- R/bart.R | 16 ++++++++++++++-- R/bcf.R | 16 +++++++++++++--- R/utils.R | 6 ++++-- stochtree/bart.py | 11 +++++++++-- stochtree/bcf.py | 9 +++++++-- stochtree/preprocessing.py | 6 ++++-- 6 files changed, 51 insertions(+), 13 deletions(-) diff --git a/R/bart.R b/R/bart.R index a6668ce7..1e058765 100644 --- a/R/bart.R +++ b/R/bart.R @@ -41,6 +41,7 @@ #' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. #' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `TRUE`. Ignored if `num_mcmc = 0`. +#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' #' **2. Mean Forest Parameters** @@ -146,6 +147,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, random_seed <- bart_params$random_seed keep_burnin <- bart_params$keep_burnin keep_gfr <- bart_params$keep_gfr + standardize <- bart_params$standardize verbose <- bart_params$verbose # Determine whether conditional mean, variance, or both will be modeled @@ -309,8 +311,13 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, has_test = !is.null(X_test) # Standardize outcome separately for test and train - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 + } resid_train <- (y_train-y_bar_train)/y_std_train resid_train <- resid_train*sqrt(variance_scale) @@ -623,6 +630,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, "b_forest" = b_forest, "outcome_mean" = y_bar_train, "outcome_scale" = y_std_train, + "standardize" = standardize, "output_dimension" = output_dimension, "is_leaf_constant" = is_leaf_constant, "leaf_regression" = leaf_regression, @@ -972,6 +980,7 @@ convertBARTModelToJson <- function(object){ jsonobj$add_scalar("variance_scale", object$model_params$variance_scale) jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf) @@ -1152,6 +1161,7 @@ createBARTModelFromJson <- function(json_object){ model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") @@ -1348,6 +1358,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){ model_params = list() model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf") @@ -1499,6 +1510,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale") model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object_default$get_boolean("standardize") model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init") model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf") diff --git a/R/bcf.R b/R/bcf.R index 41f34623..4f30e7a0 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -41,6 +41,7 @@ #' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`. #' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. #' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`. +#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`. #' @@ -214,6 +215,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU random_seed <- bcf_params$random_seed keep_burnin <- bcf_params$keep_burnin keep_gfr <- bcf_params$keep_gfr + standardize <- bcf_params$standardize verbose <- bcf_params$verbose # Determine whether conditional variance will be modeled @@ -563,8 +565,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU } # Standardize outcome separately for test and train - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 + } resid_train <- (y_train-y_bar_train)/y_std_train # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau @@ -991,7 +998,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU "a_forest" = a_forest, "b_forest" = b_forest, "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, "num_covariates" = num_cov_orig, "num_prognostic_covariates" = sum(variable_weights_mu > 0), "num_treatment_covariates" = sum(variable_weights_tau > 0), @@ -1429,6 +1437,7 @@ convertBCFModelToJson <- function(object){ # Add global parameters jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global) jsonobj$add_boolean("sample_sigma_leaf_mu", object$model_params$sample_sigma_leaf_mu) @@ -1719,6 +1728,7 @@ createBCFModelFromJson <- function(json_object){ model_params = list() model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global") model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu") diff --git a/R/utils.R b/R/utils.R index 4d42c835..34edc32a 100644 --- a/R/utils.R +++ b/R/utils.R @@ -19,7 +19,8 @@ preprocessBartParams <- function(params) { variable_weights_mean = NULL, variable_weights_variance = NULL, num_trees_mean = 200, num_trees_variance = 0, sample_sigma_global = T, sample_sigma_leaf = F, - random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F + random_seed = -1, keep_burnin = F, keep_gfr = F, + standardize = T, verbose = F ) # Override defaults @@ -57,7 +58,8 @@ preprocessBcfParams <- function(params) { drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, num_trees_variance = 0, num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T, sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5, - rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F + rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, + standardize = T, verbose = F ) # Override defaults diff --git a/stochtree/bart.py b/stochtree/bart.py index 44483e54..1c8afba2 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -118,6 +118,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N random_seed = bart_params['random_seed'] keep_burnin = bart_params['keep_burnin'] keep_gfr = bart_params['keep_gfr'] + self.standardize = bart_params['standardize'] # Determine which models (conditional mean, conditional variance, or both) we will fit self.include_mean_forest = True if num_trees_mean > 0 else False @@ -213,8 +214,12 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N variable_weights_variance = variable_weights_variance[original_var_indices]*variable_weights_adj # Scale outcome - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 if variance_scale > 0: self.variance_scale = variance_scale else: @@ -626,6 +631,7 @@ def to_json(self) -> str: bart_json.add_scalar("variance_scale", self.variance_scale) bart_json.add_scalar("outcome_scale", self.y_std) bart_json.add_scalar("outcome_mean", self.y_bar) + bart_json.add_boolean("standardize", self.standardize) bart_json.add_scalar("sigma2_init", self.sigma2_init) bart_json.add_boolean("sample_sigma_global", self.sample_sigma_global) bart_json.add_boolean("sample_sigma_leaf", self.sample_sigma_leaf) @@ -680,6 +686,7 @@ def from_json(self, json_string: str) -> None: self.variance_scale = bart_json.get_scalar("variance_scale") self.y_std = bart_json.get_scalar("outcome_scale") self.y_bar = bart_json.get_scalar("outcome_mean") + self.standardize = bart_json.get_boolean("standardize") self.sigma2_init = bart_json.get_scalar("sigma2_init") self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global") self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf") diff --git a/stochtree/bcf.py b/stochtree/bcf.py index c709989f..ab53f2e3 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -149,6 +149,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr random_seed = bcf_params['random_seed'] keep_burnin = bcf_params['keep_burnin'] keep_gfr = bcf_params['keep_gfr'] + self.standardize = bcf_params['standardize'] # Variable weight preprocessing (and initialization if necessary) if variable_weights is None: @@ -495,8 +496,12 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr self.internal_propensity_model = False # Scale outcome - self.y_bar = np.squeeze(np.mean(y_train)) - self.y_std = np.squeeze(np.std(y_train)) + if self.standardize: + self.y_bar = np.squeeze(np.mean(y_train)) + self.y_std = np.squeeze(np.std(y_train)) + else: + self.y_bar = 0 + self.y_std = 1 resid_train = (y_train-self.y_bar)/self.y_std # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART) diff --git a/stochtree/preprocessing.py b/stochtree/preprocessing.py index a9a1b8d9..9dea3845 100644 --- a/stochtree/preprocessing.py +++ b/stochtree/preprocessing.py @@ -41,7 +41,8 @@ def _preprocess_bart_params(params: Optional[Dict[str, Any]] = None) -> Dict[str 'sample_sigma_leaf' : True, 'random_seed' : -1, 'keep_burnin' : False, - 'keep_gfr' : False + 'keep_gfr' : False, + 'standardize': True } if params: @@ -90,7 +91,8 @@ def _preprocess_bcf_params(params: Optional[Dict[str, Any]] = None) -> Dict[str, 'b_1': 0.5, 'random_seed': -1, 'keep_burnin': False, - 'keep_gfr': False + 'keep_gfr': False, + 'standardize': True } if params: