Skip to content

Commit 9a4b46b

Browse files
committed
Updated python initialization and prior setting for global error variance
1 parent 376eb84 commit 9a4b46b

File tree

4 files changed

+104
-107
lines changed

4 files changed

+104
-107
lines changed

stochtree/bart.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def is_sampled(self) -> bool:
2727

2828
def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = None, X_test: np.array = None, basis_test: np.array = None,
2929
cutpoint_grid_size = 100, sigma_leaf: float = None, alpha: float = 0.95, beta: float = 2.0, min_samples_leaf: int = 5, max_depth: int = 10,
30-
nu: float = 3, lamb: float = None, a_leaf: float = 3, b_leaf: float = None, q: float = 0.9, sigma2: float = None,
31-
num_trees: int = 200, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100, sample_sigma_global: bool = True,
32-
sample_sigma_leaf: bool = True, random_seed: int = -1, keep_burnin: bool = False, keep_gfr: bool = False) -> None:
30+
a_global: float = 0, b_global: float = 0, a_leaf: float = 3, b_leaf: float = None, q: float = 0.9, sigma2: float = None,
31+
pct_var_sigma2_init: float = 0.25, num_trees: int = 200, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100,
32+
sample_sigma_global: bool = True, sample_sigma_leaf: bool = True, random_seed: int = -1, keep_burnin: bool = False, keep_gfr: bool = False) -> None:
3333
"""Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set.
3434
Does not require a leaf regression basis.
3535
@@ -60,18 +60,20 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
6060
Minimum allowable size of a leaf, in terms of training samples. Defaults to ``5``.
6161
max_depth : :obj:`int`, optional
6262
Maximum depth of any tree in the ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
63-
nu : :obj:`float`, optional
64-
Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``.
65-
lamb : :obj:`float`, optional
66-
Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
63+
a_global : :obj:`float`, optional
64+
Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``.
65+
b_global : :obj:`float`, optional
66+
Component of the scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``.
6767
a_leaf : :obj:`float`, optional
6868
Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model. Defaults to ``3``.
6969
b_leaf : :obj:`float`, optional
7070
Scale parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model. Calibrated internally as ``0.5/num_trees`` if not set here.
7171
q : :obj:`float`, optional
7272
Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``.
7373
sigma2 : :obj:`float`, optional
74-
Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
74+
Starting value of global variance parameter. Set internally as a percentage of the standardized outcome variance if not set here.
75+
pct_var_sigma2_init : :obj:`float`, optional
76+
Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``0.25``.
7577
num_trees : :obj:`int`, optional
7678
Number of trees in the ensemble. Defaults to ``200``.
7779
num_gfr : :obj:`int`, optional
@@ -81,7 +83,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
8183
num_mcmc : :obj:`int`, optional
8284
Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained.
8385
sample_sigma_global : :obj:`bool`, optional
84-
Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(nu, nu*lambda)``. Defaults to ``True``.
86+
Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``.
8587
sample_sigma_leaf : :obj:`bool`, optional
8688
Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)``. Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``.
8789
random_seed : :obj:`int`, optional
@@ -176,10 +178,8 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
176178
resid_train = (y_train-self.y_bar)/self.y_std
177179

178180
# Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART)
179-
if num_gfr > 0:
180-
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False)
181-
else:
182-
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True)
181+
if not sigma2:
182+
sigma2 = pct_var_sigma2_init*np.var(resid_train)
183183
b_leaf = np.squeeze(np.var(resid_train)) / num_trees if b_leaf is None else b_leaf
184184
sigma_leaf = np.squeeze(np.var(resid_train)) / num_trees if sigma_leaf is None else sigma_leaf
185185
current_sigma2 = sigma2
@@ -254,7 +254,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
254254

255255
# Sample variance parameters (if requested)
256256
if self.sample_sigma_global:
257-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
257+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
258258
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
259259
if self.sample_sigma_leaf:
260260
self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container, cpp_rng, a_leaf, b_leaf, i)
@@ -275,7 +275,7 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
275275

276276
# Sample variance parameters (if requested)
277277
if self.sample_sigma_global:
278-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
278+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
279279
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
280280
if self.sample_sigma_leaf:
281281
self.leaf_scale_samples[i] = leaf_var_model.sample_one_iteration(self.forest_container, cpp_rng, a_leaf, b_leaf, i)

stochtree/bcf.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
3434
cutpoint_grid_size = 100, sigma_leaf_mu: float = None, sigma_leaf_tau: float = None,
3535
alpha_mu: float = 0.95, alpha_tau: float = 0.25, beta_mu: float = 2.0, beta_tau: float = 3.0,
3636
min_samples_leaf_mu: int = 5, min_samples_leaf_tau: int = 5, max_depth_mu: int = 10, max_depth_tau: int = 5,
37-
nu: float = 3, lamb: float = None, a_leaf_mu: float = 3, a_leaf_tau: float = 3, b_leaf_mu: float = None, b_leaf_tau: float = None,
38-
q: float = 0.9, sigma2: float = None, variable_weights: np.array = None,
37+
a_global: float = 0, b_global: float = 0, a_leaf_mu: float = 3, a_leaf_tau: float = 3,
38+
b_leaf_mu: float = None, b_leaf_tau: float = None, q: float = 0.9, sigma2: float = None,
39+
pct_var_sigma2_init: float = 0.25, variable_weights: np.array = None,
3940
keep_vars_mu: Union[list, np.array] = None, drop_vars_mu: Union[list, np.array] = None,
4041
keep_vars_tau: Union[list, np.array] = None, drop_vars_tau: Union[list, np.array] = None,
4142
num_trees_mu: int = 200, num_trees_tau: int = 50, num_gfr: int = 5, num_burnin: int = 0, num_mcmc: int = 100,
@@ -93,10 +94,10 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
9394
Maximum depth of any tree in the mu ensemble. Defaults to ``10``. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
9495
max_depth_tau : :obj:`int`, optional
9596
Maximum depth of any tree in the tau ensemble. Defaults to ``5``. Can be overriden with ``-1`` which does not enforce any depth limits on trees.
96-
nu : :obj:`float`, optional
97-
Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``.
98-
lamb : :obj:`float`, optional
99-
Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
97+
a_global : :obj:`float`, optional
98+
Shape parameter in the ``IG(a_global, b_global)`` global error variance model. Defaults to ``0``.
99+
b_global : :obj:`float`, optional
100+
Component of the scale parameter in the ``IG(a_global, b_global)`` global error variance prior. Defaults to ``0``.
100101
a_leaf_mu : :obj:`float`, optional
101102
Shape parameter in the ``IG(a_leaf, b_leaf)`` leaf node parameter variance model for the prognostic forest. Defaults to ``3``.
102103
a_leaf_tau : :obj:`float`, optional
@@ -109,6 +110,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
109110
Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``.
110111
sigma2 : :obj:`float`, optional
111112
Starting value of global variance parameter. Calibrated internally as in Sparapani et al (2021) if not set here.
113+
pct_var_sigma2_init : :obj:`float`, optional
114+
Percentage of standardized outcome variance used to initialize global error variance parameter. Superseded by ``sigma2``. Defaults to ``0.25``.
112115
variable_weights : :obj:`np.array`, optional
113116
Numeric weights reflecting the relative probability of splitting on each variable. Does not need to sum to 1 but cannot be negative. Defaults to ``np.repeat(1/X_train.shape[1], X_train.shape[1])`` if not set here. Note that if the propensity score is included as a covariate in either forest, its weight will default to ``1/X_train.shape[1]``. A workaround if you wish to provide a custom weight for the propensity score is to include it as a column in ``X_train`` and then set ``propensity_covariate`` to ``'none'`` and adjust ``keep_vars_mu`` and ``keep_vars_tau`` accordingly.
114117
keep_vars_mu : obj:`list` or :obj:`np.array`, optional
@@ -130,7 +133,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
130133
num_mcmc : :obj:`int`, optional
131134
Number of "retained" iterations of the MCMC sampler. Defaults to ``100``. If this is set to 0, GFR (XBART) samples will be retained.
132135
sample_sigma_global : :obj:`bool`, optional
133-
Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(nu, nu*lambda)``. Defaults to ``True``.
136+
Whether or not to update the ``sigma^2`` global error variance parameter based on ``IG(a_global, b_global)``. Defaults to ``True``.
134137
sample_sigma_leaf_mu : :obj:`bool`, optional
135138
Whether or not to update the ``tau`` leaf scale variance parameter based on ``IG(a_leaf, b_leaf)`` for the prognostic forest.
136139
Cannot (currently) be set to true if ``basis_train`` has more than one column. Defaults to ``True``.
@@ -294,24 +297,24 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
294297
if beta_tau is not None:
295298
beta_tau = check_scalar(x=beta_tau, name="beta_tau", target_type=(float,int),
296299
min_val=1, max_val=None, include_boundaries="left")
297-
if nu is not None:
298-
nu = check_scalar(x=nu, name="nu", target_type=(float,int),
299-
min_val=0, max_val=None, include_boundaries="neither")
300-
if lamb is not None:
301-
lamb = check_scalar(x=lamb, name="lamb", target_type=(float,int),
302-
min_val=0, max_val=None, include_boundaries="neither")
300+
if a_global is not None:
301+
a_global = check_scalar(x=a_global, name="a_global", target_type=(float,int),
302+
min_val=0, max_val=None, include_boundaries="left")
303+
if b_global is not None:
304+
b_global = check_scalar(x=b_global, name="b_global", target_type=(float,int),
305+
min_val=0, max_val=None, include_boundaries="left")
303306
if a_leaf_mu is not None:
304307
a_leaf_mu = check_scalar(x=a_leaf_mu, name="a_leaf_mu", target_type=(float,int),
305-
min_val=0, max_val=None, include_boundaries="neither")
308+
min_val=0, max_val=None, include_boundaries="left")
306309
if a_leaf_tau is not None:
307310
a_leaf_tau = check_scalar(x=a_leaf_tau, name="a_leaf_tau", target_type=(float,int),
308-
min_val=0, max_val=None, include_boundaries="neither")
311+
min_val=0, max_val=None, include_boundaries="left")
309312
if b_leaf_mu is not None:
310313
b_leaf_mu = check_scalar(x=b_leaf_mu, name="b_leaf_mu", target_type=(float,int),
311-
min_val=0, max_val=None, include_boundaries="neither")
314+
min_val=0, max_val=None, include_boundaries="left")
312315
if b_leaf_tau is not None:
313316
b_leaf_tau = check_scalar(x=b_leaf_tau, name="b_leaf_tau", target_type=(float,int),
314-
min_val=0, max_val=None, include_boundaries="neither")
317+
min_val=0, max_val=None, include_boundaries="left")
315318
if q is not None:
316319
q = check_scalar(x=q, name="q", target_type=float,
317320
min_val=0, max_val=1, include_boundaries="neither")
@@ -512,10 +515,8 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
512515
resid_train = (y_train-self.y_bar)/self.y_std
513516

514517
# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART)
515-
if num_gfr > 0:
516-
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False)
517-
else:
518-
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True)
518+
if not sigma2:
519+
sigma2 = pct_var_sigma2_init*np.var(resid_train)
519520
b_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if b_leaf_mu is None else b_leaf_mu
520521
b_leaf_tau = np.squeeze(np.var(resid_train)) / (2*num_trees_tau) if b_leaf_tau is None else b_leaf_tau
521522
sigma_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if sigma_leaf_mu is None else sigma_leaf_mu
@@ -657,7 +658,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
657658

658659
# Sample variance parameters (if requested)
659660
if self.sample_sigma_global:
660-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
661+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
661662
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
662663
if self.sample_sigma_leaf_mu:
663664
self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)
@@ -671,7 +672,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
671672

672673
# Sample variance parameters (if requested)
673674
if self.sample_sigma_global:
674-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
675+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
675676
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
676677
if self.sample_sigma_leaf_tau:
677678
self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)
@@ -716,7 +717,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
716717

717718
# Sample variance parameters (if requested)
718719
if self.sample_sigma_global:
719-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
720+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
720721
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
721722
if self.sample_sigma_leaf_mu:
722723
self.leaf_scale_mu_samples[i] = leaf_var_model_mu.sample_one_iteration(self.forest_container_mu, cpp_rng, a_leaf_mu, b_leaf_mu, i)
@@ -730,7 +731,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
730731

731732
# Sample variance parameters (if requested)
732733
if self.sample_sigma_global:
733-
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, nu, lamb)
734+
current_sigma2 = global_var_model.sample_one_iteration(residual_train, cpp_rng, a_global, b_global)
734735
self.global_var_samples[i] = current_sigma2*self.y_std*self.y_std
735736
if self.sample_sigma_leaf_tau:
736737
self.leaf_scale_tau_samples[i] = leaf_var_model_tau.sample_one_iteration(self.forest_container_tau, cpp_rng, a_leaf_tau, b_leaf_tau, i)

0 commit comments

Comments
 (0)