Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 46 additions & 37 deletions include/stochtree/tree_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,46 +802,55 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM
double no_split_log_marginal_likelihood = std::get<1>(split_eval);
int32_t left_n = std::get<2>(split_eval);
int32_t right_n = std::get<3>(split_eval);

// Determine probability of growing the split node and its two new left and right nodes
double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());

// Determine whether a "grow" move is possible from the newly formed tree
// in order to compute the probability of choosing "prune" from the new tree
// (which is always possible by construction)
bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
double prob_prune_new;
if (non_constant && (min_samples_left_check || min_samples_right_check)) {
prob_prune_new = 0.5;
} else {
prob_prune_new = 1.0;
}

// Determine the number of leaves in the current tree and leaf parents in the proposed tree
int num_leaf_parents = tree->NumLeafParents();
double p_leaf = 1/static_cast<double>(num_leaves);
double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
// Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {

// Determine probability of growing the split node and its two new left and right nodes
double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());

// Determine whether a "grow" move is possible from the newly formed tree
// in order to compute the probability of choosing "prune" from the new tree
// (which is always possible by construction)
bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
double prob_prune_new;
if (non_constant && (min_samples_left_check || min_samples_right_check)) {
prob_prune_new = 0.5;
} else {
prob_prune_new = 1.0;
}

// Determine the number of leaves in the current tree and leaf parents in the proposed tree
int num_leaf_parents = tree->NumLeafParents();
double p_leaf = 1/static_cast<double>(num_leaves);
double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);

// Compute the final MH ratio
double log_mh_ratio = (
std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
);
// Threshold at 0
if (log_mh_ratio > 0) {
log_mh_ratio = 0;
}

// Compute the final MH ratio
double log_mh_ratio = (
std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
);
// Threshold at 0
if (log_mh_ratio > 0) {
log_mh_ratio = 0;
}
// Draw a uniform random variable and accept/reject the proposal on this basis
std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
double log_acceptance_prob = std::log(mh_accept(gen));
if (log_acceptance_prob <= log_mh_ratio) {
accept = true;
AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
} else {
accept = false;
}

// Draw a uniform random variable and accept/reject the proposal on this basis
std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
double log_acceptance_prob = std::log(mh_accept(gen));
if (log_acceptance_prob <= log_mh_ratio) {
accept = true;
AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
} else {
accept = false;
}
Expand Down
Loading