Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions stan/math/prim/prob/neg_binomial_2_log_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/max_size.hpp>
#include <stan/math/prim/fun/multiply_log.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/prob/poisson_log_lpmf.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -52,13 +51,20 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
size_t size_phi = stan::math::size(phi);
size_t size_eta_phi = max_size(eta, phi);
size_t size_n_phi = max_size(n, phi);
size_t max_size_seq_view = max_size(n, eta, phi);
size_t size_all = max_size(n, eta, phi);

VectorBuilder<true, T_partials_return, T_log_location> eta_val(size_eta);
for (size_t i = 0; i < size_eta; ++i) {
eta_val[i] = value_of(eta_vec[i]);
}

VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
for (size_t i = 0; i < size_phi; ++i) {
phi_val[i] = value_of(phi_vec[i]);
log_phi[i] = log(phi_val[i]);
}

VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
T_partials_return, T_log_location>
exp_eta(size_eta);
Expand All @@ -68,17 +74,19 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
}
}

VectorBuilder<true, T_partials_return, T_precision> phi_val(size_phi);
VectorBuilder<true, T_partials_return, T_precision> log_phi(size_phi);
for (size_t i = 0; i < size_phi; ++i) {
phi_val[i] = value_of(phi_vec[i]);
log_phi[i] = log(phi_val[i]);
VectorBuilder<!is_constant_all<T_log_location, T_precision>::value,
T_partials_return, T_log_location, T_precision>
exp_eta_over_exp_eta_phi(size_eta_phi);
if (!is_constant_all<T_log_location, T_precision>::value) {
for (size_t i = 0; i < size_eta_phi; ++i) {
exp_eta_over_exp_eta_phi[i] = inv(phi_val[i] / exp_eta[i] + 1);
}
}

VectorBuilder<true, T_partials_return, T_log_location, T_precision>
logsumexp_eta_logphi(size_eta_phi);
log1p_exp_eta_m_logphi(size_eta_phi);
for (size_t i = 0; i < size_eta_phi; ++i) {
logsumexp_eta_logphi[i] = log_sum_exp(eta_val[i], log_phi[i]);
log1p_exp_eta_m_logphi[i] = log1p_exp(eta_val[i] - log_phi[i]);
}

VectorBuilder<true, T_partials_return, T_n, T_precision> n_plus_phi(
Expand All @@ -87,38 +95,25 @@ return_type_t<T_log_location, T_precision> neg_binomial_2_log_lpmf(
n_plus_phi[i] = n_vec[i] + phi_val[i];
}

for (size_t i = 0; i < max_size_seq_view; i++) {
if (phi_val[i] > 1e5) {
// TODO(martinmodrak) This is wrong (doesn't pass propto information),
// and inaccurate for n = 0, but shouldn't break most models.
// Also the 1e5 cutoff is way too low.
// Will be addressed better once PR #1497 is merged
logp += poisson_log_lpmf(n_vec[i], eta_val[i]);
} else {
if (include_summand<propto>::value) {
logp -= lgamma(n_vec[i] + 1.0);
}
if (include_summand<propto, T_precision>::value) {
logp += multiply_log(phi_val[i], phi_val[i]) - lgamma(phi_val[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta_val[i];
}
if (include_summand<propto, T_precision>::value) {
logp += lgamma(n_plus_phi[i]);
}
logp -= (n_plus_phi[i]) * logsumexp_eta_logphi[i];
for (size_t i = 0; i < size_all; i++) {
if (include_summand<propto, T_precision>::value) {
logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]);
}
if (include_summand<propto, T_log_location>::value) {
logp += n_vec[i] * eta_val[i];
}
logp += -phi_val[i] * log1p_exp_eta_m_logphi[i]
- n_vec[i] * (log_phi[i] + log1p_exp_eta_m_logphi[i]);

if (!is_constant_all<T_log_location>::value) {
ops_partials.edge1_.partials_[i]
+= n_vec[i] - n_plus_phi[i] / (phi_val[i] / exp_eta[i] + 1);
+= n_vec[i] - n_plus_phi[i] * exp_eta_over_exp_eta_phi[i];
}
if (!is_constant_all<T_precision>::value) {
ops_partials.edge2_.partials_[i]
+= 1.0 - n_plus_phi[i] / (exp_eta[i] + phi_val[i]) + log_phi[i]
- logsumexp_eta_logphi[i] - digamma(phi_val[i])
+ digamma(n_plus_phi[i]);
+= exp_eta_over_exp_eta_phi[i] - n_vec[i] / (exp_eta[i] + phi_val[i])
- log1p_exp_eta_m_logphi[i]
- (digamma(phi_val[i]) - digamma(n_plus_phi[i]));
}
}
return ops_partials.build(logp);
Expand Down
17 changes: 5 additions & 12 deletions test/unit/math/prim/prob/neg_binomial_2_log_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <stan/math/prim.hpp>
#include <test/unit/math/prim/prob/vector_rng_test_helper.hpp>
#include <test/unit/math/prim/prob/NegativeBinomial2LogTestRig.hpp>
#include <test/unit/math/expect_near_rel.hpp>
#include <gtest/gtest.h>
#include <boost/random/mersenne_twister.hpp>
#include <boost/math/distributions.hpp>
Expand Down Expand Up @@ -212,29 +213,21 @@ TEST(ProbNegBinomial2, log_matches_lpmf) {
TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) {
std::vector<double> mu_log_to_test
= {-101, -27, -3, -1, -0.132, 0, 4, 10, 87};
// TODO(martinmodrak) Reducing the span of the test, should be fixed
// along with #1495
// std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
std::vector<double> phi_to_test = {0.36, 1, 10};
std::vector<double> phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16};
std::vector<int> n_to_test = {0, 1, 10, 39, 101, 3048, 150054};

// TODO(martinmdorak) Only weak tolerance for this quick fix
auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-8); };

for (double mu_log : mu_log_to_test) {
for (double phi : phi_to_test) {
for (int n : n_to_test) {
double val_log = stan::math::neg_binomial_2_log_lpmf(n, mu_log, phi);
EXPECT_LE(val_log, 0)
<< "neg_binomial_2_log_lpmf yields " << val_log
<< " which si greater than 0 for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
std::stringstream msg;
double val_orig
= stan::math::neg_binomial_2_lpmf(n, std::exp(mu_log), phi);
EXPECT_NEAR(val_log, val_orig, tolerance(val_orig))
msg << std::setprecision(22)
<< "neg_binomial_2_log_lpmf yields different result (" << val_log
<< ") than neg_binomial_2_lpmf (" << val_orig << ") for n = " << n
<< ", mu_log = " << mu_log << ", phi = " << phi << ".";
stan::test::expect_near_rel(msg.str(), val_log, val_orig);
}
}
}
Expand Down
Loading