Skip to content

Commit b76423a

Browse files
committed
Correct log integrated likelihood
1 parent decf243 commit b76423a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/leaf_model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,13 @@ double LogLinearVarianceLeafModel::NoSplitLogMarginalLikelihood(LogLinearVarianc
218218
}
219219

220220
double LogLinearVarianceLeafModel::SuffStatLogMarginalLikelihood(LogLinearVarianceSuffStat& suff_stat, double global_variance) {
221+
double prior_terms = a_ * std::log(b_) - boost::math::lgamma(a_);
221222
double a_term = a_ + 0.5 * suff_stat.n;
222223
double b_term = b_ + ((0.5 * suff_stat.weighted_sum_ei) / global_variance);
223224
double log_b_term = std::log(b_term);
224225
double lgamma_a_term = boost::math::lgamma(a_term);
225226
double resid_term = a_term * log_b_term;
226-
double log_ml = lgamma_a_term - resid_term;
227+
double log_ml = prior_terms + lgamma_a_term - resid_term;
227228
return log_ml;
228229
}
229230

@@ -258,7 +259,9 @@ void LogLinearVarianceLeafModel::SampleLeafParameters(ForestDataset& dataset, Fo
258259
node_rate = PosteriorParameterRate(node_suff_stat, global_variance);
259260

260261
// Draw from IG(shape, scale) and set the leaf parameter with each draw
261-
node_mu = std::log(gamma_sampler_.Sample(node_shape, node_rate, gen, true));
262+
std::gamma_distribution<double> gamma_dist_(node_shape, 1.);
263+
node_mu = std::log(gamma_dist_(gen) / node_rate);
264+
// node_mu = std::log(gamma_sampler_.Sample(node_shape, node_rate, gen, true));
262265
tree->SetLeaf(leaf_id, node_mu);
263266
}
264267
}

0 commit comments

Comments
 (0)