Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
* ``pct_var_variance_forest_init`` (``float``): Percentage of standardized outcome variance used to initialize global error variance parameter. Default: ``1``. Superseded by ``variance_forest_init``.
* ``variance_scale`` (``float``): Variance after the data have been scaled. Default: ``1``.
* ``variable_weights_mean`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the mean forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided.
* ``variable_weights_forest`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided.
* ``variable_weights_variance`` (``np.array``): Numeric weights reflecting the relative probability of splitting on each variable in the variance forest. Does not need to sum to 1 but cannot be negative. Defaults to uniform over the columns of ``X_train`` if not provided.
* ``num_trees_mean`` (``int``): Number of trees in the ensemble for the conditional mean model. Defaults to ``200``. If ``num_trees_mean = 0``, the conditional mean will not be modeled using a forest and the function will only proceed if ``num_trees_variance > 0``.
* ``num_trees_variance`` (``int``): Number of trees in the ensemble for the conditional variance model. Defaults to ``0``. Variance is only modeled using a tree / forest if ``num_trees_variance > 0``.
* ``sample_sigma_global`` (``bool``): Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``.
Expand Down Expand Up @@ -484,19 +484,19 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
if self.include_variance_forest:
sigma_x_train_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_train.dataset_cpp)
if self.sample_sigma_global:
self.sigma_x_train = sigma_x_train_raw
self.sigma2_x_train = sigma_x_train_raw
for i in range(self.num_samples):
self.sigma_x_train[:,i] = np.sqrt(sigma_x_train_raw[:,i]*self.global_var_samples[i])
self.sigma2_x_train[:,i] = sigma_x_train_raw[:,i]*self.global_var_samples[i]
else:
self.sigma_x_train = np.sqrt(sigma_x_train_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale)
self.sigma2_x_train = sigma_x_train_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale
if self.has_test:
sigma_x_test_raw = self.forest_container_variance.forest_container_cpp.Predict(forest_dataset_test.dataset_cpp)
if self.sample_sigma_global:
self.sigma_x_test = sigma_x_test_raw
self.sigma2_x_test = sigma_x_test_raw
for i in range(self.num_samples):
self.sigma_x_test[:,i] = np.sqrt(sigma_x_test_raw[:,i]*self.global_var_samples[i])
self.sigma2_x_test[:,i] = sigma_x_test_raw[:,i]*self.global_var_samples[i]
else:
self.sigma_x_test = np.sqrt(sigma_x_test_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale)
self.sigma2_x_test = sigma_x_test_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale

def predict(self, covariates: np.array, basis: np.array = None) -> np.array:
"""Return predictions from every forest sampled (either / both of mean and variance)
Expand Down Expand Up @@ -605,20 +605,18 @@ def predict_mean(self, covariates: np.array, basis: np.array = None) -> np.array

return mean_pred

def predict_variance(self, covariates: np.array, basis: np.array = None) -> np.array:
def predict_variance(self, covariates: np.array) -> np.array:
"""Predict expected conditional variance from a BART model.

Parameters
----------
covariates : np.array
Test set covariates.
basis_train : :obj:`np.array`, optional
Optional test set basis vector, must be provided if the model was trained with a leaf regression basis.

Returns
-------
tuple of :obj:`np.array`
Tuple of arrays of predictions corresponding to each forest (mean and variance, depending on whether either / both was included). Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm.
Tuple of arrays of predictions corresponding to the variance forest. Each array will contain as many rows as in ``covariates`` and as many columns as retained samples of the algorithm.
"""
if not self.is_sampled():
msg = (
Expand All @@ -637,26 +635,16 @@ def predict_variance(self, covariates: np.array, basis: np.array = None) -> np.a
# Convert everything to standard shape (2-dimensional)
if covariates.ndim == 1:
covariates = np.expand_dims(covariates, 1)
if basis is not None:
if basis.ndim == 1:
basis = np.expand_dims(basis, 1)

# Data checks
if basis is not None:
if basis.shape[0] != covariates.shape[0]:
raise ValueError("covariates and basis must have the same number of rows")

pred_dataset = Dataset()
pred_dataset.add_covariates(covariates)
# if basis is not None:
# pred_dataset.add_basis(basis)
variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict(pred_dataset.dataset_cpp)
if self.sample_sigma_global:
variance_pred = variance_pred_raw
for i in range(self.num_samples):
variance_pred[:,i] = np.sqrt(variance_pred_raw[:,i]*self.global_var_samples[i])
variance_pred[:,i] = variance_pred_raw[:,i]*self.global_var_samples[i]
else:
variance_pred = np.sqrt(variance_pred_raw*self.sigma2_init)*self.y_std/np.sqrt(self.variance_scale)
variance_pred = variance_pred_raw*self.sigma2_init*self.y_std*self.y_std/self.variance_scale

return variance_pred

Expand Down
Loading
Loading