From dc7bad9ab9a32929942bb27bb34eaa659e7d53b8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 4 Apr 2025 12:30:49 -0500 Subject: [PATCH 1/4] Updated C++ library version --- Doxyfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Doxyfile b/Doxyfile index 06c38d94..58837078 100644 --- a/Doxyfile +++ b/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 0.0.1 +PROJECT_NUMBER = 0.1.1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a From 8cf53ee9d98d67b08038f14ea8a8d5ba2e504e7e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 5 Apr 2025 01:26:30 -0500 Subject: [PATCH 2/4] Initial commit of warm-start interface in Python --- R/bart.R | 2 +- demo/debug/multi_chain.py | 147 ++++++++++++++++++++++++++++++++++++++ stochtree/bart.py | 70 ++++++++++++++++++ 3 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 demo/debug/multi_chain.py diff --git a/R/bart.R b/R/bart.R index ca717621..60c536b7 100644 --- a/R/bart.R +++ b/R/bart.R @@ -206,7 +206,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (previous_bart_model$model_params$include_mean_forest) { previous_forest_samples_mean <- previous_bart_model$mean_forests } else previous_forest_samples_mean <- NULL - if (previous_bart_model$model_params$include_mean_forest) { + if (previous_bart_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bart_model$variance_forests } else previous_forest_samples_variance <- NULL if (previous_bart_model$model_params$sample_sigma_global) { diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py new file mode 100644 index 00000000..d575f0eb --- /dev/null +++ b/demo/debug/multi_chain.py @@ -0,0 +1,147 @@ +# Multi Chain Demo Script + +# Load necessary libraries +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel + +# Generate sample data +# RNG +random_seed = 1234 +rng = np.random.default_rng(random_seed) + +# Generate covariates and basis +n = 500 +p_X = 10 +p_W = 1 +X = rng.uniform(0, 1, (n, p_X)) +W = rng.uniform(0, 1, (n, p_W)) + +# Define the outcome mean function +def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]), + ), + ) + +# Generate outcome +f_XW = outcome_mean(X, W) +epsilon = rng.normal(0, 1, n) +y = f_XW + epsilon + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split(sample_inds, test_size=0.5, random_state=random_seed) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +basis_train = W[train_inds, :] +basis_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] + +# Run the GFR algorithm for a small number of iterations +general_model_params = {"random_seed": -1} +mean_forest_model_params = {"num_trees": 20} +num_warmstart = 10 +num_mcmc = 10 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_warmstart, + num_mcmc=0, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params +) +bart_model_json = bart_model.to_json() + +# Run several BART MCMC samples from the last GFR forest +bart_model_2 = BARTModel() +bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=bart_model_json, + previous_model_warmstart_sample_num=num_warmstart-1, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params +) + +# Run several BART MCMC samples from the second-to-last GFR forest +bart_model_3 = BARTModel() +bart_model_3.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=bart_model_json, + previous_model_warmstart_sample_num=num_warmstart-2, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params +) + +# Run several BART MCMC samples from root +bart_model_4 = BARTModel() +bart_model_4.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params +) + +# Inspect the model outputs +y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test) +y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) +y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test) +y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) +y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test) +y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) +y_df = pd.DataFrame( + np.concatenate((y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), axis=1), + columns=["First Chain", "Second Chain", "Third Chain", "Outcome"], +) + +# Compare first warm-start chain to root chain with equal number of MCMC draws +sns.scatterplot(data=y_df, x="First Chain", y="Third Chain") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compare first warm-start chain to outcome +sns.scatterplot(data=y_df, x="First Chain", y="Outcome") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compare root chain to outcome +sns.scatterplot(data=y_df, x="Third Chain", y="Outcome") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compute RMSEs +rmse_1 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_2)-y_test)*(np.squeeze(y_avg_mcmc_2)-y_test))) +rmse_2 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_3)-y_test)*(np.squeeze(y_avg_mcmc_3)-y_test))) +rmse_3 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_4)-y_test)*(np.squeeze(y_avg_mcmc_4)-y_test))) +print("Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(rmse_1, rmse_2, rmse_3)) diff --git a/stochtree/bart.py b/stochtree/bart.py index 6608f576..4a409831 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -77,6 +77,8 @@ def sample( general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -154,6 +156,11 @@ def sample( * `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set here. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. + + previous_model_json : str, optional + JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. + previous_model_warmstart_sample_num : int, optional + Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`. Returns ------- @@ -612,6 +619,51 @@ def sample( else: variable_subset_variance = [i for i in range(X_train.shape[1])] + # Check if previous model JSON is provided and parse it if so + has_prev_model = previous_model_json is not None + if has_prev_model: + if num_gfr > 0: + if num_mcmc == 0: + raise ValueError("A previous model is being used to initialize this sampler, so `num_mcmc` must be greater than zero") + else: + warnings.warn("A previous model is being used to initialize this sampler, so num_gfr will be ignored and the MCMC sampler will be run from the previous samples") + previous_bart_model = BARTModel() + previous_bart_model.from_json(previous_model_json) + previous_y_bar = previous_bart_model.y_bar + previous_y_scale = previous_bart_model.y_std + previous_model_num_samples = previous_bart_model.num_samples + if previous_bart_model.include_mean_forest: + previous_forest_samples_mean = previous_bart_model.forest_container_mean + else: + previous_forest_samples_mean = None + if previous_bart_model.include_variance_forest: + previous_forest_samples_variance = previous_bart_model.forest_container_variance + else: + previous_forest_samples_variance = None + if previous_bart_model.sample_sigma_global: + previous_global_var_samples = previous_bart_model.global_var_samples / (previous_y_scale * previous_y_scale) + else: + previous_global_var_samples = None + if previous_bart_model.sample_sigma_leaf: + previous_leaf_var_samples = previous_bart_model.leaf_scale_samples + else: + previous_leaf_var_samples = None + if previous_bart_model.has_rfx: + previous_rfx_samples = previous_bart_model.rfx_container + else: + previous_rfx_samples = None + if previous_model_warmstart_sample_num + 1 > previous_model_num_samples: + raise ValueError("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + else: + previous_y_bar = None + previous_y_scale = None + previous_global_var_samples = None + previous_leaf_var_samples = None + previous_rfx_samples = None + previous_forest_samples_mean = None + previous_forest_samples_variance = None + previous_model_num_samples = 0 + # Update variable weights if the covariates have been resized (by e.g. one-hot encoding) if X_train_processed.shape[1] != X_train.shape[1]: variable_counts = [ @@ -992,6 +1044,22 @@ def sample( ) if sample_sigma_global: current_sigma2 = self.global_var_samples[forest_ind] + elif has_prev_model: + if self.include_mean_forest: + active_forest_mean.reset(previous_bart_model.forest_container_mean, previous_model_warmstart_sample_num) + forest_sampler_mean.reconstitute_from_forest(active_forest_mean, forest_dataset_train, residual_train, True) + if sample_sigma_leaf and previous_leaf_var_samples is not None: + leaf_scale_double = previous_leaf_var_samples[previous_model_warmstart_sample_num] + current_leaf_scale[0, 0] = leaf_scale_double + forest_model_config_mean.update_leaf_model_scale(leaf_scale_double) + if self.include_variance_forest: + active_forest_variance.reset(previous_bart_model.forest_container_variance, previous_model_warmstart_sample_num) + forest_sampler_variance.reconstitute_from_forest(active_forest_variance, forest_dataset_train, residual_train, True) + # if self.has_rfx: + # pass + if self.sample_sigma_global: + current_sigma2 = previous_global_var_samples[previous_model_warmstart_sample_num] + global_model_config.update_global_error_variance(current_sigma2) else: if self.include_mean_forest: active_forest_mean.reset_root() @@ -1069,12 +1137,14 @@ def sample( current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) + global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 if self.sample_sigma_leaf: current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( active_forest_mean, cpp_rng, a_leaf, b_leaf ) + forest_model_config_mean.update_leaf_model_scale(current_leaf_scale) if keep_sample: self.leaf_scale_samples[sample_counter] = ( current_leaf_scale[0, 0] From 294d01e957dba04d01bc4e71ea428f2b91d9814a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sat, 5 Apr 2025 01:31:33 -0500 Subject: [PATCH 3/4] Fixed numpy deprecation warning in bcf --- stochtree/bcf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 276c0f54..df6b1cd1 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1555,12 +1555,12 @@ def sample( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)), size=1, - ) + )[0] current_b_1 = self.rng.normal( loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), size=1, - ) + )[0] tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 @@ -1703,12 +1703,12 @@ def sample( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)), size=1, - ) + )[0] current_b_1 = self.rng.normal( loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), size=1, - ) + )[0] tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 From daa9243cb11e1b996276497d8dcc82be3e60fa31 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 9 Apr 2025 00:21:23 -0500 Subject: [PATCH 4/4] Updated python interface to support parallel multi-chain --- R/bart.R | 2 +- demo/debug/multi_chain.py | 51 ++- demo/debug/parallel_multi_chain.py | 177 +++++++++ src/py_stochtree.cpp | 15 + stochtree/bart.py | 398 +++++++++++++++---- stochtree/bcf.py | 353 +++++++++++++---- stochtree/forest.py | 32 +- stochtree/random_effects.py | 32 +- test/python/test_bart.py | 601 ++++++++++++++++++++++++++++- test/python/test_bcf.py | 108 ++++++ 10 files changed, 1560 insertions(+), 209 deletions(-) create mode 100644 demo/debug/parallel_multi_chain.py diff --git a/R/bart.R b/R/bart.R index 60c536b7..96815850 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1853,7 +1853,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata") output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py index d575f0eb..6d8aef68 100644 --- a/demo/debug/multi_chain.py +++ b/demo/debug/multi_chain.py @@ -21,6 +21,7 @@ X = rng.uniform(0, 1, (n, p_X)) W = rng.uniform(0, 1, (n, p_W)) + # Define the outcome mean function def outcome_mean(X, W): return np.where( @@ -33,6 +34,7 @@ def outcome_mean(X, W): ), ) + # Generate outcome f_XW = outcome_mean(X, W) epsilon = rng.normal(0, 1, n) @@ -40,7 +42,9 @@ def outcome_mean(X, W): # Test-train split sample_inds = np.arange(n) -train_inds, test_inds = train_test_split(sample_inds, test_size=0.5, random_state=random_seed) +train_inds, test_inds = train_test_split( + sample_inds, test_size=0.5, random_state=random_seed +) X_train = X[train_inds, :] X_test = X[test_inds, :] basis_train = W[train_inds, :] @@ -61,9 +65,9 @@ def outcome_mean(X, W): X_test=X_test, leaf_basis_test=basis_test, num_gfr=num_warmstart, - num_mcmc=0, - general_params=general_model_params, - mean_forest_params=mean_forest_model_params + num_mcmc=0, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, ) bart_model_json = bart_model.to_json() @@ -78,9 +82,9 @@ def outcome_mean(X, W): num_gfr=0, num_mcmc=num_mcmc, previous_model_json=bart_model_json, - previous_model_warmstart_sample_num=num_warmstart-1, - general_params=general_model_params, - mean_forest_params=mean_forest_model_params + previous_model_warmstart_sample_num=num_warmstart - 1, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, ) # Run several BART MCMC samples from the second-to-last GFR forest @@ -94,9 +98,9 @@ def outcome_mean(X, W): num_gfr=0, num_mcmc=num_mcmc, previous_model_json=bart_model_json, - previous_model_warmstart_sample_num=num_warmstart-2, - general_params=general_model_params, - mean_forest_params=mean_forest_model_params + previous_model_warmstart_sample_num=num_warmstart - 2, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, ) # Run several BART MCMC samples from root @@ -109,8 +113,8 @@ def outcome_mean(X, W): leaf_basis_test=basis_test, num_gfr=0, num_mcmc=num_mcmc, - general_params=general_model_params, - mean_forest_params=mean_forest_model_params + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, ) # Inspect the model outputs @@ -121,7 +125,10 @@ def outcome_mean(X, W): y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test) y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) y_df = pd.DataFrame( - np.concatenate((y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), axis=1), + np.concatenate( + (y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), + axis=1, + ), columns=["First Chain", "Second Chain", "Third Chain", "Outcome"], ) @@ -141,7 +148,17 @@ def outcome_mean(X, W): plt.show() # Compute RMSEs -rmse_1 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_2)-y_test)*(np.squeeze(y_avg_mcmc_2)-y_test))) -rmse_2 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_3)-y_test)*(np.squeeze(y_avg_mcmc_3)-y_test))) -rmse_3 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_4)-y_test)*(np.squeeze(y_avg_mcmc_4)-y_test))) -print("Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(rmse_1, rmse_2, rmse_3)) +rmse_1 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test)) +) +rmse_2 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test)) +) +rmse_3 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test)) +) +print( + "Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format( + rmse_1, rmse_2, rmse_3 + ) +) diff --git a/demo/debug/parallel_multi_chain.py b/demo/debug/parallel_multi_chain.py new file mode 100644 index 00000000..ee114aee --- /dev/null +++ b/demo/debug/parallel_multi_chain.py @@ -0,0 +1,177 @@ +# Multi Chain Demo Script + +# Load necessary libraries +from multiprocessing import Pool, cpu_count + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel + + +def fit_bart( + model_string, + X_train, + y_train, + basis_train, + X_test, + basis_test, + num_mcmc, + gen_param_list, + mean_list, + i, +): + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=model_string, + previous_model_warmstart_sample_num=i, + general_params=gen_param_list, + mean_forest_params=mean_list, + ) + return (bart_model.to_json(), bart_model.y_hat_test) + + +def bart_warmstart_parallel(X_train, y_train, basis_train, X_test, basis_test): + # Run the GFR algorithm for a small number of iterations + general_model_params = {"random_seed": -1} + mean_forest_model_params = {"num_trees": 100} + num_warmstart = 10 + num_mcmc = 100 + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_warmstart, + num_mcmc=0, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, + ) + bart_model_json = bart_model.to_json() + + # Warm-start multiple BART fits from a different GFR forest + process_tasks = [ + ( + bart_model_json, + X_train, + y_train, + basis_train, + X_test, + basis_test, + num_mcmc, + general_model_params, + mean_forest_model_params, + i, + ) + for i in range(4) + ] + num_processes = cpu_count() + with Pool(processes=num_processes) as pool: + results = pool.starmap(fit_bart, process_tasks) + + # Extract separate outputs as separate lists + bart_model_json_list, bart_model_pred_list = zip(*results) + + # Process results + combined_bart_model = BARTModel() + combined_bart_model.from_json_string_list(bart_model_json_list) + combined_bart_preds = bart_model_pred_list[0] + for i in range(1, len(bart_model_pred_list)): + combined_bart_preds = np.concatenate( + (combined_bart_preds, bart_model_pred_list[i]), axis=1 + ) + + return (combined_bart_model, combined_bart_preds) + + +if __name__ == "__main__": + # RNG + random_seed = 1234 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 1000 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0] + ), + ), + ) + + # Generate outcome + f_XW = outcome_mean(X, W) + epsilon = rng.normal(0, 1, n) + y = f_XW + epsilon + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split( + sample_inds, test_size=0.2, random_state=random_seed + ) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + basis_train = W[train_inds, :] + basis_test = W[test_inds, :] + y_train = y[train_inds] + y_test = y[test_inds] + + # Run the parallel BART + combined_bart, combined_bart_preds = bart_warmstart_parallel( + X_train, y_train, basis_train, X_test, basis_test + ) + + # Inspect the model outputs + y_hat_mcmc = combined_bart.predict(X_test, basis_test) + y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True) + y_df = pd.DataFrame( + np.concatenate((y_avg_mcmc, np.expand_dims(y_test, axis=1)), axis=1), + columns=["Average BART Predictions", "Outcome"], + ) + + # Compare first warm-start chain to outcome + sns.scatterplot(data=y_df, x="Average BART Predictions", y="Outcome") + plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) + plt.show() + + # Compare cached predictions to deserialized predictions for first chain + chain_index = 0 + num_mcmc = 100 + offset_index = num_mcmc * chain_index + chain_inds = slice(offset_index, (offset_index + num_mcmc)) + chain_1_preds_original = np.squeeze(combined_bart_preds[chain_inds]).mean( + axis=1, keepdims=True + ) + chain_1_preds_reloaded = np.squeeze(y_hat_mcmc[chain_inds]).mean( + axis=1, keepdims=True + ) + chain_df = pd.DataFrame( + np.concatenate((chain_1_preds_reloaded, chain_1_preds_original), axis=1), + columns=["New Predictions", "Original Predictions"], + ) + sns.scatterplot(data=chain_df, x="New Predictions", y="Original Predictions") + plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) + plt.show() diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 65f8c927..8c73fd59 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -325,6 +325,8 @@ class ForestContainerCpp { void LoadFromJson(JsonCpp& json, std::string forest_label); + void AppendFromJson(JsonCpp& json, std::string forest_label); + std::string DumpJsonString() { return forest_samples_->DumpJsonString(); } @@ -1289,6 +1291,7 @@ class RandomEffectsContainerCpp { rfx_container_->LoadFromJsonString(json_string); } void LoadFromJson(JsonCpp& json, std::string rfx_container_label); + void AppendFromJson(JsonCpp& json, std::string rfx_container_label); StochTree::RandomEffectsContainer* GetRandomEffectsContainer() { return rfx_container_.get(); } @@ -1870,6 +1873,11 @@ void ForestContainerCpp::LoadFromJson(JsonCpp& json, std::string forest_label) { forest_samples_->from_json(forest_json); } +void ForestContainerCpp::AppendFromJson(JsonCpp& json, std::string forest_label) { + nlohmann::json forest_json = json.SubsetJsonForest(forest_label); + forest_samples_->append_from_json(forest_json); +} + void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) { // Determine whether or not we are adding forest_num to the residuals std::function op; @@ -1896,6 +1904,11 @@ void RandomEffectsContainerCpp::LoadFromJson(JsonCpp& json, std::string rfx_cont rfx_container_->from_json(rfx_json); } +void RandomEffectsContainerCpp::AppendFromJson(JsonCpp& json, std::string rfx_container_label) { + nlohmann::json rfx_json = json.SubsetJsonRFX().at(rfx_container_label); + rfx_container_->append_from_json(rfx_json); +} + void RandomEffectsContainerCpp::AddSample(RandomEffectsModelCpp& rfx_model) { rfx_container_->AddSample(*rfx_model.GetModel()); } @@ -2012,6 +2025,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile) .def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile) .def("LoadFromJson", &ForestContainerCpp::LoadFromJson) + .def("AppendFromJson", &ForestContainerCpp::AppendFromJson) .def("DumpJsonString", &ForestContainerCpp::DumpJsonString) .def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString) .def("AddSampleValue", &ForestContainerCpp::AddSampleValue) @@ -2125,6 +2139,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("DumpJsonString", &RandomEffectsContainerCpp::DumpJsonString) .def("LoadFromJsonString", &RandomEffectsContainerCpp::LoadFromJsonString) .def("LoadFromJson", &RandomEffectsContainerCpp::LoadFromJson) + .def("AppendFromJson", &RandomEffectsContainerCpp::AppendFromJson) .def("GetRandomEffectsContainer", &RandomEffectsContainerCpp::GetRandomEffectsContainer); py::class_(m, "RandomEffectsTrackerCpp") diff --git a/stochtree/bart.py b/stochtree/bart.py index 4a409831..c8943c85 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -14,7 +14,12 @@ from .data import Dataset, Residual from .forest import Forest, ForestContainer from .preprocessing import CovariatePreprocessor, _preprocess_params -from .random_effects import RandomEffectsContainer, RandomEffectsDataset, RandomEffectsModel, RandomEffectsTracker +from .random_effects import ( + RandomEffectsContainer, + RandomEffectsDataset, + RandomEffectsModel, + RandomEffectsTracker, +) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -58,7 +63,6 @@ class BARTModel: def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False - self.rng = np.random.default_rng() def sample( self, @@ -101,7 +105,7 @@ def sample( Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if leaf_basis_train is provided, then leaf_basis_test must be provided alongside X_test). rfx_group_ids_test : np.array, optional - Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), + Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. @@ -156,10 +160,10 @@ def sample( * `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set here. * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. - + previous_model_json : str, optional JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. - previous_model_warmstart_sample_num : int, optional + previous_model_warmstart_sample_num : int, optional Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`. Returns @@ -251,22 +255,24 @@ def sample( drop_vars_mean = mean_forest_params_updated["drop_vars"] # 3. Variance forest parameters - num_trees_variance = variance_forest_params_updated['num_trees'] - alpha_variance = variance_forest_params_updated['alpha'] - beta_variance = variance_forest_params_updated['beta'] - min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] - max_depth_variance = variance_forest_params_updated['max_depth'] - a_0 = variance_forest_params_updated['leaf_prior_calibration_param'] - variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] - a_forest = variance_forest_params_updated['var_forest_prior_shape'] - b_forest = variance_forest_params_updated['var_forest_prior_scale'] - keep_vars_variance = variance_forest_params_updated['keep_vars'] - drop_vars_variance = variance_forest_params_updated['drop_vars'] - + num_trees_variance = variance_forest_params_updated["num_trees"] + alpha_variance = variance_forest_params_updated["alpha"] + beta_variance = variance_forest_params_updated["beta"] + min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"] + max_depth_variance = variance_forest_params_updated["max_depth"] + a_0 = variance_forest_params_updated["leaf_prior_calibration_param"] + variance_forest_leaf_init = variance_forest_params_updated[ + "var_forest_leaf_init" + ] + a_forest = variance_forest_params_updated["var_forest_prior_shape"] + b_forest = variance_forest_params_updated["var_forest_prior_scale"] + keep_vars_variance = variance_forest_params_updated["keep_vars"] + drop_vars_variance = variance_forest_params_updated["drop_vars"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True - + # Check that num_chains >= 1 if not isinstance(num_chains, Integral) or num_chains < 1: raise ValueError("num_chains must be an integer greater than 0") @@ -304,7 +310,9 @@ def sample( if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") if not np.issubdtype(rfx_group_ids_train.dtype, np.integer): - raise ValueError("rfx_group_ids_train must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_train must be a numpy array of integer-valued group IDs" + ) if rfx_basis_train is not None: if not isinstance(rfx_basis_train, np.ndarray): raise ValueError("rfx_basis_train must be a numpy array") @@ -312,11 +320,13 @@ def sample( if not isinstance(rfx_group_ids_test, np.ndarray): raise ValueError("rfx_group_ids_test must be a numpy array") if not np.issubdtype(rfx_group_ids_test.dtype, np.integer): - raise ValueError("rfx_group_ids_test must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_test must be a numpy array of integer-valued group IDs" + ) if rfx_basis_test is not None: if not isinstance(rfx_basis_test, np.ndarray): raise ValueError("rfx_basis_test must be a numpy array") - + # Convert everything to standard shape (2-dimensional) if isinstance(X_train, np.ndarray): if X_train.ndim == 1: @@ -359,7 +369,9 @@ def sample( "leaf_basis_train and leaf_basis_test must have the same number of columns" ) else: - raise ValueError("leaf_basis_test provided but leaf_basis_train was not") + raise ValueError( + "leaf_basis_test provided but leaf_basis_train was not" + ) if leaf_basis_train is not None: if leaf_basis_train.shape[0] != X_train.shape[0]: raise ValueError( @@ -624,9 +636,13 @@ def sample( if has_prev_model: if num_gfr > 0: if num_mcmc == 0: - raise ValueError("A previous model is being used to initialize this sampler, so `num_mcmc` must be greater than zero") + raise ValueError( + "A previous model is being used to initialize this sampler, so `num_mcmc` must be greater than zero" + ) else: - warnings.warn("A previous model is being used to initialize this sampler, so num_gfr will be ignored and the MCMC sampler will be run from the previous samples") + warnings.warn( + "A previous model is being used to initialize this sampler, so num_gfr will be ignored and the MCMC sampler will be run from the previous samples" + ) previous_bart_model = BARTModel() previous_bart_model.from_json(previous_model_json) previous_y_bar = previous_bart_model.y_bar @@ -637,11 +653,15 @@ def sample( else: previous_forest_samples_mean = None if previous_bart_model.include_variance_forest: - previous_forest_samples_variance = previous_bart_model.forest_container_variance + previous_forest_samples_variance = ( + previous_bart_model.forest_container_variance + ) else: previous_forest_samples_variance = None if previous_bart_model.sample_sigma_global: - previous_global_var_samples = previous_bart_model.global_var_samples / (previous_y_scale * previous_y_scale) + previous_global_var_samples = previous_bart_model.global_var_samples / ( + previous_y_scale * previous_y_scale + ) else: previous_global_var_samples = None if previous_bart_model.sample_sigma_leaf: @@ -653,7 +673,9 @@ def sample( else: previous_rfx_samples = None if previous_model_warmstart_sample_num + 1 > previous_model_num_samples: - raise ValueError("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + raise ValueError( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) else: previous_y_bar = None previous_y_scale = None @@ -663,7 +685,7 @@ def sample( previous_forest_samples_mean = None previous_forest_samples_variance = None previous_model_num_samples = 0 - + # Update variable weights if the covariates have been resized (by e.g. one-hot encoding) if X_train_processed.shape[1] != X_train.shape[1]: variable_counts = [ @@ -713,35 +735,56 @@ def sample( if self.has_basis: if sigma_leaf is None: current_leaf_scale = np.zeros((self.num_basis, self.num_basis)) - np.fill_diagonal(current_leaf_scale, np.squeeze(np.var(resid_train)) / num_trees_mean) + np.fill_diagonal( + current_leaf_scale, + np.squeeze(np.var(resid_train)) / num_trees_mean, + ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.zeros((self.num_basis, self.num_basis)) np.fill_diagonal(current_leaf_scale, sigma_leaf) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != sigma_leaf.shape[1]: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != self.num_basis: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) current_leaf_scale = sigma_leaf else: - raise ValueError("sigma_leaf must be either a scalar or a 2d symmetric numpy array") + raise ValueError( + "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + ) else: if sigma_leaf is None: - current_leaf_scale = np.array([[np.squeeze(np.var(resid_train)) / num_trees_mean]]) + current_leaf_scale = np.array( + [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.array([[sigma_leaf]]) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != sigma_leaf.shape[1]: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != 1: - raise ValueError("sigma_leaf must be a 1x1 numpy array for this leaf model") + raise ValueError( + "sigma_leaf must be a 1x1 numpy array for this leaf model" + ) current_leaf_scale = sigma_leaf else: - raise ValueError("sigma_leaf must be either a scalar or a 2d numpy array") + raise ValueError( + "sigma_leaf must be either a scalar or a 2d numpy array" + ) else: current_leaf_scale = np.array([[1.0]]) if self.include_variance_forest: @@ -754,7 +797,7 @@ def sample( a_forest = 1.0 if not b_forest: b_forest = 1.0 - + # Runtime checks on RFX group ids self.has_rfx = False has_rfx_test = False @@ -763,13 +806,15 @@ def sample( if rfx_group_ids_test is not None: has_rfx_test = True if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx = False if self.has_rfx: if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0],1)) + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) else: has_basis_rfx = True num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] @@ -778,22 +823,26 @@ def sample( if has_rfx_test: if rfx_basis_test is None: if has_basis_rfx: - raise ValueError("Random effects basis provided for training set, must also be provided for the test set") - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0],1)) - + raise ValueError( + "Random effects basis provided for training set, must also be provided for the test set" + ) + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + # Set up random effects structures if self.has_rfx: if num_rfx_components == 1: alpha_init = np.array([1]) elif num_rfx_components > 1: - alpha_init = np.concatenate((np.ones(1), np.zeros(num_rfx_components-1))) + alpha_init = np.concatenate( + (np.ones(1), np.zeros(num_rfx_components - 1)) + ) else: raise ValueError("There must be at least 1 random effect component") xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) sigma_alpha_init = np.identity(num_rfx_components) sigma_xi_init = np.identity(num_rfx_components) - sigma_xi_shape = 1. - sigma_xi_scale = 1. + sigma_xi_shape = 1.0 + sigma_xi_scale = 1.0 rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) rfx_dataset_train.add_basis(rfx_basis_train) @@ -806,7 +855,9 @@ def sample( rfx_model.set_variance_prior_shape(sigma_xi_shape) rfx_model.set_variance_prior_scale(sigma_xi_scale) self.rfx_container = RandomEffectsContainer() - self.rfx_container.load_new_container(num_rfx_components, num_rfx_groups, rfx_tracker) + self.rfx_container.load_new_container( + num_rfx_components, num_rfx_groups, rfx_tracker + ) # Container of variance parameter samples self.num_gfr = num_gfr @@ -1012,11 +1063,17 @@ def sample( self.leaf_scale_samples[sample_counter] = current_leaf_scale[ 0, 0 ] - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Run MCMC @@ -1046,19 +1103,41 @@ def sample( current_sigma2 = self.global_var_samples[forest_ind] elif has_prev_model: if self.include_mean_forest: - active_forest_mean.reset(previous_bart_model.forest_container_mean, previous_model_warmstart_sample_num) - forest_sampler_mean.reconstitute_from_forest(active_forest_mean, forest_dataset_train, residual_train, True) + active_forest_mean.reset( + previous_bart_model.forest_container_mean, + previous_model_warmstart_sample_num, + ) + forest_sampler_mean.reconstitute_from_forest( + active_forest_mean, + forest_dataset_train, + residual_train, + True, + ) if sample_sigma_leaf and previous_leaf_var_samples is not None: - leaf_scale_double = previous_leaf_var_samples[previous_model_warmstart_sample_num] + leaf_scale_double = previous_leaf_var_samples[ + previous_model_warmstart_sample_num + ] current_leaf_scale[0, 0] = leaf_scale_double - forest_model_config_mean.update_leaf_model_scale(leaf_scale_double) + forest_model_config_mean.update_leaf_model_scale( + leaf_scale_double + ) if self.include_variance_forest: - active_forest_variance.reset(previous_bart_model.forest_container_variance, previous_model_warmstart_sample_num) - forest_sampler_variance.reconstitute_from_forest(active_forest_variance, forest_dataset_train, residual_train, True) + active_forest_variance.reset( + previous_bart_model.forest_container_variance, + previous_model_warmstart_sample_num, + ) + forest_sampler_variance.reconstitute_from_forest( + active_forest_variance, + forest_dataset_train, + residual_train, + True, + ) # if self.has_rfx: # pass if self.sample_sigma_global: - current_sigma2 = previous_global_var_samples[previous_model_warmstart_sample_num] + current_sigma2 = previous_global_var_samples[ + previous_model_warmstart_sample_num + ] global_model_config.update_global_error_variance(current_sigma2) else: if self.include_mean_forest: @@ -1144,16 +1223,24 @@ def sample( current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( active_forest_mean, cpp_rng, a_leaf, b_leaf ) - forest_model_config_mean.update_leaf_model_scale(current_leaf_scale) + forest_model_config_mean.update_leaf_model_scale( + current_leaf_scale + ) if keep_sample: self.leaf_scale_samples[sample_counter] = ( current_leaf_scale[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Mark the model as sampled @@ -1191,12 +1278,18 @@ def sample( forest_dataset_test.dataset_cpp ) self.y_hat_test = yhat_test_raw * self.y_std + self.y_bar - + # TODO: make rfx_preds_train and rfx_preds_test persistent properties if self.has_rfx: - rfx_preds_train = self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) * self.y_std + rfx_preds_train = ( + self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) + * self.y_std + ) if has_rfx_test: - rfx_preds_test = self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) * self.y_std + rfx_preds_test = ( + self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) + * self.y_std + ) if self.include_mean_forest: self.y_hat_train = self.y_hat_train + rfx_preds_train if self.has_test: @@ -1240,8 +1333,11 @@ def sample( ) def predict( - self, covariates: Union[np.array, pd.DataFrame], basis: np.array = None, - rfx_group_ids: np.array = None, rfx_basis: np.array = None + self, + covariates: Union[np.array, pd.DataFrame], + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1326,9 +1422,11 @@ def predict( pred_dataset.dataset_cpp ) mean_pred = mean_pred_raw * self.y_std + self.y_bar - + if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) if self.include_mean_forest: mean_pred = mean_pred + rfx_preds else: @@ -1360,8 +1458,11 @@ def predict( return variance_pred def predict_mean( - self, covariates: np.array, basis: np.array = None, - rfx_group_ids: np.array = None, rfx_basis: np.array = None + self, + covariates: np.array, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, ) -> np.array: """Predict expected conditional outcome from a BART model. @@ -1449,7 +1550,9 @@ def predict_mean( # RFX predictions if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) if self.include_mean_forest: mean_pred = mean_pred + rfx_preds else: @@ -1577,11 +1680,11 @@ def to_json(self) -> str: bart_json.add_boolean("include_mean_forest", self.include_mean_forest) bart_json.add_boolean("include_variance_forest", self.include_variance_forest) bart_json.add_boolean("has_rfx", self.has_rfx) - bart_json.add_scalar("num_gfr", self.num_gfr) - bart_json.add_scalar("num_burnin", self.num_burnin) - bart_json.add_scalar("num_mcmc", self.num_mcmc) - bart_json.add_scalar("num_samples", self.num_samples) - bart_json.add_scalar("num_basis", self.num_basis) + bart_json.add_integer("num_gfr", self.num_gfr) + bart_json.add_integer("num_burnin", self.num_burnin) + bart_json.add_integer("num_mcmc", self.num_mcmc) + bart_json.add_integer("num_samples", self.num_samples) + bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) # Add parameter samples @@ -1635,7 +1738,7 @@ def from_json(self, json_string: str) -> None: self.forest_container_variance.forest_container_cpp.LoadFromJson( bart_json.json_cpp, "forest_0" ) - + # Unpack random effects if self.has_rfx: self.rfx_container = RandomEffectsContainer() @@ -1648,11 +1751,11 @@ def from_json(self, json_string: str) -> None: self.sigma2_init = bart_json.get_scalar("sigma2_init") self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global") self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf") - self.num_gfr = bart_json.get_scalar("num_gfr") - self.num_burnin = bart_json.get_scalar("num_burnin") - self.num_mcmc = bart_json.get_scalar("num_mcmc") - self.num_samples = bart_json.get_scalar("num_samples") - self.num_basis = bart_json.get_scalar("num_basis") + self.num_gfr = bart_json.get_integer("num_gfr") + self.num_burnin = bart_json.get_integer("num_burnin") + self.num_mcmc = bart_json.get_integer("num_mcmc") + self.num_samples = bart_json.get_integer("num_samples") + self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") # Unpack parameter samples @@ -1673,6 +1776,135 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + def from_json_string_list(self, json_string_list: list[str]) -> None: + """ + Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object + which can be used for prediction, etc... + + Parameters + ------- + json_string_list : list of str + List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BART model + """ + # Convert strings to JSONSerializer + json_object_list = [] + for i in range(len(json_string_list)): + json_string = json_string_list[i] + json_object_list.append(JSONSerializer()) + json_object_list[i].load_from_json_string(json_string) + + # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + json_object_default = json_object_list[0] + + # Unpack forests + self.include_mean_forest = json_object_default.get_boolean( + "include_mean_forest" + ) + self.include_variance_forest = json_object_default.get_boolean( + "include_variance_forest" + ) + if self.include_mean_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_mean = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_mean.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + else: + self.forest_container_mean.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + if self.include_variance_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + + # Unpack random effects + self.has_rfx = json_object_default.get_boolean("has_rfx") + if self.has_rfx: + self.rfx_container = RandomEffectsContainer() + for i in range(len(json_object_list)): + if i == 0: + self.rfx_container.load_from_json(json_object_list[i], 0) + else: + self.rfx_container.append_from_json(json_object_list[i], 0) + + # Unpack global parameters + self.y_std = json_object_default.get_scalar("outcome_scale") + self.y_bar = json_object_default.get_scalar("outcome_mean") + self.standardize = json_object_default.get_boolean("standardize") + self.sigma2_init = json_object_default.get_scalar("sigma2_init") + self.sample_sigma_global = json_object_default.get_boolean( + "sample_sigma_global" + ) + self.sample_sigma_leaf = json_object_default.get_boolean("sample_sigma_leaf") + self.num_gfr = json_object_default.get_integer("num_gfr") + self.num_burnin = json_object_default.get_integer("num_burnin") + self.num_mcmc = json_object_default.get_integer("num_mcmc") + self.num_samples = json_object_default.get_integer("num_samples") + self.num_basis = json_object_default.get_integer("num_basis") + self.has_basis = json_object_default.get_boolean("requires_basis") + + # Unpack parameter samples + if self.sample_sigma_global: + for i in range(len(json_object_list)): + if i == 0: + self.global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + else: + global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + self.global_var_samples = np.concatenate( + (self.global_var_samples, global_var_samples) + ) + + if self.sample_sigma_leaf: + for i in range(len(json_object_list)): + if i == 0: + self.leaf_scale_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_samples", "parameters" + ) + else: + leaf_scale_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_samples", "parameters" + ) + self.leaf_scale_samples = np.concatenate( + (self.leaf_scale_samples, leaf_scale_samples) + ) + + # Unpack covariate preprocessor + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # Mark the deserialized model as "sampled" + self.sampled = True + def is_sampled(self) -> bool: """Whether or not a BART model has been sampled. diff --git a/stochtree/bcf.py b/stochtree/bcf.py index df6b1cd1..37dc80ec 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -14,7 +14,12 @@ from .data import Dataset, Residual from .forest import Forest, ForestContainer from .preprocessing import CovariatePreprocessor, _preprocess_params -from .random_effects import RandomEffectsContainer, RandomEffectsDataset, RandomEffectsModel, RandomEffectsTracker +from .random_effects import ( + RandomEffectsContainer, + RandomEffectsDataset, + RandomEffectsModel, + RandomEffectsTracker, +) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -115,7 +120,7 @@ def sample( pi_test : np.array, optional Optional test set vector of propensity scores. If not provided (but `X_test` and `Z_test` are), this will be estimated from the data. rfx_group_ids_test : np.array, optional - Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), + Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. @@ -310,9 +315,13 @@ def sample( num_trees_tau = treatment_effect_forest_params_updated["num_trees"] alpha_tau = treatment_effect_forest_params_updated["alpha"] beta_tau = treatment_effect_forest_params_updated["beta"] - min_samples_leaf_tau = treatment_effect_forest_params_updated["min_samples_leaf"] + min_samples_leaf_tau = treatment_effect_forest_params_updated[ + "min_samples_leaf" + ] max_depth_tau = treatment_effect_forest_params_updated["max_depth"] - sample_sigma_leaf_tau = treatment_effect_forest_params_updated["sample_sigma2_leaf"] + sample_sigma_leaf_tau = treatment_effect_forest_params_updated[ + "sample_sigma2_leaf" + ] sigma_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_init"] a_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_shape"] b_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_scale"] @@ -320,22 +329,24 @@ def sample( drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"] # 4. Variance forest parameters - num_trees_variance = variance_forest_params_updated['num_trees'] - alpha_variance = variance_forest_params_updated['alpha'] - beta_variance = variance_forest_params_updated['beta'] - min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] - max_depth_variance = variance_forest_params_updated['max_depth'] - a_0 = variance_forest_params_updated['leaf_prior_calibration_param'] - variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] - a_forest = variance_forest_params_updated['var_forest_prior_shape'] - b_forest = variance_forest_params_updated['var_forest_prior_scale'] - keep_vars_variance = variance_forest_params_updated['keep_vars'] - drop_vars_variance = variance_forest_params_updated['drop_vars'] - + num_trees_variance = variance_forest_params_updated["num_trees"] + alpha_variance = variance_forest_params_updated["alpha"] + beta_variance = variance_forest_params_updated["beta"] + min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"] + max_depth_variance = variance_forest_params_updated["max_depth"] + a_0 = variance_forest_params_updated["leaf_prior_calibration_param"] + variance_forest_leaf_init = variance_forest_params_updated[ + "var_forest_leaf_init" + ] + a_forest = variance_forest_params_updated["var_forest_prior_shape"] + b_forest = variance_forest_params_updated["var_forest_prior_scale"] + keep_vars_variance = variance_forest_params_updated["keep_vars"] + drop_vars_variance = variance_forest_params_updated["drop_vars"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True - + # Variable weight preprocessing (and initialization if necessary) if variable_weights is None: if X_train.ndim > 1: @@ -378,7 +389,9 @@ def sample( if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") if not np.issubdtype(rfx_group_ids_train.dtype, np.integer): - raise ValueError("rfx_group_ids_train must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_train must be a numpy array of integer-valued group IDs" + ) if rfx_basis_train is not None: if not isinstance(rfx_basis_train, np.ndarray): raise ValueError("rfx_basis_train must be a numpy array") @@ -386,7 +399,9 @@ def sample( if not isinstance(rfx_group_ids_test, np.ndarray): raise ValueError("rfx_group_ids_test must be a numpy array") if not np.issubdtype(rfx_group_ids_test.dtype, np.integer): - raise ValueError("rfx_group_ids_test must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_test must be a numpy array of integer-valued group IDs" + ) if rfx_basis_test is not None: if not isinstance(rfx_basis_test, np.ndarray): raise ValueError("rfx_basis_test must be a numpy array") @@ -470,7 +485,6 @@ def sample( # Set variance leaf model type (currently only one option) leaf_dimension_variance = 1 leaf_model_variance = 3 - self.variance_scale = 1 # Check parameters if sigma_leaf_tau is not None: @@ -1155,11 +1169,17 @@ def sample( current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) elif isinstance(sigma_leaf_tau, np.ndarray): if sigma_leaf_tau.ndim != 2: - raise ValueError("sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: - raise ValueError("sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf_tau.shape[0] != Z_train.shape[1]: - raise ValueError("sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector") + raise ValueError( + "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + ) current_leaf_scale_tau = sigma_leaf_tau else: raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") @@ -1173,7 +1193,7 @@ def sample( a_forest = 1.0 if not b_forest: b_forest = 1.0 - + # Runtime checks on RFX group ids self.has_rfx = False has_rfx_test = False @@ -1182,13 +1202,15 @@ def sample( if rfx_group_ids_test is not None: has_rfx_test = True if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx = False if self.has_rfx: if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0],1)) + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) else: has_basis_rfx = True num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] @@ -1197,22 +1219,26 @@ def sample( if has_rfx_test: if rfx_basis_test is None: if has_basis_rfx: - raise ValueError("Random effects basis provided for training set, must also be provided for the test set") - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0],1)) - + raise ValueError( + "Random effects basis provided for training set, must also be provided for the test set" + ) + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + # Set up random effects structures if self.has_rfx: if num_rfx_components == 1: alpha_init = np.array([1]) elif num_rfx_components > 1: - alpha_init = np.concatenate((np.ones(1), np.zeros(num_rfx_components-1))) + alpha_init = np.concatenate( + (np.ones(1), np.zeros(num_rfx_components - 1)) + ) else: raise ValueError("There must be at least 1 random effect component") xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) sigma_alpha_init = np.identity(num_rfx_components) sigma_xi_init = np.identity(num_rfx_components) - sigma_xi_shape = 1. - sigma_xi_scale = 1. + sigma_xi_shape = 1.0 + sigma_xi_scale = 1.0 rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) rfx_dataset_train.add_basis(rfx_basis_train) @@ -1225,7 +1251,9 @@ def sample( rfx_model.set_variance_prior_shape(sigma_xi_shape) rfx_model.set_variance_prior_scale(sigma_xi_scale) self.rfx_container = RandomEffectsContainer() - self.rfx_container.load_new_container(num_rfx_components, num_rfx_groups, rfx_tracker) + self.rfx_container.load_new_container( + num_rfx_components, num_rfx_groups, rfx_tracker + ) # Update variable weights variable_counts = [original_var_indices.count(i) for i in original_var_indices] @@ -1614,11 +1642,17 @@ def sample( self.leaf_scale_tau_samples[sample_counter] = ( current_leaf_scale_tau[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Run MCMC @@ -1762,11 +1796,17 @@ def sample( self.leaf_scale_tau_samples[sample_counter] = ( current_leaf_scale_tau[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Mark the model as sampled @@ -1839,13 +1879,19 @@ def sample( # TODO: make rfx_preds_train and rfx_preds_test persistent properties if self.has_rfx: - rfx_preds_train = self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) * self.y_std + rfx_preds_train = ( + self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) + * self.y_std + ) if has_rfx_test: - rfx_preds_test = self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) * self.y_std + rfx_preds_test = ( + self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) + * self.y_std + ) self.y_hat_train = self.y_hat_train + rfx_preds_train if self.has_test: self.y_hat_test = self.y_hat_test + rfx_preds_test - + if self.include_variance_forest: sigma2_x_train_raw = ( self.forest_container_variance.forest_container_cpp.Predict( @@ -1860,11 +1906,7 @@ def sample( ) else: self.sigma2_x_train = ( - sigma2_x_train_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -1880,11 +1922,7 @@ def sample( ) else: self.sigma2_x_test = ( - sigma2_x_test_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) if self.sample_sigma_global: @@ -1970,16 +2008,16 @@ def predict_tau( "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", RuntimeWarning, ) - if not np.issubdtype( - X.dtype, np.floating - ) and not np.issubdtype(X.dtype, np.integer): + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) covariates_processed = X else: covariates_processed = self._covariate_preprocessor.transform(X) - + # Update covariates to include propensities if requested if self.propensity_covariate == "none": X_combined = covariates_processed @@ -2064,7 +2102,7 @@ def predict_variance( covariates_processed = covariates else: covariates_processed = self._covariate_preprocessor.transform(covariates) - + # Update covariates to include propensities if requested if self.propensity_covariate == "none": X_combined = covariates_processed @@ -2092,16 +2130,19 @@ def predict_variance( ) else: variance_pred = ( - variance_pred_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) return variance_pred - def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> tuple: + def predict( + self, + X: np.array, + Z: np.array, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> tuple: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2167,7 +2208,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro propensity = np.mean( self.bart_propensity_model.predict(X), axis=1, keepdims=True ) - + # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): if not isinstance(X, np.ndarray): @@ -2179,9 +2220,9 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", RuntimeWarning, ) - if not np.issubdtype( - X.dtype, np.floating - ) and not np.issubdtype(X.dtype, np.integer): + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) @@ -2204,7 +2245,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro mu_raw = self.forest_container_mu.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) - mu_x = mu_raw * self.y_std / np.sqrt(self.variance_scale) + self.y_bar + mu_x = mu_raw * self.y_std + self.y_bar tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( forest_dataset_test.dataset_cpp ) @@ -2213,7 +2254,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro self.b1_samples - self.b0_samples, axis=(0, 2) ) tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std / np.sqrt(self.variance_scale)) + tau_x = np.squeeze(tau_raw * self.y_std) if Z.shape[1] > 1: treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( axis=2 @@ -2221,9 +2262,11 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro else: treatment_term = Z * np.squeeze(tau_x) yhat_x = mu_x + treatment_term - + if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) yhat_x = yhat_x + rfx_preds # Compute predictions from the variance forest (if included) @@ -2236,13 +2279,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro for i in range(self.num_samples): sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] else: - sigma2_x = ( - sigma2_x_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale - ) + sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std # Return result matrices as a tuple if self.has_rfx and self.include_variance_forest: @@ -2285,7 +2322,6 @@ def to_json(self) -> str: bcf_json.add_random_effects(self.rfx_container) # Add global parameters - bcf_json.add_scalar("variance_scale", self.variance_scale) bcf_json.add_scalar("outcome_scale", self.y_std) bcf_json.add_scalar("outcome_mean", self.y_bar) bcf_json.add_boolean("standardize", self.standardize) @@ -2365,14 +2401,13 @@ def from_json(self, json_string: str) -> None: self.forest_container_variance.forest_container_cpp.LoadFromJson( bcf_json.json_cpp, "forest_2" ) - + # Unpack random effects if self.has_rfx: self.rfx_container = RandomEffectsContainer() self.rfx_container.load_from_json(bcf_json, 0) # Unpack global parameters - self.variance_scale = bcf_json.get_scalar("variance_scale") self.y_std = bcf_json.get_scalar("outcome_scale") self.y_bar = bcf_json.get_scalar("outcome_mean") self.standardize = bcf_json.get_boolean("standardize") @@ -2421,6 +2456,162 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + def from_json_string_list(self, json_string_list: list[str]) -> None: + """ + Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object + which can be used for prediction, etc... + + Parameters + ------- + json_string_list : list of str + List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BCF model + """ + # Convert strings to JSONSerializer + json_object_list = [] + for i in range(len(json_string_list)): + json_string = json_string_list[i] + json_object_list.append(JSONSerializer()) + json_object_list[i].load_from_json_string(json_string) + + # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + json_object_default = json_object_list[0] + + # Unpack forests + # Mu forest + self.forest_container_mu = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_mu.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + else: + self.forest_container_mu.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + # Tau forest + self.forest_container_tau = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_tau.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_tau.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + self.include_variance_forest = json_object_default.get_boolean( + "include_variance_forest" + ) + if self.include_variance_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_2" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_2" + ) + + # Unpack random effects + self.has_rfx = json_object_default.get_boolean("has_rfx") + if self.has_rfx: + self.rfx_container = RandomEffectsContainer() + for i in range(len(json_object_list)): + if i == 0: + self.rfx_container.load_from_json(json_object_list[i], 0) + else: + self.rfx_container.append_from_json(json_object_list[i], 0) + + # Unpack global parameters + self.y_std = json_object_default.get_scalar("outcome_scale") + self.y_bar = json_object_default.get_scalar("outcome_mean") + self.standardize = json_object_default.get_boolean("standardize") + self.sigma2_init = json_object_default.get_scalar("sigma2_init") + self.sample_sigma_global = json_object_default.get_boolean( + "sample_sigma_global" + ) + self.sample_sigma_leaf_mu = json_object_default.get_boolean( + "sample_sigma_leaf_mu" + ) + self.sample_sigma_leaf_tau = json_object_default.get_boolean( + "sample_sigma_leaf_tau" + ) + self.num_gfr = json_object_default.get_scalar("num_gfr") + self.num_burnin = json_object_default.get_scalar("num_burnin") + self.num_mcmc = json_object_default.get_scalar("num_mcmc") + self.num_samples = json_object_default.get_scalar("num_samples") + self.adaptive_coding = json_object_default.get_boolean("adaptive_coding") + self.propensity_covariate = json_object_default.get_string( + "propensity_covariate" + ) + self.internal_propensity_model = json_object_default.get_boolean( + "internal_propensity_model" + ) + + # Unpack parameter samples + if self.sample_sigma_global: + for i in range(len(json_object_list)): + if i == 0: + self.global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + else: + global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + self.global_var_samples = np.concatenate( + (self.global_var_samples, global_var_samples) + ) + + if self.sample_sigma_leaf_mu: + for i in range(len(json_object_list)): + if i == 0: + self.leaf_scale_mu_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_mu_samples", "parameters" + ) + else: + leaf_scale_mu_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_mu_samples", "parameters" + ) + self.leaf_scale_mu_samples = np.concatenate( + (self.leaf_scale_mu_samples, leaf_scale_mu_samples) + ) + + if self.sample_sigma_leaf_tau: + for i in range(len(json_object_list)): + if i == 0: + self.sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + "sigma2_leaf_tau_samples", "parameters" + ) + else: + sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + "sigma2_leaf_tau_samples", "parameters" + ) + self.sample_sigma_leaf_tau = np.concatenate( + (self.sample_sigma_leaf_tau, sample_sigma_leaf_tau) + ) + + # Unpack internal propensity model + if self.internal_propensity_model: + bart_propensity_string = json_object_default.get_string( + "bart_propensity_model" + ) + self.bart_propensity_model = BARTModel() + self.bart_propensity_model.from_json(bart_propensity_string) + + # Unpack covariate preprocessor + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # Mark the deserialized model as "sampled" + self.sampled = True + def is_sampled(self) -> bool: """Whether or not a BCF model has been sampled. diff --git a/stochtree/forest.py b/stochtree/forest.py index 809a3908..ba5ff9c3 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -206,6 +206,17 @@ def load_from_json_string(self, json_string: str) -> None: """ self.forest_container_cpp.LoadFromJsonString(json_string) + def load_from_json_object(self, json_object) -> None: + """ + Reload a forest container from an in-memory JSONSerializer object. + + Parameters + ---------- + json_object : JSONSerializer + In-memory JSONSerializer object. + """ + self.forest_container_cpp.LoadFromJsonObject(json_object) + def add_sample(self, leaf_value: Union[float, np.array]) -> None: """ Add a new all-root ensemble to the container, with all of the leaves set to the value / vector provided @@ -903,8 +914,13 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: if isinstance(leaf_value, np.ndarray): if len(leaf_value.shape) > 1: leaf_value = np.squeeze(leaf_value) - if len(leaf_value.shape) != 1 or leaf_value.shape[0] != self.output_dimension: - raise ValueError("leaf_value must be a one-dimensional array with dimension equal to the output_dimension field of the forest") + if ( + len(leaf_value.shape) != 1 + or leaf_value.shape[0] != self.output_dimension + ): + raise ValueError( + "leaf_value must be a one-dimensional array with dimension equal to the output_dimension field of the forest" + ) if leaf_value.shape[0] > 1: self.forest_cpp.SetRootVector(leaf_value, leaf_value.shape[0]) else: @@ -1368,13 +1384,13 @@ def leaves(self, tree_num: int) -> np.array: def is_empty(self) -> bool: """ When a Forest object is created, it is "empty" in the sense that none - of its component trees have leaves with values. There are two ways to + of its component trees have leaves with values. There are two ways to "initialize" a Forest object. First, the `set_root_leaves()` method of the - Forest class simply initializes every tree in the forest to a single node - carrying the same (user-specified) leaf value. Second, the `prepare_for_sampler()` - method of the ForestSampler class initializes every tree in the forest to a - single node with the same value and also propagates this information through - to the temporary tracking data structrues in a ForestSampler object, which + Forest class simply initializes every tree in the forest to a single node + carrying the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + method of the ForestSampler class initializes every tree in the forest to a + single node with the same value and also propagates this information through + to the temporary tracking data structrues in a ForestSampler object, which must be synchronized with a Forest during a forest sampler loop. Returns diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 0badcbb3..18144044 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -177,8 +177,10 @@ class RandomEffectsContainer: def __init__(self) -> None: pass - - def load_new_container(self, num_components: int, num_groups: int, rfx_tracker: RandomEffectsTracker) -> None: + + def load_new_container( + self, num_components: int, num_groups: int, rfx_tracker: RandomEffectsTracker + ) -> None: """ Initializes internal data structures for an "empty" random effects container to be sampled and populated. @@ -198,7 +200,7 @@ def load_new_container(self, num_components: int, num_groups: int, rfx_tracker: self.rfx_label_mapper_cpp = RandomEffectsLabelMapperCpp() self.rfx_label_mapper_cpp.LoadFromTracker(rfx_tracker.rfx_tracker_cpp) self.rfx_group_ids = rfx_tracker.rfx_tracker_cpp.GetUniqueGroupIds() - + def load_from_json(self, json, rfx_num: int) -> None: """ Initializes internal data structures for an "empty" random effects container to be sampled and populated. @@ -210,14 +212,30 @@ def load_from_json(self, json, rfx_num: int) -> None: rfx_num : int Integer index of the RFX term in a JSON model. In practice, this is typically 0 (most models don't contain two RFX terms). """ - rfx_container_key = f'random_effect_container_{rfx_num:d}' - rfx_label_mapper_key = f'random_effect_label_mapper_{rfx_num:d}' - rfx_group_ids_key = f'random_effect_groupids_{rfx_num:d}' + rfx_container_key = f"random_effect_container_{rfx_num:d}" + rfx_label_mapper_key = f"random_effect_label_mapper_{rfx_num:d}" + rfx_group_ids_key = f"random_effect_groupids_{rfx_num:d}" self.rfx_container_cpp = RandomEffectsContainerCpp() self.rfx_container_cpp.LoadFromJson(json.json_cpp, rfx_container_key) self.rfx_label_mapper_cpp = RandomEffectsLabelMapperCpp() self.rfx_label_mapper_cpp.LoadFromJson(json.json_cpp, rfx_label_mapper_key) - self.rfx_group_ids = json.get_integer_vector(rfx_group_ids_key, "random_effects") + self.rfx_group_ids = json.get_integer_vector( + rfx_group_ids_key, "random_effects" + ) + + def append_from_json(self, json, rfx_num: int) -> None: + """ + Initializes internal data structures for an "empty" random effects container to be sampled and populated. + + Parameters + ---------- + json : JSONSerializer + Python object wrapping a C++ `json` object. + rfx_num : int + Integer index of the RFX term in a JSON model. In practice, this is typically 0 (most models don't contain two RFX terms). + """ + rfx_container_key = f"random_effect_container_{rfx_num:d}" + self.rfx_container_cpp.AppendFromJson(json.json_cpp, rfx_container_key) def num_samples(self) -> int: return self.rfx_container_cpp.NumSamples() diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 5b1415bd..a2f2e64c 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -41,12 +41,12 @@ def outcome_mean(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -61,6 +61,43 @@ def outcome_mean(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict(covariates=X_train) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_univariate_leaf_regression_homoskedastic(self): # RNG random_seed = 101 @@ -105,12 +142,12 @@ def outcome_mean(X, W): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -127,6 +164,47 @@ def outcome_mean(X, W): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_multivariate_leaf_regression_homoskedastic(self): # RNG random_seed = 101 @@ -171,12 +249,12 @@ def outcome_mean(X, W): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -193,6 +271,47 @@ def outcome_mean(X, W): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_constant_leaf_heteroskedastic(self): # RNG random_seed = 101 @@ -240,12 +359,12 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -264,6 +383,45 @@ def conditional_stddev(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict(covariates=X_train) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_univariate_leaf_regression_heteroskedastic(self): # RNG random_seed = 101 @@ -319,12 +477,12 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -345,6 +503,49 @@ def conditional_stddev(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_multivariate_leaf_regression_heteroskedastic(self): # RNG random_seed = 101 @@ -400,12 +601,321 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings + num_gfr = 10 + num_burnin = 0 + num_mcmc = 10 + + # Run BART with test set and propensity score + bart_model = BARTModel() + general_params = {"sample_sigma2_global": True} + variance_forest_params = {"num_trees": 50} + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + + def test_bart_constant_leaf_heteroskedastic_rfx(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 10 + X = rng.uniform(0, 1, (n, p_X)) + + # Generate RFX group labels and basis term + num_rfx_basis = 1 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + rfx_basis = np.empty((n, num_rfx_basis)) + rfx_basis[:, 0] = 1.0 + if num_rfx_basis > 1: + rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the outcome mean function + def outcome_mean(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5, 7.5), + ), + ) + + # Define the conditional standard deviation function + def conditional_stddev(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + 0.25, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + 0.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2), + ), + ) + + # Define the group rfx function + def rfx_term(group_labels, basis): + return np.where( + group_labels == 0, + 0, + np.where(group_labels == 1, 4, np.where(group_labels == 2, 8, 12)), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = ( + outcome_mean(X) + + rfx_term(group_labels, rfx_basis) + + epsilon * conditional_stddev(X) + ) + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + group_labels_train = group_labels[train_inds] + group_labels_test = group_labels[test_inds] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + y_train = y[train_inds] + n_train = X_train.shape[0] + n_test = X_test.shape[0] + + # BART settings + num_gfr = 10 + num_burnin = 0 + num_mcmc = 10 + + # Run BART with test set and propensity score + bart_model = BARTModel() + general_params = {"sample_sigma2_global": True} + variance_forest_params = {"num_trees": 50} + bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + rfx_preds_train = bart_model.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + rfx_preds_train_2 = bart_model_2.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + rfx_preds_train_3 = bart_model_3.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, + rfx_group_ids=group_labels_train, + rfx_basis=rfx_basis_train, + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) + np.testing.assert_allclose( + rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 + ) + + def test_bart_univariate_leaf_regression_heteroskedastic_rfx(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Generate RFX group labels and basis term + num_rfx_basis = 1 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + rfx_basis = np.empty((n, num_rfx_basis)) + rfx_basis[:, 0] = 1.0 + if num_rfx_basis > 1: + rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.5) & (X[:, 0] < 0.75), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), + ) + + # Define the group rfx function + def rfx_term(group_labels, basis): + return np.where( + group_labels == 0, + 0, + np.where(group_labels == 1, 4, np.where(group_labels == 2, 8, 12)), + ) + + # Define the conditional standard deviation function + def conditional_stddev(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + 0.25, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + 0.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = ( + outcome_mean(X, W) + + rfx_term(group_labels, rfx_basis) + + epsilon * conditional_stddev(X) + ) + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + basis_train = W[train_inds, :] + basis_test = W[test_inds, :] + group_labels_train = group_labels[train_inds] + group_labels_test = group_labels[test_inds] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + y_train = y[train_inds] + n_train = X_train.shape[0] + n_test = X_test.shape[0] + + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -413,15 +923,82 @@ def conditional_stddev(X): X_train=X_train, y_train=y_train, leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, X_test=X_test, leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, variance_forest_params=variance_forest_params, ) + rfx_preds_train = bart_model.rfx_container.predict( + group_labels_train, rfx_basis_train + ) # Assertions assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + rfx_preds_train_2 = bart_model_2.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + rfx_preds_train_3 = bart_model_3.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, + basis=basis_train, + rfx_group_ids=group_labels_train, + rfx_basis=rfx_basis_train, + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) + np.testing.assert_allclose( + rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 + ) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index dc6fe162..96f25a34 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -244,6 +244,65 @@ def test_continuous_univariate_bcf(self): tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) assert tau_hat.shape == (n_test, num_mcmc) + # Run second BCF model with test set and propensity score + bcf_model_2 = BCFModel() + variance_forest_params = {"num_trees": 0} + bcf_model_2.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.y_hat_test.shape == (n_test, num_mcmc) + assert bcf_model_2.mu_hat_test.shape == (n_test, num_mcmc) + assert bcf_model_2.tau_hat_test.shape == (n_test, num_mcmc) + + # Check overall prediction method + tau_hat_2, mu_hat_2, y_hat_2 = bcf_model_2.predict(X_test, Z_test, pi_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + assert mu_hat_2.shape == (n_test, num_mcmc) + assert y_hat_2.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test, pi_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + + # Combine into a single model + bcf_models_json = [bcf_model.to_json(), bcf_model_2.to_json()] + bcf_model_3 = BCFModel() + bcf_model_3.from_json_string_list(bcf_models_json) + + # Assertions + tau_hat_3, mu_hat_3, y_hat_3 = bcf_model_3.predict(X_test, Z_test, pi_test) + assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert mu_hat_3.shape == (n_train, num_mcmc * 2) + assert y_hat_3.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) + np.testing.assert_allclose(y_hat_3[:, num_mcmc : (2 * num_mcmc)], y_hat_2) + np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) + np.testing.assert_allclose(mu_hat_3[:, num_mcmc : (2 * num_mcmc)], mu_hat_2) + np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose(tau_hat_3[:, num_mcmc : (2 * num_mcmc)], tau_hat_2) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples + ) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bcf_model_2.global_var_samples, + ) + # Run BCF without test set and with propensity score bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} @@ -336,6 +395,55 @@ def test_continuous_univariate_bcf(self): # Check treatment effect prediction method tau_hat = bcf_model.predict_tau(X_test, Z_test) + # Run second BCF model with test set and propensity score + bcf_model_2 = BCFModel() + variance_forest_params = {"num_trees": 0} + bcf_model_2.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.tau_hat_train.shape == (n_train, num_mcmc) + + # Check overall prediction method + tau_hat_2, mu_hat_2, y_hat_2 = bcf_model_2.predict(X_test, Z_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + assert mu_hat_2.shape == (n_test, num_mcmc) + assert y_hat_2.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + + # Combine into a single model + bcf_models_json = [bcf_model.to_json(), bcf_model_2.to_json()] + bcf_model_3 = BCFModel() + bcf_model_3.from_json_string_list(bcf_models_json) + + # Assertions + tau_hat_3, mu_hat_3, y_hat_3 = bcf_model_3.predict(X_test, Z_test) + assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert mu_hat_3.shape == (n_train, num_mcmc * 2) + assert y_hat_3.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) + np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) + np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples + ) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bcf_model_2.global_var_samples, + ) + def test_multivariate_bcf(self): # RNG random_seed = 101