diff --git a/Doxyfile b/Doxyfile index 06c38d94..58837078 100644 --- a/Doxyfile +++ b/Doxyfile @@ -48,7 +48,7 @@ PROJECT_NAME = "StochTree" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 0.0.1 +PROJECT_NUMBER = 0.1.1 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/R/bart.R b/R/bart.R index ca717621..96815850 100644 --- a/R/bart.R +++ b/R/bart.R @@ -206,7 +206,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (previous_bart_model$model_params$include_mean_forest) { previous_forest_samples_mean <- previous_bart_model$mean_forests } else previous_forest_samples_mean <- NULL - if (previous_bart_model$model_params$include_mean_forest) { + if (previous_bart_model$model_params$include_variance_forest) { previous_forest_samples_variance <- previous_bart_model$variance_forests } else previous_forest_samples_variance <- NULL if (previous_bart_model$model_params$sample_sigma_global) { @@ -1853,7 +1853,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){ } # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata") + preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata") output[["train_set_metadata"]] <- createPreprocessorFromJsonString( preprocessor_metadata_string ) diff --git a/demo/debug/multi_chain.py b/demo/debug/multi_chain.py new file mode 100644 index 00000000..6d8aef68 --- /dev/null +++ b/demo/debug/multi_chain.py @@ -0,0 +1,164 @@ +# Multi Chain Demo Script + +# Load necessary libraries +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel + +# Generate sample data +# RNG +random_seed = 1234 +rng = np.random.default_rng(random_seed) + +# Generate covariates and basis +n = 500 +p_X = 10 +p_W = 1 +X = rng.uniform(0, 1, (n, p_X)) +W = rng.uniform(0, 1, (n, p_W)) + + +# Define the outcome mean function +def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]), + ), + ) + + +# Generate outcome +f_XW = outcome_mean(X, W) +epsilon = rng.normal(0, 1, n) +y = f_XW + epsilon + +# Test-train split +sample_inds = np.arange(n) +train_inds, test_inds = train_test_split( + sample_inds, test_size=0.5, random_state=random_seed +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +basis_train = W[train_inds, :] +basis_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] + +# Run the GFR algorithm for a small number of iterations +general_model_params = {"random_seed": -1} +mean_forest_model_params = {"num_trees": 20} +num_warmstart = 10 +num_mcmc = 10 +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_warmstart, + num_mcmc=0, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, +) +bart_model_json = bart_model.to_json() + +# Run several BART MCMC samples from the last GFR forest +bart_model_2 = BARTModel() +bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=bart_model_json, + previous_model_warmstart_sample_num=num_warmstart - 1, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, +) + +# Run several BART MCMC samples from the second-to-last GFR forest +bart_model_3 = BARTModel() +bart_model_3.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=bart_model_json, + previous_model_warmstart_sample_num=num_warmstart - 2, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, +) + +# Run several BART MCMC samples from root +bart_model_4 = BARTModel() +bart_model_4.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, +) + +# Inspect the model outputs +y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test) +y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) +y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test) +y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) +y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test) +y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) +y_df = pd.DataFrame( + np.concatenate( + (y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), + axis=1, + ), + columns=["First Chain", "Second Chain", "Third Chain", "Outcome"], +) + +# Compare first warm-start chain to root chain with equal number of MCMC draws +sns.scatterplot(data=y_df, x="First Chain", y="Third Chain") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compare first warm-start chain to outcome +sns.scatterplot(data=y_df, x="First Chain", y="Outcome") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compare root chain to outcome +sns.scatterplot(data=y_df, x="Third Chain", y="Outcome") +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) +plt.show() + +# Compute RMSEs +rmse_1 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test)) +) +rmse_2 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test)) +) +rmse_3 = np.sqrt( + np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test)) +) +print( + "Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format( + rmse_1, rmse_2, rmse_3 + ) +) diff --git a/demo/debug/parallel_multi_chain.py b/demo/debug/parallel_multi_chain.py new file mode 100644 index 00000000..ee114aee --- /dev/null +++ b/demo/debug/parallel_multi_chain.py @@ -0,0 +1,177 @@ +# Multi Chain Demo Script + +# Load necessary libraries +from multiprocessing import Pool, cpu_count + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from sklearn.model_selection import train_test_split + +from stochtree import BARTModel + + +def fit_bart( + model_string, + X_train, + y_train, + basis_train, + X_test, + basis_test, + num_mcmc, + gen_param_list, + mean_list, + i, +): + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=0, + num_mcmc=num_mcmc, + previous_model_json=model_string, + previous_model_warmstart_sample_num=i, + general_params=gen_param_list, + mean_forest_params=mean_list, + ) + return (bart_model.to_json(), bart_model.y_hat_test) + + +def bart_warmstart_parallel(X_train, y_train, basis_train, X_test, basis_test): + # Run the GFR algorithm for a small number of iterations + general_model_params = {"random_seed": -1} + mean_forest_model_params = {"num_trees": 100} + num_warmstart = 10 + num_mcmc = 100 + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_warmstart, + num_mcmc=0, + general_params=general_model_params, + mean_forest_params=mean_forest_model_params, + ) + bart_model_json = bart_model.to_json() + + # Warm-start multiple BART fits from a different GFR forest + process_tasks = [ + ( + bart_model_json, + X_train, + y_train, + basis_train, + X_test, + basis_test, + num_mcmc, + general_model_params, + mean_forest_model_params, + i, + ) + for i in range(4) + ] + num_processes = cpu_count() + with Pool(processes=num_processes) as pool: + results = pool.starmap(fit_bart, process_tasks) + + # Extract separate outputs as separate lists + bart_model_json_list, bart_model_pred_list = zip(*results) + + # Process results + combined_bart_model = BARTModel() + combined_bart_model.from_json_string_list(bart_model_json_list) + combined_bart_preds = bart_model_pred_list[0] + for i in range(1, len(bart_model_pred_list)): + combined_bart_preds = np.concatenate( + (combined_bart_preds, bart_model_pred_list[i]), axis=1 + ) + + return (combined_bart_model, combined_bart_preds) + + +if __name__ == "__main__": + # RNG + random_seed = 1234 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 1000 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0] + ), + ), + ) + + # Generate outcome + f_XW = outcome_mean(X, W) + epsilon = rng.normal(0, 1, n) + y = f_XW + epsilon + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split( + sample_inds, test_size=0.2, random_state=random_seed + ) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + basis_train = W[train_inds, :] + basis_test = W[test_inds, :] + y_train = y[train_inds] + y_test = y[test_inds] + + # Run the parallel BART + combined_bart, combined_bart_preds = bart_warmstart_parallel( + X_train, y_train, basis_train, X_test, basis_test + ) + + # Inspect the model outputs + y_hat_mcmc = combined_bart.predict(X_test, basis_test) + y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True) + y_df = pd.DataFrame( + np.concatenate((y_avg_mcmc, np.expand_dims(y_test, axis=1)), axis=1), + columns=["Average BART Predictions", "Outcome"], + ) + + # Compare first warm-start chain to outcome + sns.scatterplot(data=y_df, x="Average BART Predictions", y="Outcome") + plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) + plt.show() + + # Compare cached predictions to deserialized predictions for first chain + chain_index = 0 + num_mcmc = 100 + offset_index = num_mcmc * chain_index + chain_inds = slice(offset_index, (offset_index + num_mcmc)) + chain_1_preds_original = np.squeeze(combined_bart_preds[chain_inds]).mean( + axis=1, keepdims=True + ) + chain_1_preds_reloaded = np.squeeze(y_hat_mcmc[chain_inds]).mean( + axis=1, keepdims=True + ) + chain_df = pd.DataFrame( + np.concatenate((chain_1_preds_reloaded, chain_1_preds_original), axis=1), + columns=["New Predictions", "Original Predictions"], + ) + sns.scatterplot(data=chain_df, x="New Predictions", y="Original Predictions") + plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) + plt.show() diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 65f8c927..8c73fd59 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -325,6 +325,8 @@ class ForestContainerCpp { void LoadFromJson(JsonCpp& json, std::string forest_label); + void AppendFromJson(JsonCpp& json, std::string forest_label); + std::string DumpJsonString() { return forest_samples_->DumpJsonString(); } @@ -1289,6 +1291,7 @@ class RandomEffectsContainerCpp { rfx_container_->LoadFromJsonString(json_string); } void LoadFromJson(JsonCpp& json, std::string rfx_container_label); + void AppendFromJson(JsonCpp& json, std::string rfx_container_label); StochTree::RandomEffectsContainer* GetRandomEffectsContainer() { return rfx_container_.get(); } @@ -1870,6 +1873,11 @@ void ForestContainerCpp::LoadFromJson(JsonCpp& json, std::string forest_label) { forest_samples_->from_json(forest_json); } +void ForestContainerCpp::AppendFromJson(JsonCpp& json, std::string forest_label) { + nlohmann::json forest_json = json.SubsetJsonForest(forest_label); + forest_samples_->append_from_json(forest_json); +} + void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) { // Determine whether or not we are adding forest_num to the residuals std::function op; @@ -1896,6 +1904,11 @@ void RandomEffectsContainerCpp::LoadFromJson(JsonCpp& json, std::string rfx_cont rfx_container_->from_json(rfx_json); } +void RandomEffectsContainerCpp::AppendFromJson(JsonCpp& json, std::string rfx_container_label) { + nlohmann::json rfx_json = json.SubsetJsonRFX().at(rfx_container_label); + rfx_container_->append_from_json(rfx_json); +} + void RandomEffectsContainerCpp::AddSample(RandomEffectsModelCpp& rfx_model) { rfx_container_->AddSample(*rfx_model.GetModel()); } @@ -2012,6 +2025,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile) .def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile) .def("LoadFromJson", &ForestContainerCpp::LoadFromJson) + .def("AppendFromJson", &ForestContainerCpp::AppendFromJson) .def("DumpJsonString", &ForestContainerCpp::DumpJsonString) .def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString) .def("AddSampleValue", &ForestContainerCpp::AddSampleValue) @@ -2125,6 +2139,7 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("DumpJsonString", &RandomEffectsContainerCpp::DumpJsonString) .def("LoadFromJsonString", &RandomEffectsContainerCpp::LoadFromJsonString) .def("LoadFromJson", &RandomEffectsContainerCpp::LoadFromJson) + .def("AppendFromJson", &RandomEffectsContainerCpp::AppendFromJson) .def("GetRandomEffectsContainer", &RandomEffectsContainerCpp::GetRandomEffectsContainer); py::class_(m, "RandomEffectsTrackerCpp") diff --git a/stochtree/bart.py b/stochtree/bart.py index 6608f576..c8943c85 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -14,7 +14,12 @@ from .data import Dataset, Residual from .forest import Forest, ForestContainer from .preprocessing import CovariatePreprocessor, _preprocess_params -from .random_effects import RandomEffectsContainer, RandomEffectsDataset, RandomEffectsModel, RandomEffectsTracker +from .random_effects import ( + RandomEffectsContainer, + RandomEffectsDataset, + RandomEffectsModel, + RandomEffectsTracker, +) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -58,7 +63,6 @@ class BARTModel: def __init__(self) -> None: # Internal flag for whether the sample() method has been run self.sampled = False - self.rng = np.random.default_rng() def sample( self, @@ -77,6 +81,8 @@ def sample( general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + previous_model_json: Optional[str] = None, + previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: """Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set. Does not require a leaf regression basis. @@ -99,7 +105,7 @@ def sample( Optional test set basis vector used to define a regression to be run in the leaves of each tree. Must be included / omitted consistently (i.e. if leaf_basis_train is provided, then leaf_basis_test must be provided alongside X_test). rfx_group_ids_test : np.array, optional - Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), + Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. @@ -155,6 +161,11 @@ def sample( * `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`. * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. + previous_model_json : str, optional + JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. + previous_model_warmstart_sample_num : int, optional + Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`. + Returns ------- self : BARTModel @@ -244,22 +255,24 @@ def sample( drop_vars_mean = mean_forest_params_updated["drop_vars"] # 3. Variance forest parameters - num_trees_variance = variance_forest_params_updated['num_trees'] - alpha_variance = variance_forest_params_updated['alpha'] - beta_variance = variance_forest_params_updated['beta'] - min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] - max_depth_variance = variance_forest_params_updated['max_depth'] - a_0 = variance_forest_params_updated['leaf_prior_calibration_param'] - variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] - a_forest = variance_forest_params_updated['var_forest_prior_shape'] - b_forest = variance_forest_params_updated['var_forest_prior_scale'] - keep_vars_variance = variance_forest_params_updated['keep_vars'] - drop_vars_variance = variance_forest_params_updated['drop_vars'] - + num_trees_variance = variance_forest_params_updated["num_trees"] + alpha_variance = variance_forest_params_updated["alpha"] + beta_variance = variance_forest_params_updated["beta"] + min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"] + max_depth_variance = variance_forest_params_updated["max_depth"] + a_0 = variance_forest_params_updated["leaf_prior_calibration_param"] + variance_forest_leaf_init = variance_forest_params_updated[ + "var_forest_leaf_init" + ] + a_forest = variance_forest_params_updated["var_forest_prior_shape"] + b_forest = variance_forest_params_updated["var_forest_prior_scale"] + keep_vars_variance = variance_forest_params_updated["keep_vars"] + drop_vars_variance = variance_forest_params_updated["drop_vars"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True - + # Check that num_chains >= 1 if not isinstance(num_chains, Integral) or num_chains < 1: raise ValueError("num_chains must be an integer greater than 0") @@ -297,7 +310,9 @@ def sample( if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") if not np.issubdtype(rfx_group_ids_train.dtype, np.integer): - raise ValueError("rfx_group_ids_train must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_train must be a numpy array of integer-valued group IDs" + ) if rfx_basis_train is not None: if not isinstance(rfx_basis_train, np.ndarray): raise ValueError("rfx_basis_train must be a numpy array") @@ -305,11 +320,13 @@ def sample( if not isinstance(rfx_group_ids_test, np.ndarray): raise ValueError("rfx_group_ids_test must be a numpy array") if not np.issubdtype(rfx_group_ids_test.dtype, np.integer): - raise ValueError("rfx_group_ids_test must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_test must be a numpy array of integer-valued group IDs" + ) if rfx_basis_test is not None: if not isinstance(rfx_basis_test, np.ndarray): raise ValueError("rfx_basis_test must be a numpy array") - + # Convert everything to standard shape (2-dimensional) if isinstance(X_train, np.ndarray): if X_train.ndim == 1: @@ -352,7 +369,9 @@ def sample( "leaf_basis_train and leaf_basis_test must have the same number of columns" ) else: - raise ValueError("leaf_basis_test provided but leaf_basis_train was not") + raise ValueError( + "leaf_basis_test provided but leaf_basis_train was not" + ) if leaf_basis_train is not None: if leaf_basis_train.shape[0] != X_train.shape[0]: raise ValueError( @@ -612,6 +631,61 @@ def sample( else: variable_subset_variance = [i for i in range(X_train.shape[1])] + # Check if previous model JSON is provided and parse it if so + has_prev_model = previous_model_json is not None + if has_prev_model: + if num_gfr > 0: + if num_mcmc == 0: + raise ValueError( + "A previous model is being used to initialize this sampler, so `num_mcmc` must be greater than zero" + ) + else: + warnings.warn( + "A previous model is being used to initialize this sampler, so num_gfr will be ignored and the MCMC sampler will be run from the previous samples" + ) + previous_bart_model = BARTModel() + previous_bart_model.from_json(previous_model_json) + previous_y_bar = previous_bart_model.y_bar + previous_y_scale = previous_bart_model.y_std + previous_model_num_samples = previous_bart_model.num_samples + if previous_bart_model.include_mean_forest: + previous_forest_samples_mean = previous_bart_model.forest_container_mean + else: + previous_forest_samples_mean = None + if previous_bart_model.include_variance_forest: + previous_forest_samples_variance = ( + previous_bart_model.forest_container_variance + ) + else: + previous_forest_samples_variance = None + if previous_bart_model.sample_sigma_global: + previous_global_var_samples = previous_bart_model.global_var_samples / ( + previous_y_scale * previous_y_scale + ) + else: + previous_global_var_samples = None + if previous_bart_model.sample_sigma_leaf: + previous_leaf_var_samples = previous_bart_model.leaf_scale_samples + else: + previous_leaf_var_samples = None + if previous_bart_model.has_rfx: + previous_rfx_samples = previous_bart_model.rfx_container + else: + previous_rfx_samples = None + if previous_model_warmstart_sample_num + 1 > previous_model_num_samples: + raise ValueError( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) + else: + previous_y_bar = None + previous_y_scale = None + previous_global_var_samples = None + previous_leaf_var_samples = None + previous_rfx_samples = None + previous_forest_samples_mean = None + previous_forest_samples_variance = None + previous_model_num_samples = 0 + # Update variable weights if the covariates have been resized (by e.g. one-hot encoding) if X_train_processed.shape[1] != X_train.shape[1]: variable_counts = [ @@ -661,35 +735,56 @@ def sample( if self.has_basis: if sigma_leaf is None: current_leaf_scale = np.zeros((self.num_basis, self.num_basis)) - np.fill_diagonal(current_leaf_scale, np.squeeze(np.var(resid_train)) / num_trees_mean) + np.fill_diagonal( + current_leaf_scale, + np.squeeze(np.var(resid_train)) / num_trees_mean, + ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.zeros((self.num_basis, self.num_basis)) np.fill_diagonal(current_leaf_scale, sigma_leaf) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != sigma_leaf.shape[1]: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != self.num_basis: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array with its dimensionality matching the basis dimension" + ) current_leaf_scale = sigma_leaf else: - raise ValueError("sigma_leaf must be either a scalar or a 2d symmetric numpy array") + raise ValueError( + "sigma_leaf must be either a scalar or a 2d symmetric numpy array" + ) else: if sigma_leaf is None: - current_leaf_scale = np.array([[np.squeeze(np.var(resid_train)) / num_trees_mean]]) + current_leaf_scale = np.array( + [[np.squeeze(np.var(resid_train)) / num_trees_mean]] + ) elif isinstance(sigma_leaf, float): current_leaf_scale = np.array([[sigma_leaf]]) elif isinstance(sigma_leaf, np.ndarray): if sigma_leaf.ndim != 2: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != sigma_leaf.shape[1]: - raise ValueError("sigma_leaf must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf.shape[0] != 1: - raise ValueError("sigma_leaf must be a 1x1 numpy array for this leaf model") + raise ValueError( + "sigma_leaf must be a 1x1 numpy array for this leaf model" + ) current_leaf_scale = sigma_leaf else: - raise ValueError("sigma_leaf must be either a scalar or a 2d numpy array") + raise ValueError( + "sigma_leaf must be either a scalar or a 2d numpy array" + ) else: current_leaf_scale = np.array([[1.0]]) if self.include_variance_forest: @@ -702,7 +797,7 @@ def sample( a_forest = 1.0 if not b_forest: b_forest = 1.0 - + # Runtime checks on RFX group ids self.has_rfx = False has_rfx_test = False @@ -711,13 +806,15 @@ def sample( if rfx_group_ids_test is not None: has_rfx_test = True if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx = False if self.has_rfx: if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0],1)) + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) else: has_basis_rfx = True num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] @@ -726,22 +823,26 @@ def sample( if has_rfx_test: if rfx_basis_test is None: if has_basis_rfx: - raise ValueError("Random effects basis provided for training set, must also be provided for the test set") - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0],1)) - + raise ValueError( + "Random effects basis provided for training set, must also be provided for the test set" + ) + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + # Set up random effects structures if self.has_rfx: if num_rfx_components == 1: alpha_init = np.array([1]) elif num_rfx_components > 1: - alpha_init = np.concatenate((np.ones(1), np.zeros(num_rfx_components-1))) + alpha_init = np.concatenate( + (np.ones(1), np.zeros(num_rfx_components - 1)) + ) else: raise ValueError("There must be at least 1 random effect component") xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) sigma_alpha_init = np.identity(num_rfx_components) sigma_xi_init = np.identity(num_rfx_components) - sigma_xi_shape = 1. - sigma_xi_scale = 1. + sigma_xi_shape = 1.0 + sigma_xi_scale = 1.0 rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) rfx_dataset_train.add_basis(rfx_basis_train) @@ -754,7 +855,9 @@ def sample( rfx_model.set_variance_prior_shape(sigma_xi_shape) rfx_model.set_variance_prior_scale(sigma_xi_scale) self.rfx_container = RandomEffectsContainer() - self.rfx_container.load_new_container(num_rfx_components, num_rfx_groups, rfx_tracker) + self.rfx_container.load_new_container( + num_rfx_components, num_rfx_groups, rfx_tracker + ) # Container of variance parameter samples self.num_gfr = num_gfr @@ -960,11 +1063,17 @@ def sample( self.leaf_scale_samples[sample_counter] = current_leaf_scale[ 0, 0 ] - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Run MCMC @@ -992,6 +1101,44 @@ def sample( ) if sample_sigma_global: current_sigma2 = self.global_var_samples[forest_ind] + elif has_prev_model: + if self.include_mean_forest: + active_forest_mean.reset( + previous_bart_model.forest_container_mean, + previous_model_warmstart_sample_num, + ) + forest_sampler_mean.reconstitute_from_forest( + active_forest_mean, + forest_dataset_train, + residual_train, + True, + ) + if sample_sigma_leaf and previous_leaf_var_samples is not None: + leaf_scale_double = previous_leaf_var_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale[0, 0] = leaf_scale_double + forest_model_config_mean.update_leaf_model_scale( + leaf_scale_double + ) + if self.include_variance_forest: + active_forest_variance.reset( + previous_bart_model.forest_container_variance, + previous_model_warmstart_sample_num, + ) + forest_sampler_variance.reconstitute_from_forest( + active_forest_variance, + forest_dataset_train, + residual_train, + True, + ) + # if self.has_rfx: + # pass + if self.sample_sigma_global: + current_sigma2 = previous_global_var_samples[ + previous_model_warmstart_sample_num + ] + global_model_config.update_global_error_variance(current_sigma2) else: if self.include_mean_forest: active_forest_mean.reset_root() @@ -1069,21 +1216,31 @@ def sample( current_sigma2 = global_var_model.sample_one_iteration( residual_train, cpp_rng, a_global, b_global ) + global_model_config.update_global_error_variance(current_sigma2) if keep_sample: self.global_var_samples[sample_counter] = current_sigma2 if self.sample_sigma_leaf: current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration( active_forest_mean, cpp_rng, a_leaf, b_leaf ) + forest_model_config_mean.update_leaf_model_scale( + current_leaf_scale + ) if keep_sample: self.leaf_scale_samples[sample_counter] = ( current_leaf_scale[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Mark the model as sampled @@ -1121,12 +1278,18 @@ def sample( forest_dataset_test.dataset_cpp ) self.y_hat_test = yhat_test_raw * self.y_std + self.y_bar - + # TODO: make rfx_preds_train and rfx_preds_test persistent properties if self.has_rfx: - rfx_preds_train = self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) * self.y_std + rfx_preds_train = ( + self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) + * self.y_std + ) if has_rfx_test: - rfx_preds_test = self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) * self.y_std + rfx_preds_test = ( + self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) + * self.y_std + ) if self.include_mean_forest: self.y_hat_train = self.y_hat_train + rfx_preds_train if self.has_test: @@ -1170,8 +1333,11 @@ def sample( ) def predict( - self, covariates: Union[np.array, pd.DataFrame], basis: np.array = None, - rfx_group_ids: np.array = None, rfx_basis: np.array = None + self, + covariates: Union[np.array, pd.DataFrame], + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1256,9 +1422,11 @@ def predict( pred_dataset.dataset_cpp ) mean_pred = mean_pred_raw * self.y_std + self.y_bar - + if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) if self.include_mean_forest: mean_pred = mean_pred + rfx_preds else: @@ -1290,8 +1458,11 @@ def predict( return variance_pred def predict_mean( - self, covariates: np.array, basis: np.array = None, - rfx_group_ids: np.array = None, rfx_basis: np.array = None + self, + covariates: np.array, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, ) -> np.array: """Predict expected conditional outcome from a BART model. @@ -1379,7 +1550,9 @@ def predict_mean( # RFX predictions if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) if self.include_mean_forest: mean_pred = mean_pred + rfx_preds else: @@ -1507,11 +1680,11 @@ def to_json(self) -> str: bart_json.add_boolean("include_mean_forest", self.include_mean_forest) bart_json.add_boolean("include_variance_forest", self.include_variance_forest) bart_json.add_boolean("has_rfx", self.has_rfx) - bart_json.add_scalar("num_gfr", self.num_gfr) - bart_json.add_scalar("num_burnin", self.num_burnin) - bart_json.add_scalar("num_mcmc", self.num_mcmc) - bart_json.add_scalar("num_samples", self.num_samples) - bart_json.add_scalar("num_basis", self.num_basis) + bart_json.add_integer("num_gfr", self.num_gfr) + bart_json.add_integer("num_burnin", self.num_burnin) + bart_json.add_integer("num_mcmc", self.num_mcmc) + bart_json.add_integer("num_samples", self.num_samples) + bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) # Add parameter samples @@ -1565,7 +1738,7 @@ def from_json(self, json_string: str) -> None: self.forest_container_variance.forest_container_cpp.LoadFromJson( bart_json.json_cpp, "forest_0" ) - + # Unpack random effects if self.has_rfx: self.rfx_container = RandomEffectsContainer() @@ -1578,11 +1751,11 @@ def from_json(self, json_string: str) -> None: self.sigma2_init = bart_json.get_scalar("sigma2_init") self.sample_sigma_global = bart_json.get_boolean("sample_sigma_global") self.sample_sigma_leaf = bart_json.get_boolean("sample_sigma_leaf") - self.num_gfr = bart_json.get_scalar("num_gfr") - self.num_burnin = bart_json.get_scalar("num_burnin") - self.num_mcmc = bart_json.get_scalar("num_mcmc") - self.num_samples = bart_json.get_scalar("num_samples") - self.num_basis = bart_json.get_scalar("num_basis") + self.num_gfr = bart_json.get_integer("num_gfr") + self.num_burnin = bart_json.get_integer("num_burnin") + self.num_mcmc = bart_json.get_integer("num_mcmc") + self.num_samples = bart_json.get_integer("num_samples") + self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") # Unpack parameter samples @@ -1603,6 +1776,135 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + def from_json_string_list(self, json_string_list: list[str]) -> None: + """ + Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object + which can be used for prediction, etc... + + Parameters + ------- + json_string_list : list of str + List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BART model + """ + # Convert strings to JSONSerializer + json_object_list = [] + for i in range(len(json_string_list)): + json_string = json_string_list[i] + json_object_list.append(JSONSerializer()) + json_object_list[i].load_from_json_string(json_string) + + # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + json_object_default = json_object_list[0] + + # Unpack forests + self.include_mean_forest = json_object_default.get_boolean( + "include_mean_forest" + ) + self.include_variance_forest = json_object_default.get_boolean( + "include_variance_forest" + ) + if self.include_mean_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_mean = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_mean.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + else: + self.forest_container_mean.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + if self.include_variance_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + + # Unpack random effects + self.has_rfx = json_object_default.get_boolean("has_rfx") + if self.has_rfx: + self.rfx_container = RandomEffectsContainer() + for i in range(len(json_object_list)): + if i == 0: + self.rfx_container.load_from_json(json_object_list[i], 0) + else: + self.rfx_container.append_from_json(json_object_list[i], 0) + + # Unpack global parameters + self.y_std = json_object_default.get_scalar("outcome_scale") + self.y_bar = json_object_default.get_scalar("outcome_mean") + self.standardize = json_object_default.get_boolean("standardize") + self.sigma2_init = json_object_default.get_scalar("sigma2_init") + self.sample_sigma_global = json_object_default.get_boolean( + "sample_sigma_global" + ) + self.sample_sigma_leaf = json_object_default.get_boolean("sample_sigma_leaf") + self.num_gfr = json_object_default.get_integer("num_gfr") + self.num_burnin = json_object_default.get_integer("num_burnin") + self.num_mcmc = json_object_default.get_integer("num_mcmc") + self.num_samples = json_object_default.get_integer("num_samples") + self.num_basis = json_object_default.get_integer("num_basis") + self.has_basis = json_object_default.get_boolean("requires_basis") + + # Unpack parameter samples + if self.sample_sigma_global: + for i in range(len(json_object_list)): + if i == 0: + self.global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + else: + global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + self.global_var_samples = np.concatenate( + (self.global_var_samples, global_var_samples) + ) + + if self.sample_sigma_leaf: + for i in range(len(json_object_list)): + if i == 0: + self.leaf_scale_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_samples", "parameters" + ) + else: + leaf_scale_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_samples", "parameters" + ) + self.leaf_scale_samples = np.concatenate( + (self.leaf_scale_samples, leaf_scale_samples) + ) + + # Unpack covariate preprocessor + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # Mark the deserialized model as "sampled" + self.sampled = True + def is_sampled(self) -> bool: """Whether or not a BART model has been sampled. diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 276c0f54..37dc80ec 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -14,7 +14,12 @@ from .data import Dataset, Residual from .forest import Forest, ForestContainer from .preprocessing import CovariatePreprocessor, _preprocess_params -from .random_effects import RandomEffectsContainer, RandomEffectsDataset, RandomEffectsModel, RandomEffectsTracker +from .random_effects import ( + RandomEffectsContainer, + RandomEffectsDataset, + RandomEffectsModel, + RandomEffectsTracker, +) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer from .utils import NotSampledError @@ -115,7 +120,7 @@ def sample( pi_test : np.array, optional Optional test set vector of propensity scores. If not provided (but `X_test` and `Z_test` are), this will be estimated from the data. rfx_group_ids_test : np.array, optional - Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), + Optional test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. rfx_basis_test : np.array, optional Optional test set basis for "random-slope" regression in additive random effects model. @@ -310,9 +315,13 @@ def sample( num_trees_tau = treatment_effect_forest_params_updated["num_trees"] alpha_tau = treatment_effect_forest_params_updated["alpha"] beta_tau = treatment_effect_forest_params_updated["beta"] - min_samples_leaf_tau = treatment_effect_forest_params_updated["min_samples_leaf"] + min_samples_leaf_tau = treatment_effect_forest_params_updated[ + "min_samples_leaf" + ] max_depth_tau = treatment_effect_forest_params_updated["max_depth"] - sample_sigma_leaf_tau = treatment_effect_forest_params_updated["sample_sigma2_leaf"] + sample_sigma_leaf_tau = treatment_effect_forest_params_updated[ + "sample_sigma2_leaf" + ] sigma_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_init"] a_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_shape"] b_leaf_tau = treatment_effect_forest_params_updated["sigma2_leaf_scale"] @@ -320,22 +329,24 @@ def sample( drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"] # 4. Variance forest parameters - num_trees_variance = variance_forest_params_updated['num_trees'] - alpha_variance = variance_forest_params_updated['alpha'] - beta_variance = variance_forest_params_updated['beta'] - min_samples_leaf_variance = variance_forest_params_updated['min_samples_leaf'] - max_depth_variance = variance_forest_params_updated['max_depth'] - a_0 = variance_forest_params_updated['leaf_prior_calibration_param'] - variance_forest_leaf_init = variance_forest_params_updated['var_forest_leaf_init'] - a_forest = variance_forest_params_updated['var_forest_prior_shape'] - b_forest = variance_forest_params_updated['var_forest_prior_scale'] - keep_vars_variance = variance_forest_params_updated['keep_vars'] - drop_vars_variance = variance_forest_params_updated['drop_vars'] - + num_trees_variance = variance_forest_params_updated["num_trees"] + alpha_variance = variance_forest_params_updated["alpha"] + beta_variance = variance_forest_params_updated["beta"] + min_samples_leaf_variance = variance_forest_params_updated["min_samples_leaf"] + max_depth_variance = variance_forest_params_updated["max_depth"] + a_0 = variance_forest_params_updated["leaf_prior_calibration_param"] + variance_forest_leaf_init = variance_forest_params_updated[ + "var_forest_leaf_init" + ] + a_forest = variance_forest_params_updated["var_forest_prior_shape"] + b_forest = variance_forest_params_updated["var_forest_prior_scale"] + keep_vars_variance = variance_forest_params_updated["keep_vars"] + drop_vars_variance = variance_forest_params_updated["drop_vars"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True - + # Variable weight preprocessing (and initialization if necessary) if variable_weights is None: if X_train.ndim > 1: @@ -378,7 +389,9 @@ def sample( if not isinstance(rfx_group_ids_train, np.ndarray): raise ValueError("rfx_group_ids_train must be a numpy array") if not np.issubdtype(rfx_group_ids_train.dtype, np.integer): - raise ValueError("rfx_group_ids_train must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_train must be a numpy array of integer-valued group IDs" + ) if rfx_basis_train is not None: if not isinstance(rfx_basis_train, np.ndarray): raise ValueError("rfx_basis_train must be a numpy array") @@ -386,7 +399,9 @@ def sample( if not isinstance(rfx_group_ids_test, np.ndarray): raise ValueError("rfx_group_ids_test must be a numpy array") if not np.issubdtype(rfx_group_ids_test.dtype, np.integer): - raise ValueError("rfx_group_ids_test must be a numpy array of integer-valued group IDs") + raise ValueError( + "rfx_group_ids_test must be a numpy array of integer-valued group IDs" + ) if rfx_basis_test is not None: if not isinstance(rfx_basis_test, np.ndarray): raise ValueError("rfx_basis_test must be a numpy array") @@ -470,7 +485,6 @@ def sample( # Set variance leaf model type (currently only one option) leaf_dimension_variance = 1 leaf_model_variance = 3 - self.variance_scale = 1 # Check parameters if sigma_leaf_tau is not None: @@ -1155,11 +1169,17 @@ def sample( current_leaf_scale_tau = np.array([[sigma_leaf_tau]]) elif isinstance(sigma_leaf_tau, np.ndarray): if sigma_leaf_tau.ndim != 2: - raise ValueError("sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf_tau.shape[0] != sigma_leaf_tau.shape[1]: - raise ValueError("sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form") + raise ValueError( + "sigma_leaf_tau must be a 2d symmetric numpy array if provided in matrix form" + ) if sigma_leaf_tau.shape[0] != Z_train.shape[1]: - raise ValueError("sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector") + raise ValueError( + "sigma_leaf_tau must be a 2d numpy array with dimension matching that of the treatment vector" + ) current_leaf_scale_tau = sigma_leaf_tau else: raise ValueError("sigma_leaf_tau must be a scalar or a 2d numpy array") @@ -1173,7 +1193,7 @@ def sample( a_forest = 1.0 if not b_forest: b_forest = 1.0 - + # Runtime checks on RFX group ids self.has_rfx = False has_rfx_test = False @@ -1182,13 +1202,15 @@ def sample( if rfx_group_ids_test is not None: has_rfx_test = True if not np.all(np.isin(rfx_group_ids_test, rfx_group_ids_train)): - raise ValueError("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train") - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + raise ValueError( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided has_basis_rfx = False if self.has_rfx: if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0],1)) + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) else: has_basis_rfx = True num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] @@ -1197,22 +1219,26 @@ def sample( if has_rfx_test: if rfx_basis_test is None: if has_basis_rfx: - raise ValueError("Random effects basis provided for training set, must also be provided for the test set") - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0],1)) - + raise ValueError( + "Random effects basis provided for training set, must also be provided for the test set" + ) + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + # Set up random effects structures if self.has_rfx: if num_rfx_components == 1: alpha_init = np.array([1]) elif num_rfx_components > 1: - alpha_init = np.concatenate((np.ones(1), np.zeros(num_rfx_components-1))) + alpha_init = np.concatenate( + (np.ones(1), np.zeros(num_rfx_components - 1)) + ) else: raise ValueError("There must be at least 1 random effect component") xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) sigma_alpha_init = np.identity(num_rfx_components) sigma_xi_init = np.identity(num_rfx_components) - sigma_xi_shape = 1. - sigma_xi_scale = 1. + sigma_xi_shape = 1.0 + sigma_xi_scale = 1.0 rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) rfx_dataset_train.add_basis(rfx_basis_train) @@ -1225,7 +1251,9 @@ def sample( rfx_model.set_variance_prior_shape(sigma_xi_shape) rfx_model.set_variance_prior_scale(sigma_xi_scale) self.rfx_container = RandomEffectsContainer() - self.rfx_container.load_new_container(num_rfx_components, num_rfx_groups, rfx_tracker) + self.rfx_container.load_new_container( + num_rfx_components, num_rfx_groups, rfx_tracker + ) # Update variable weights variable_counts = [original_var_indices.count(i) for i in original_var_indices] @@ -1555,12 +1583,12 @@ def sample( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)), size=1, - ) + )[0] current_b_1 = self.rng.normal( loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), size=1, - ) + )[0] tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 @@ -1614,11 +1642,17 @@ def sample( self.leaf_scale_tau_samples[sample_counter] = ( current_leaf_scale_tau[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Run MCMC @@ -1703,12 +1737,12 @@ def sample( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)), size=1, - ) + )[0] current_b_1 = self.rng.normal( loc=(s_ty1 / (s_tt1 + 2 * current_sigma2)), scale=np.sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)), size=1, - ) + )[0] tau_basis_train = ( 1 - np.squeeze(Z_train) ) * current_b_0 + np.squeeze(Z_train) * current_b_1 @@ -1762,11 +1796,17 @@ def sample( self.leaf_scale_tau_samples[sample_counter] = ( current_leaf_scale_tau[0, 0] ) - + # Sample random effects if self.has_rfx: rfx_model.sample( - rfx_dataset_train, residual_train, rfx_tracker, self.rfx_container, keep_sample, current_sigma2, cpp_rng + rfx_dataset_train, + residual_train, + rfx_tracker, + self.rfx_container, + keep_sample, + current_sigma2, + cpp_rng, ) # Mark the model as sampled @@ -1839,13 +1879,19 @@ def sample( # TODO: make rfx_preds_train and rfx_preds_test persistent properties if self.has_rfx: - rfx_preds_train = self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) * self.y_std + rfx_preds_train = ( + self.rfx_container.predict(rfx_group_ids_train, rfx_basis_train) + * self.y_std + ) if has_rfx_test: - rfx_preds_test = self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) * self.y_std + rfx_preds_test = ( + self.rfx_container.predict(rfx_group_ids_test, rfx_basis_test) + * self.y_std + ) self.y_hat_train = self.y_hat_train + rfx_preds_train if self.has_test: self.y_hat_test = self.y_hat_test + rfx_preds_test - + if self.include_variance_forest: sigma2_x_train_raw = ( self.forest_container_variance.forest_container_cpp.Predict( @@ -1860,11 +1906,7 @@ def sample( ) else: self.sigma2_x_train = ( - sigma2_x_train_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + sigma2_x_train_raw * self.sigma2_init * self.y_std * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -1880,11 +1922,7 @@ def sample( ) else: self.sigma2_x_test = ( - sigma2_x_test_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) if self.sample_sigma_global: @@ -1970,16 +2008,16 @@ def predict_tau( "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", RuntimeWarning, ) - if not np.issubdtype( - X.dtype, np.floating - ) and not np.issubdtype(X.dtype, np.integer): + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) covariates_processed = X else: covariates_processed = self._covariate_preprocessor.transform(X) - + # Update covariates to include propensities if requested if self.propensity_covariate == "none": X_combined = covariates_processed @@ -2064,7 +2102,7 @@ def predict_variance( covariates_processed = covariates else: covariates_processed = self._covariate_preprocessor.transform(covariates) - + # Update covariates to include propensities if requested if self.propensity_covariate == "none": X_combined = covariates_processed @@ -2092,16 +2130,19 @@ def predict_variance( ) else: variance_pred = ( - variance_pred_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale + variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) return variance_pred - def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> tuple: + def predict( + self, + X: np.array, + Z: np.array, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> tuple: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2167,7 +2208,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro propensity = np.mean( self.bart_propensity_model.predict(X), axis=1, keepdims=True ) - + # Covariate preprocessing if not self._covariate_preprocessor._check_is_fitted(): if not isinstance(X, np.ndarray): @@ -2179,9 +2220,9 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", RuntimeWarning, ) - if not np.issubdtype( - X.dtype, np.floating - ) and not np.issubdtype(X.dtype, np.integer): + if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( + X.dtype, np.integer + ): raise ValueError( "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." ) @@ -2204,7 +2245,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro mu_raw = self.forest_container_mu.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) - mu_x = mu_raw * self.y_std / np.sqrt(self.variance_scale) + self.y_bar + mu_x = mu_raw * self.y_std + self.y_bar tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( forest_dataset_test.dataset_cpp ) @@ -2213,7 +2254,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro self.b1_samples - self.b0_samples, axis=(0, 2) ) tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std / np.sqrt(self.variance_scale)) + tau_x = np.squeeze(tau_raw * self.y_std) if Z.shape[1] > 1: treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( axis=2 @@ -2221,9 +2262,11 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro else: treatment_term = Z * np.squeeze(tau_x) yhat_x = mu_x + treatment_term - + if self.has_rfx: - rfx_preds = self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) yhat_x = yhat_x + rfx_preds # Compute predictions from the variance forest (if included) @@ -2236,13 +2279,7 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro for i in range(self.num_samples): sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] else: - sigma2_x = ( - sigma2_x_raw - * self.sigma2_init - * self.y_std - * self.y_std - / self.variance_scale - ) + sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std # Return result matrices as a tuple if self.has_rfx and self.include_variance_forest: @@ -2285,7 +2322,6 @@ def to_json(self) -> str: bcf_json.add_random_effects(self.rfx_container) # Add global parameters - bcf_json.add_scalar("variance_scale", self.variance_scale) bcf_json.add_scalar("outcome_scale", self.y_std) bcf_json.add_scalar("outcome_mean", self.y_bar) bcf_json.add_boolean("standardize", self.standardize) @@ -2365,14 +2401,13 @@ def from_json(self, json_string: str) -> None: self.forest_container_variance.forest_container_cpp.LoadFromJson( bcf_json.json_cpp, "forest_2" ) - + # Unpack random effects if self.has_rfx: self.rfx_container = RandomEffectsContainer() self.rfx_container.load_from_json(bcf_json, 0) # Unpack global parameters - self.variance_scale = bcf_json.get_scalar("variance_scale") self.y_std = bcf_json.get_scalar("outcome_scale") self.y_bar = bcf_json.get_scalar("outcome_mean") self.standardize = bcf_json.get_boolean("standardize") @@ -2421,6 +2456,162 @@ def from_json(self, json_string: str) -> None: # Mark the deserialized model as "sampled" self.sampled = True + def from_json_string_list(self, json_string_list: list[str]) -> None: + """ + Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object + which can be used for prediction, etc... + + Parameters + ------- + json_string_list : list of str + List of JSON strings which can be parsed to objects of type `JSONSerializer` containing Json representation of a BCF model + """ + # Convert strings to JSONSerializer + json_object_list = [] + for i in range(len(json_string_list)): + json_string = json_string_list[i] + json_object_list.append(JSONSerializer()) + json_object_list[i].load_from_json_string(json_string) + + # For scalar / preprocessing details which aren't sample-dependent, defer to the first json + json_object_default = json_object_list[0] + + # Unpack forests + # Mu forest + self.forest_container_mu = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_mu.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + else: + self.forest_container_mu.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_0" + ) + # Tau forest + self.forest_container_tau = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_tau.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + else: + self.forest_container_tau.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_1" + ) + self.include_variance_forest = json_object_default.get_boolean( + "include_variance_forest" + ) + if self.include_variance_forest: + # TODO: don't just make this a placeholder that we overwrite + self.forest_container_variance = ForestContainer(0, 0, False, False) + for i in range(len(json_object_list)): + if i == 0: + self.forest_container_variance.forest_container_cpp.LoadFromJson( + json_object_list[i].json_cpp, "forest_2" + ) + else: + self.forest_container_variance.forest_container_cpp.AppendFromJson( + json_object_list[i].json_cpp, "forest_2" + ) + + # Unpack random effects + self.has_rfx = json_object_default.get_boolean("has_rfx") + if self.has_rfx: + self.rfx_container = RandomEffectsContainer() + for i in range(len(json_object_list)): + if i == 0: + self.rfx_container.load_from_json(json_object_list[i], 0) + else: + self.rfx_container.append_from_json(json_object_list[i], 0) + + # Unpack global parameters + self.y_std = json_object_default.get_scalar("outcome_scale") + self.y_bar = json_object_default.get_scalar("outcome_mean") + self.standardize = json_object_default.get_boolean("standardize") + self.sigma2_init = json_object_default.get_scalar("sigma2_init") + self.sample_sigma_global = json_object_default.get_boolean( + "sample_sigma_global" + ) + self.sample_sigma_leaf_mu = json_object_default.get_boolean( + "sample_sigma_leaf_mu" + ) + self.sample_sigma_leaf_tau = json_object_default.get_boolean( + "sample_sigma_leaf_tau" + ) + self.num_gfr = json_object_default.get_scalar("num_gfr") + self.num_burnin = json_object_default.get_scalar("num_burnin") + self.num_mcmc = json_object_default.get_scalar("num_mcmc") + self.num_samples = json_object_default.get_scalar("num_samples") + self.adaptive_coding = json_object_default.get_boolean("adaptive_coding") + self.propensity_covariate = json_object_default.get_string( + "propensity_covariate" + ) + self.internal_propensity_model = json_object_default.get_boolean( + "internal_propensity_model" + ) + + # Unpack parameter samples + if self.sample_sigma_global: + for i in range(len(json_object_list)): + if i == 0: + self.global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + else: + global_var_samples = json_object_list[i].get_numeric_vector( + "sigma2_global_samples", "parameters" + ) + self.global_var_samples = np.concatenate( + (self.global_var_samples, global_var_samples) + ) + + if self.sample_sigma_leaf_mu: + for i in range(len(json_object_list)): + if i == 0: + self.leaf_scale_mu_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_mu_samples", "parameters" + ) + else: + leaf_scale_mu_samples = json_object_list[i].get_numeric_vector( + "sigma2_leaf_mu_samples", "parameters" + ) + self.leaf_scale_mu_samples = np.concatenate( + (self.leaf_scale_mu_samples, leaf_scale_mu_samples) + ) + + if self.sample_sigma_leaf_tau: + for i in range(len(json_object_list)): + if i == 0: + self.sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + "sigma2_leaf_tau_samples", "parameters" + ) + else: + sample_sigma_leaf_tau = json_object_list[i].get_numeric_vector( + "sigma2_leaf_tau_samples", "parameters" + ) + self.sample_sigma_leaf_tau = np.concatenate( + (self.sample_sigma_leaf_tau, sample_sigma_leaf_tau) + ) + + # Unpack internal propensity model + if self.internal_propensity_model: + bart_propensity_string = json_object_default.get_string( + "bart_propensity_model" + ) + self.bart_propensity_model = BARTModel() + self.bart_propensity_model.from_json(bart_propensity_string) + + # Unpack covariate preprocessor + covariate_preprocessor_string = json_object_default.get_string( + "covariate_preprocessor" + ) + self._covariate_preprocessor = CovariatePreprocessor() + self._covariate_preprocessor.from_json(covariate_preprocessor_string) + + # Mark the deserialized model as "sampled" + self.sampled = True + def is_sampled(self) -> bool: """Whether or not a BCF model has been sampled. diff --git a/stochtree/forest.py b/stochtree/forest.py index 809a3908..ba5ff9c3 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -206,6 +206,17 @@ def load_from_json_string(self, json_string: str) -> None: """ self.forest_container_cpp.LoadFromJsonString(json_string) + def load_from_json_object(self, json_object) -> None: + """ + Reload a forest container from an in-memory JSONSerializer object. + + Parameters + ---------- + json_object : JSONSerializer + In-memory JSONSerializer object. + """ + self.forest_container_cpp.LoadFromJsonObject(json_object) + def add_sample(self, leaf_value: Union[float, np.array]) -> None: """ Add a new all-root ensemble to the container, with all of the leaves set to the value / vector provided @@ -903,8 +914,13 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: if isinstance(leaf_value, np.ndarray): if len(leaf_value.shape) > 1: leaf_value = np.squeeze(leaf_value) - if len(leaf_value.shape) != 1 or leaf_value.shape[0] != self.output_dimension: - raise ValueError("leaf_value must be a one-dimensional array with dimension equal to the output_dimension field of the forest") + if ( + len(leaf_value.shape) != 1 + or leaf_value.shape[0] != self.output_dimension + ): + raise ValueError( + "leaf_value must be a one-dimensional array with dimension equal to the output_dimension field of the forest" + ) if leaf_value.shape[0] > 1: self.forest_cpp.SetRootVector(leaf_value, leaf_value.shape[0]) else: @@ -1368,13 +1384,13 @@ def leaves(self, tree_num: int) -> np.array: def is_empty(self) -> bool: """ When a Forest object is created, it is "empty" in the sense that none - of its component trees have leaves with values. There are two ways to + of its component trees have leaves with values. There are two ways to "initialize" a Forest object. First, the `set_root_leaves()` method of the - Forest class simply initializes every tree in the forest to a single node - carrying the same (user-specified) leaf value. Second, the `prepare_for_sampler()` - method of the ForestSampler class initializes every tree in the forest to a - single node with the same value and also propagates this information through - to the temporary tracking data structrues in a ForestSampler object, which + Forest class simply initializes every tree in the forest to a single node + carrying the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + method of the ForestSampler class initializes every tree in the forest to a + single node with the same value and also propagates this information through + to the temporary tracking data structrues in a ForestSampler object, which must be synchronized with a Forest during a forest sampler loop. Returns diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 0badcbb3..18144044 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -177,8 +177,10 @@ class RandomEffectsContainer: def __init__(self) -> None: pass - - def load_new_container(self, num_components: int, num_groups: int, rfx_tracker: RandomEffectsTracker) -> None: + + def load_new_container( + self, num_components: int, num_groups: int, rfx_tracker: RandomEffectsTracker + ) -> None: """ Initializes internal data structures for an "empty" random effects container to be sampled and populated. @@ -198,7 +200,7 @@ def load_new_container(self, num_components: int, num_groups: int, rfx_tracker: self.rfx_label_mapper_cpp = RandomEffectsLabelMapperCpp() self.rfx_label_mapper_cpp.LoadFromTracker(rfx_tracker.rfx_tracker_cpp) self.rfx_group_ids = rfx_tracker.rfx_tracker_cpp.GetUniqueGroupIds() - + def load_from_json(self, json, rfx_num: int) -> None: """ Initializes internal data structures for an "empty" random effects container to be sampled and populated. @@ -210,14 +212,30 @@ def load_from_json(self, json, rfx_num: int) -> None: rfx_num : int Integer index of the RFX term in a JSON model. In practice, this is typically 0 (most models don't contain two RFX terms). """ - rfx_container_key = f'random_effect_container_{rfx_num:d}' - rfx_label_mapper_key = f'random_effect_label_mapper_{rfx_num:d}' - rfx_group_ids_key = f'random_effect_groupids_{rfx_num:d}' + rfx_container_key = f"random_effect_container_{rfx_num:d}" + rfx_label_mapper_key = f"random_effect_label_mapper_{rfx_num:d}" + rfx_group_ids_key = f"random_effect_groupids_{rfx_num:d}" self.rfx_container_cpp = RandomEffectsContainerCpp() self.rfx_container_cpp.LoadFromJson(json.json_cpp, rfx_container_key) self.rfx_label_mapper_cpp = RandomEffectsLabelMapperCpp() self.rfx_label_mapper_cpp.LoadFromJson(json.json_cpp, rfx_label_mapper_key) - self.rfx_group_ids = json.get_integer_vector(rfx_group_ids_key, "random_effects") + self.rfx_group_ids = json.get_integer_vector( + rfx_group_ids_key, "random_effects" + ) + + def append_from_json(self, json, rfx_num: int) -> None: + """ + Initializes internal data structures for an "empty" random effects container to be sampled and populated. + + Parameters + ---------- + json : JSONSerializer + Python object wrapping a C++ `json` object. + rfx_num : int + Integer index of the RFX term in a JSON model. In practice, this is typically 0 (most models don't contain two RFX terms). + """ + rfx_container_key = f"random_effect_container_{rfx_num:d}" + self.rfx_container_cpp.AppendFromJson(json.json_cpp, rfx_container_key) def num_samples(self) -> int: return self.rfx_container_cpp.NumSamples() diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 5b1415bd..a2f2e64c 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -41,12 +41,12 @@ def outcome_mean(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -61,6 +61,43 @@ def outcome_mean(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict(covariates=X_train) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_univariate_leaf_regression_homoskedastic(self): # RNG random_seed = 101 @@ -105,12 +142,12 @@ def outcome_mean(X, W): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -127,6 +164,47 @@ def outcome_mean(X, W): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_multivariate_leaf_regression_homoskedastic(self): # RNG random_seed = 101 @@ -171,12 +249,12 @@ def outcome_mean(X, W): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() bart_model.sample( X_train=X_train, @@ -193,6 +271,47 @@ def outcome_mean(X, W): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_constant_leaf_heteroskedastic(self): # RNG random_seed = 101 @@ -240,12 +359,12 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -264,6 +383,45 @@ def conditional_stddev(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict(covariates=X_train) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_univariate_leaf_regression_heteroskedastic(self): # RNG random_seed = 101 @@ -319,12 +477,12 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -345,6 +503,49 @@ def conditional_stddev(X): assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + def test_bart_multivariate_leaf_regression_heteroskedastic(self): # RNG random_seed = 101 @@ -400,12 +601,321 @@ def conditional_stddev(X): n_train = X_train.shape[0] n_test = X_test.shape[0] - # BCF settings + # BART settings + num_gfr = 10 + num_burnin = 0 + num_mcmc = 10 + + # Run BART with test set and propensity score + bart_model = BARTModel() + general_params = {"sample_sigma2_global": True} + variance_forest_params = {"num_trees": 50} + bart_model.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, basis=basis_train + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + + def test_bart_constant_leaf_heteroskedastic_rfx(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 10 + X = rng.uniform(0, 1, (n, p_X)) + + # Generate RFX group labels and basis term + num_rfx_basis = 1 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + rfx_basis = np.empty((n, num_rfx_basis)) + rfx_basis[:, 0] = 1.0 + if num_rfx_basis > 1: + rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the outcome mean function + def outcome_mean(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5, 7.5), + ), + ) + + # Define the conditional standard deviation function + def conditional_stddev(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + 0.25, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + 0.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2), + ), + ) + + # Define the group rfx function + def rfx_term(group_labels, basis): + return np.where( + group_labels == 0, + 0, + np.where(group_labels == 1, 4, np.where(group_labels == 2, 8, 12)), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = ( + outcome_mean(X) + + rfx_term(group_labels, rfx_basis) + + epsilon * conditional_stddev(X) + ) + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + group_labels_train = group_labels[train_inds] + group_labels_test = group_labels[test_inds] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + y_train = y[train_inds] + n_train = X_train.shape[0] + n_test = X_test.shape[0] + + # BART settings + num_gfr = 10 + num_burnin = 0 + num_mcmc = 10 + + # Run BART with test set and propensity score + bart_model = BARTModel() + general_params = {"sample_sigma2_global": True} + variance_forest_params = {"num_trees": 50} + bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + rfx_preds_train = bart_model.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + general_params=general_params, + variance_forest_params=variance_forest_params, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + ) + rfx_preds_train_2 = bart_model_2.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + rfx_preds_train_3 = bart_model_3.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, + rfx_group_ids=group_labels_train, + rfx_basis=rfx_basis_train, + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) + np.testing.assert_allclose( + rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 + ) + + def test_bart_univariate_leaf_regression_heteroskedastic_rfx(self): + # RNG + random_seed = 101 + rng = np.random.default_rng(random_seed) + + # Generate covariates and basis + n = 100 + p_X = 10 + p_W = 1 + X = rng.uniform(0, 1, (n, p_X)) + W = rng.uniform(0, 1, (n, p_W)) + + # Generate RFX group labels and basis term + num_rfx_basis = 1 + num_rfx_groups = 4 + group_labels = rng.choice(num_rfx_groups, size=n) + rfx_basis = np.empty((n, num_rfx_basis)) + rfx_basis[:, 0] = 1.0 + if num_rfx_basis > 1: + rfx_basis[:, 1:] = rng.uniform(-1, 1, (n, num_rfx_basis - 1)) + + # Define the outcome mean function + def outcome_mean(X, W): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + -7.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + -2.5 * W[:, 0], + np.where( + (X[:, 0] >= 0.5) & (X[:, 0] < 0.75), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), + ) + + # Define the group rfx function + def rfx_term(group_labels, basis): + return np.where( + group_labels == 0, + 0, + np.where(group_labels == 1, 4, np.where(group_labels == 2, 8, 12)), + ) + + # Define the conditional standard deviation function + def conditional_stddev(X): + return np.where( + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), + 0.25, + np.where( + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), + 0.5, + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 1, 2), + ), + ) + + # Generate outcome + epsilon = rng.normal(0, 1, n) + y = ( + outcome_mean(X, W) + + rfx_term(group_labels, rfx_basis) + + epsilon * conditional_stddev(X) + ) + + # Test-train split + sample_inds = np.arange(n) + train_inds, test_inds = train_test_split(sample_inds, test_size=0.5) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + basis_train = W[train_inds, :] + basis_test = W[test_inds, :] + group_labels_train = group_labels[train_inds] + group_labels_test = group_labels[test_inds] + rfx_basis_train = rfx_basis[train_inds, :] + rfx_basis_test = rfx_basis[test_inds, :] + y_train = y[train_inds] + n_train = X_train.shape[0] + n_test = X_test.shape[0] + + # BART settings num_gfr = 10 num_burnin = 0 num_mcmc = 10 - # Run BCF with test set and propensity score + # Run BART with test set and propensity score bart_model = BARTModel() general_params = {"sample_sigma2_global": True} variance_forest_params = {"num_trees": 50} @@ -413,15 +923,82 @@ def conditional_stddev(X): X_train=X_train, y_train=y_train, leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, X_test=X_test, leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, general_params=general_params, variance_forest_params=variance_forest_params, ) + rfx_preds_train = bart_model.rfx_container.predict( + group_labels_train, rfx_basis_train + ) # Assertions assert bart_model.y_hat_train.shape == (n_train, num_mcmc) assert bart_model.y_hat_test.shape == (n_test, num_mcmc) + + # Run second BART model + bart_model_2 = BARTModel() + bart_model_2.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + rfx_basis_train=rfx_basis_train, + X_test=X_test, + leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + rfx_basis_test=rfx_basis_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + general_params=general_params, + variance_forest_params=variance_forest_params, + ) + rfx_preds_train_2 = bart_model_2.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + assert bart_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bart_model_2.y_hat_test.shape == (n_test, num_mcmc) + + # Combine into a single model + bart_models_json = [bart_model.to_json(), bart_model_2.to_json()] + bart_model_3 = BARTModel() + bart_model_3.from_json_string_list(bart_models_json) + rfx_preds_train_3 = bart_model_3.rfx_container.predict( + group_labels_train, rfx_basis_train + ) + + # Assertions + y_hat_train_combined, _ = bart_model_3.predict( + covariates=X_train, + basis=basis_train, + rfx_group_ids=group_labels_train, + rfx_basis=rfx_basis_train, + ) + assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose( + y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train + ) + np.testing.assert_allclose( + y_hat_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.y_hat_train + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples + ) + np.testing.assert_allclose( + bart_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bart_model_2.global_var_samples, + ) + np.testing.assert_allclose(rfx_preds_train_3[:, 0:num_mcmc], rfx_preds_train) + np.testing.assert_allclose( + rfx_preds_train_3[:, num_mcmc : (2 * num_mcmc)], rfx_preds_train_2 + ) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index dc6fe162..96f25a34 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -244,6 +244,65 @@ def test_continuous_univariate_bcf(self): tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) assert tau_hat.shape == (n_test, num_mcmc) + # Run second BCF model with test set and propensity score + bcf_model_2 = BCFModel() + variance_forest_params = {"num_trees": 0} + bcf_model_2.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.tau_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.y_hat_test.shape == (n_test, num_mcmc) + assert bcf_model_2.mu_hat_test.shape == (n_test, num_mcmc) + assert bcf_model_2.tau_hat_test.shape == (n_test, num_mcmc) + + # Check overall prediction method + tau_hat_2, mu_hat_2, y_hat_2 = bcf_model_2.predict(X_test, Z_test, pi_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + assert mu_hat_2.shape == (n_test, num_mcmc) + assert y_hat_2.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test, pi_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + + # Combine into a single model + bcf_models_json = [bcf_model.to_json(), bcf_model_2.to_json()] + bcf_model_3 = BCFModel() + bcf_model_3.from_json_string_list(bcf_models_json) + + # Assertions + tau_hat_3, mu_hat_3, y_hat_3 = bcf_model_3.predict(X_test, Z_test, pi_test) + assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert mu_hat_3.shape == (n_train, num_mcmc * 2) + assert y_hat_3.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) + np.testing.assert_allclose(y_hat_3[:, num_mcmc : (2 * num_mcmc)], y_hat_2) + np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) + np.testing.assert_allclose(mu_hat_3[:, num_mcmc : (2 * num_mcmc)], mu_hat_2) + np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose(tau_hat_3[:, num_mcmc : (2 * num_mcmc)], tau_hat_2) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples + ) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bcf_model_2.global_var_samples, + ) + # Run BCF without test set and with propensity score bcf_model = BCFModel() variance_forest_params = {"num_trees": 0} @@ -336,6 +395,55 @@ def test_continuous_univariate_bcf(self): # Check treatment effect prediction method tau_hat = bcf_model.predict_tau(X_test, Z_test) + # Run second BCF model with test set and propensity score + bcf_model_2 = BCFModel() + variance_forest_params = {"num_trees": 0} + bcf_model_2.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + variance_forest_params=variance_forest_params, + ) + + # Assertions + assert bcf_model_2.y_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.mu_hat_train.shape == (n_train, num_mcmc) + assert bcf_model_2.tau_hat_train.shape == (n_train, num_mcmc) + + # Check overall prediction method + tau_hat_2, mu_hat_2, y_hat_2 = bcf_model_2.predict(X_test, Z_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + assert mu_hat_2.shape == (n_test, num_mcmc) + assert y_hat_2.shape == (n_test, num_mcmc) + + # Check treatment effect prediction method + tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test) + assert tau_hat_2.shape == (n_test, num_mcmc) + + # Combine into a single model + bcf_models_json = [bcf_model.to_json(), bcf_model_2.to_json()] + bcf_model_3 = BCFModel() + bcf_model_3.from_json_string_list(bcf_models_json) + + # Assertions + tau_hat_3, mu_hat_3, y_hat_3 = bcf_model_3.predict(X_test, Z_test) + assert tau_hat_3.shape == (n_train, num_mcmc * 2) + assert mu_hat_3.shape == (n_train, num_mcmc * 2) + assert y_hat_3.shape == (n_train, num_mcmc * 2) + np.testing.assert_allclose(y_hat_3[:, 0:num_mcmc], y_hat) + np.testing.assert_allclose(mu_hat_3[:, 0:num_mcmc], mu_hat) + np.testing.assert_allclose(tau_hat_3[:, 0:num_mcmc], tau_hat) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[0:num_mcmc], bcf_model.global_var_samples + ) + np.testing.assert_allclose( + bcf_model_3.global_var_samples[num_mcmc : (2 * num_mcmc)], + bcf_model_2.global_var_samples, + ) + def test_multivariate_bcf(self): # RNG random_seed = 101