Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`.
#' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`.
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `TRUE`. Ignored if `num_mcmc = 0`.
#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`.
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#'
#' **2. Mean Forest Parameters**
Expand Down Expand Up @@ -146,6 +147,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
random_seed <- bart_params$random_seed
keep_burnin <- bart_params$keep_burnin
keep_gfr <- bart_params$keep_gfr
standardize <- bart_params$standardize
verbose <- bart_params$verbose

# Determine whether conditional mean, variance, or both will be modeled
Expand Down Expand Up @@ -309,8 +311,13 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
has_test = !is.null(X_test)

# Standardize outcome separately for test and train
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
y_std_train <- 1
}
resid_train <- (y_train-y_bar_train)/y_std_train
resid_train <- resid_train*sqrt(variance_scale)

Expand Down Expand Up @@ -623,6 +630,7 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
"b_forest" = b_forest,
"outcome_mean" = y_bar_train,
"outcome_scale" = y_std_train,
"standardize" = standardize,
"output_dimension" = output_dimension,
"is_leaf_constant" = is_leaf_constant,
"leaf_regression" = leaf_regression,
Expand Down Expand Up @@ -972,6 +980,7 @@ convertBARTModelToJson <- function(object){
jsonobj$add_scalar("variance_scale", object$model_params$variance_scale)
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
jsonobj$add_boolean("standardize", object$model_params$standardize)
jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init)
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
jsonobj$add_boolean("sample_sigma_leaf", object$model_params$sample_sigma_leaf)
Expand Down Expand Up @@ -1152,6 +1161,7 @@ createBARTModelFromJson <- function(json_object){
model_params[["variance_scale"]] <- json_object$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
model_params[["standardize"]] <- json_object$get_boolean("standardize")
model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
Expand Down Expand Up @@ -1348,6 +1358,7 @@ createBARTModelFromCombinedJson <- function(json_object_list){
model_params = list()
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
model_params[["standardize"]] <- json_object_default$get_boolean("standardize")
model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object$get_boolean("sample_sigma_leaf")
Expand Down Expand Up @@ -1499,6 +1510,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
model_params[["variance_scale"]] <- json_object_default$get_scalar("variance_scale")
model_params[["outcome_scale"]] <- json_object_default$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object_default$get_scalar("outcome_mean")
model_params[["standardize"]] <- json_object_default$get_boolean("standardize")
model_params[["sigma2_init"]] <- json_object_default$get_scalar("sigma2_init")
model_params[["sample_sigma_global"]] <- json_object_default$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf"]] <- json_object_default$get_boolean("sample_sigma_leaf")
Expand Down
16 changes: 13 additions & 3 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#' - `random_seed` Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' - `keep_burnin` Whether or not "burnin" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`.
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in cached predictions. Default `FALSE`. Ignored if `num_mcmc = 0`.
#' - `standardize` Whether or not to standardize the outcome (and store the offset / scale in the model object). Default: `TRUE`.
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
#' - `sample_sigma_global` Whether or not to update the `sigma^2` global error variance parameter based on `IG(a_global, b_global)`. Default: `TRUE`.
#'
Expand Down Expand Up @@ -214,6 +215,7 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
random_seed <- bcf_params$random_seed
keep_burnin <- bcf_params$keep_burnin
keep_gfr <- bcf_params$keep_gfr
standardize <- bcf_params$standardize
verbose <- bcf_params$verbose

# Determine whether conditional variance will be modeled
Expand Down Expand Up @@ -563,8 +565,13 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}

# Standardize outcome separately for test and train
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
if (standardize) {
y_bar_train <- mean(y_train)
y_std_train <- sd(y_train)
} else {
y_bar_train <- 0
y_std_train <- 1
}
resid_train <- (y_train-y_bar_train)/y_std_train

# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau
Expand Down Expand Up @@ -991,7 +998,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
"a_forest" = a_forest,
"b_forest" = b_forest,
"outcome_mean" = y_bar_train,
"outcome_scale" = y_std_train,
"outcome_scale" = y_std_train,
"standardize" = standardize,
"num_covariates" = num_cov_orig,
"num_prognostic_covariates" = sum(variable_weights_mu > 0),
"num_treatment_covariates" = sum(variable_weights_tau > 0),
Expand Down Expand Up @@ -1429,6 +1437,7 @@ convertBCFModelToJson <- function(object){
# Add global parameters
jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale)
jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean)
jsonobj$add_boolean("standardize", object$model_params$standardize)
jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2)
jsonobj$add_boolean("sample_sigma_global", object$model_params$sample_sigma_global)
jsonobj$add_boolean("sample_sigma_leaf_mu", object$model_params$sample_sigma_leaf_mu)
Expand Down Expand Up @@ -1719,6 +1728,7 @@ createBCFModelFromJson <- function(json_object){
model_params = list()
model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale")
model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean")
model_params[["standardize"]] <- json_object$get_boolean("standardize")
model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2")
model_params[["sample_sigma_global"]] <- json_object$get_boolean("sample_sigma_global")
model_params[["sample_sigma_leaf_mu"]] <- json_object$get_boolean("sample_sigma_leaf_mu")
Expand Down
6 changes: 4 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ preprocessBartParams <- function(params) {
variable_weights_mean = NULL, variable_weights_variance = NULL,
num_trees_mean = 200, num_trees_variance = 0,
sample_sigma_global = T, sample_sigma_leaf = F,
random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F
random_seed = -1, keep_burnin = F, keep_gfr = F,
standardize = T, verbose = F
)

# Override defaults
Expand Down Expand Up @@ -57,7 +58,8 @@ preprocessBcfParams <- function(params) {
drop_vars_variance = NULL, num_trees_mu = 250, num_trees_tau = 50, num_trees_variance = 0,
num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T,
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5, b_1 = 0.5,
rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F
rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F,
standardize = T, verbose = F
)

# Override defaults
Expand Down
11 changes: 9 additions & 2 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
random_seed = bart_params['random_seed']
keep_burnin = bart_params['keep_burnin']
keep_gfr = bart_params['keep_gfr']
self.standardize = bart_params['standardize']

# Determine which models (conditional mean, conditional variance, or both) we will fit
self.include_mean_forest = True if num_trees_mean > 0 else False
Expand Down Expand Up @@ -213,8 +214,12 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
variable_weights_variance = variable_weights_variance[original_var_indices]*variable_weights_adj

# Scale outcome
self.y_bar = np.squeeze(np.mean(y_train))
self.y_std = np.squeeze(np.std(y_train))
if self.standardize:
self.y_bar = np.squeeze(np.mean(y_train))
self.y_std = np.squeeze(np.std(y_train))
else:
self.y_bar = 0
self.y_std = 1
if variance_scale > 0:
self.variance_scale = variance_scale
else:
Expand Down Expand Up @@ -626,6 +631,7 @@ def to_json(self) -> str:
bart_json.add_scalar("variance_scale", self.variance_scale)
bart_json.add_scalar("outcome_scale", self.y_std)
bart_json.add_scalar("outcome_mean", self.y_bar)
bart_json.add_boolean("standardize", self.standardize)
bart_json.add_scalar("sigma2_init", self.sigma2_init)
bart_json.add_boolean("sample_sigma_global", self.sample_sigma_global)
bart_json.add_boolean("sample_sigma_leaf", self.sample_sigma_leaf)
Expand Down Expand Up @@ -680,6 +686,7 @@ def from_json(self, json_string: str) -> None:
self.variance_scale = bart_json.get_scalar("variance_scale")
self.y_std = bart_json.get_scalar("outcome_scale")
self.y_bar = bart_json.get_scalar("outcome_mean")
self.standardize = bart_json.get_boolean("standardize")
self.sigma2_init = bart_json.get_scalar("sigma2_init")
self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global")
self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf")
Expand Down
9 changes: 7 additions & 2 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
random_seed = bcf_params['random_seed']
keep_burnin = bcf_params['keep_burnin']
keep_gfr = bcf_params['keep_gfr']
self.standardize = bcf_params['standardize']

# Variable weight preprocessing (and initialization if necessary)
if variable_weights is None:
Expand Down Expand Up @@ -495,8 +496,12 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
self.internal_propensity_model = False

# Scale outcome
self.y_bar = np.squeeze(np.mean(y_train))
self.y_std = np.squeeze(np.std(y_train))
if self.standardize:
self.y_bar = np.squeeze(np.mean(y_train))
self.y_std = np.squeeze(np.std(y_train))
else:
self.y_bar = 0
self.y_std = 1
resid_train = (y_train-self.y_bar)/self.y_std

# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART)
Expand Down
6 changes: 4 additions & 2 deletions stochtree/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def _preprocess_bart_params(params: Optional[Dict[str, Any]] = None) -> Dict[str
'sample_sigma_leaf' : True,
'random_seed' : -1,
'keep_burnin' : False,
'keep_gfr' : False
'keep_gfr' : False,
'standardize': True
}

if params:
Expand Down Expand Up @@ -90,7 +91,8 @@ def _preprocess_bcf_params(params: Optional[Dict[str, Any]] = None) -> Dict[str,
'b_1': 0.5,
'random_seed': -1,
'keep_burnin': False,
'keep_gfr': False
'keep_gfr': False,
'standardize': True
}

if params:
Expand Down