diff --git a/stan/math/prim/fun/binomial_coefficient_log.hpp b/stan/math/prim/fun/binomial_coefficient_log.hpp index 444bf88e742..c7d7bf4f16f 100644 --- a/stan/math/prim/fun/binomial_coefficient_log.hpp +++ b/stan/math/prim/fun/binomial_coefficient_log.hpp @@ -2,9 +2,14 @@ #define STAN_MATH_PRIM_FUN_BINOMIAL_COEFFICIENT_LOG_HPP #include -#include +#include +#include +#include +#include +#include +#include #include -#include +#include namespace stan { namespace math { @@ -13,22 +18,24 @@ namespace math { * Return the log of the binomial coefficient for the specified * arguments. * - * The binomial coefficient, \f${N \choose n}\f$, read "N choose n", is - * defined for \f$0 \leq n \leq N\f$ by + * The binomial coefficient, \f${n \choose k}\f$, read "n choose k", is + * defined for \f$0 \leq k \leq n\f$ by * - * \f${N \choose n} = \frac{N!}{n! (N-n)!}\f$. + * \f${n \choose k} = \frac{n!}{k! (n-k)!}\f$. * * This function uses Gamma functions to define the log - * and generalize the arguments to continuous N and n. + * and generalize the arguments to continuous n and k. + * + * \f$ \log {n \choose k} + * = \log \ \Gamma(n+1) - \log \Gamma(k+1) - \log \Gamma(n-k+1)\f$. * - * \f$ \log {N \choose n} - * = \log \ \Gamma(N+1) - \log \Gamma(n+1) - \log \Gamma(N-n+1)\f$. * \f[ \mbox{binomial\_coefficient\_log}(x, y) = \begin{cases} - \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ - \ln\Gamma(x+1) & \mbox{if } 0\leq y \leq x \\ + \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x + < -1\\ + \ln\Gamma(x+1) & \mbox{if } -1 < y < x + 1 \\ \quad -\ln\Gamma(y+1)& \\ \quad -\ln\Gamma(x-y+1)& \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} @@ -38,7 +45,8 @@ namespace math { \f[ \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial x} = \begin{cases} - \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ + \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x + < -1\\ \Psi(x+1) & \mbox{if } 0\leq y \leq x \\ \quad -\Psi(x-y+1)& \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} @@ -48,32 +56,95 @@ namespace math { \f[ \frac{\partial\, \mbox{binomial\_coefficient\_log}(x, y)}{\partial y} = \begin{cases} - \textrm{error} & \mbox{if } y > x \textrm{ or } y < 0\\ + \textrm{error} & \mbox{if } y > x + 1 \textrm{ or } y < -1 \textrm{ or } x + < -1\\ -\Psi(y+1) & \mbox{if } 0\leq y \leq x \\ \quad +\Psi(x-y+1)& \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN} \end{cases} \f] * - * @tparam T_N type of the first argument - * @tparam T_n type of the second argument - * @param N total number of objects. - * @param n number of objects chosen. - * @return log (N choose n). + * This function is numerically more stable than naive evaluation via lgamma. + * + * @tparam T_n type of the first argument + * @tparam T_k type of the second argument + * + * @param n total number of objects. + * @param k number of objects chosen. + * @return log (n choose k). */ -template -inline return_type_t binomial_coefficient_log(const T_N N, - const T_n n) { - const double CUTOFF = 1000; - if (N - n < CUTOFF) { - const T_N N_plus_1 = N + 1; - return lgamma(N_plus_1) - lgamma(n + 1) - lgamma(N_plus_1 - n); + +template +inline return_type_t binomial_coefficient_log(const T_n n, + const T_k k) { + using T_partials_return = partials_return_t; + + if (is_any_nan(n, k)) { + return NOT_A_NUMBER; + } + + // Choosing the more stable of the symmetric branches + if (n > -1 && k > value_of_rec(n) / 2.0 + 1e-8) { + return binomial_coefficient_log(n, n - k); + } + + const T_partials_return n_dbl = value_of(n); + const T_partials_return k_dbl = value_of(k); + const T_partials_return n_plus_1 = n_dbl + 1; + const T_partials_return n_plus_1_mk = n_plus_1 - k_dbl; + + static const char* function = "binomial_coefficient_log"; + check_greater_or_equal(function, "first argument", n, -1); + check_greater_or_equal(function, "second argument", k, -1); + check_greater_or_equal(function, "(first argument - second argument + 1)", + n_plus_1_mk, 0.0); + + operands_and_partials ops_partials(n, k); + + T_partials_return value; + if (k_dbl == 0) { + value = 0; + } else if (n_plus_1 < lgamma_stirling_diff_useful) { + value = lgamma(n_plus_1) - lgamma(k_dbl + 1) - lgamma(n_plus_1_mk); } else { - return_type_t N_minus_n = N - n; - const double one_twelfth = inv(12); - return multiply_log(n, N_minus_n) + multiply_log((N + 0.5), N / N_minus_n) - + one_twelfth / N - n - one_twelfth / N_minus_n - lgamma(n + 1); + value = -lbeta(n_plus_1_mk, k_dbl + 1) - log1p(n_dbl); } + + if (!is_constant_all::value) { + // Branching on all the edge cases. + // In direct computation many of those would be NaN + // But one-sided limits from within the domain exist, all of the below + // follows from lim x->0 from above digamma(x) == -Inf + // + // Note that we have k < n / 2 (see the first branch in this function) + // se we can ignore the n == k - 1 edge case. + T_partials_return digamma_n_plus_1_mk = digamma(n_plus_1_mk); + + if (!is_constant_all::value) { + if (n_dbl == -1.0) { + if (k_dbl == 0) { + ops_partials.edge1_.partials_[0] = 0; + } else { + ops_partials.edge1_.partials_[0] = NEGATIVE_INFTY; + } + } else { + ops_partials.edge1_.partials_[0] + = (digamma(n_plus_1) - digamma_n_plus_1_mk); + } + } + if (!is_constant_all::value) { + if (k_dbl == 0 && n_dbl == -1.0) { + ops_partials.edge2_.partials_[0] = NEGATIVE_INFTY; + } else if (k_dbl == -1) { + ops_partials.edge2_.partials_[0] = INFTY; + } else { + ops_partials.edge2_.partials_[0] + = (digamma_n_plus_1_mk - digamma(k_dbl + 1)); + } + } + } + + return ops_partials.build(value); } } // namespace math diff --git a/test/unit/math/mix/fun/binomial_coefficient_log_test.cpp b/test/unit/math/mix/fun/binomial_coefficient_log_test.cpp index 84defe83a3e..a47c850ece6 100644 --- a/test/unit/math/mix/fun/binomial_coefficient_log_test.cpp +++ b/test/unit/math/mix/fun/binomial_coefficient_log_test.cpp @@ -6,5 +6,9 @@ TEST(mathMixScalFun, binomialCoefficientLog) { }; stan::test::expect_ad(f, 3, 2); stan::test::expect_ad(f, 24.0, 12.0); + stan::test::expect_ad(f, 1.0, 0.0); + stan::test::expect_ad(f, 0.0, 1.0); + stan::test::expect_ad(f, -0.3, 0.5); + stan::test::expect_common_nonzero_binary(f); } diff --git a/test/unit/math/prim/fun/binomial_coefficient_log_test.cpp b/test/unit/math/prim/fun/binomial_coefficient_log_test.cpp index 23cea698637..e5df58947b9 100644 --- a/test/unit/math/prim/fun/binomial_coefficient_log_test.cpp +++ b/test/unit/math/prim/fun/binomial_coefficient_log_test.cpp @@ -1,13 +1,14 @@ #include +#include #include #include -#include template void test_binom_coefficient(const T_N& N, const T_n& n) { using stan::math::binomial_coefficient_log; EXPECT_FLOAT_EQ(lgamma(N + 1) - lgamma(n + 1) - lgamma(N - n + 1), - binomial_coefficient_log(N, n)); + binomial_coefficient_log(N, n)) + << "N = " << N << ", n = " << n; } TEST(MathFunctions, binomial_coefficient_log) { @@ -19,6 +20,13 @@ TEST(MathFunctions, binomial_coefficient_log) { EXPECT_FLOAT_EQ(29979.16, binomial_coefficient_log(100000, 91116)); + EXPECT_EQ(binomial_coefficient_log(-1, 0), 0); // Needed for neg_binomial_2 + EXPECT_EQ(binomial_coefficient_log(50, 0), 0); + EXPECT_EQ(binomial_coefficient_log(10000, 0), 0); + + EXPECT_EQ(binomial_coefficient_log(10, 11), stan::math::NEGATIVE_INFTY); + EXPECT_EQ(binomial_coefficient_log(10, -1), stan::math::NEGATIVE_INFTY); + for (int n = 0; n < 1010; ++n) { test_binom_coefficient(1010, n); test_binom_coefficient(1010.0, n); @@ -32,9 +40,27 @@ TEST(MathFunctions, binomial_coefficient_log) { } TEST(MathFunctions, binomial_coefficient_log_nan) { - double nan = std::numeric_limits::quiet_NaN(); + double nan = stan::math::NOT_A_NUMBER; EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(2.0, nan))); EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, 2.0))); EXPECT_TRUE(std::isnan(stan::math::binomial_coefficient_log(nan, nan))); } + +TEST(MathFunctions, binomial_coefficient_log_errors_edge_cases) { + using stan::math::INFTY; + using stan::math::binomial_coefficient_log; + + EXPECT_NO_THROW(binomial_coefficient_log(10, 11)); + EXPECT_THROW(binomial_coefficient_log(10, 11.01), std::domain_error); + EXPECT_THROW(binomial_coefficient_log(10, -1.1), std::domain_error); + EXPECT_THROW(binomial_coefficient_log(-1, 0.3), std::domain_error); + EXPECT_NO_THROW(binomial_coefficient_log(-0.5, 0.49)); + EXPECT_NO_THROW(binomial_coefficient_log(10, -0.9)); + + EXPECT_FLOAT_EQ(binomial_coefficient_log(0, -1), -INFTY); + EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, 0), 0); + EXPECT_FLOAT_EQ(binomial_coefficient_log(-1, -0.3), INFTY); + EXPECT_FLOAT_EQ(binomial_coefficient_log(0.3, -1), -INFTY); + EXPECT_FLOAT_EQ(binomial_coefficient_log(5.0, 6.0), -INFTY); +} diff --git a/test/unit/math/rev/fun/binomial_coefficient_log_test.cpp b/test/unit/math/rev/fun/binomial_coefficient_log_test.cpp new file mode 100644 index 00000000000..c70e37a2c72 --- /dev/null +++ b/test/unit/math/rev/fun/binomial_coefficient_log_test.cpp @@ -0,0 +1,349 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(MathFunctions, binomial_coefficient_log_identities) { + using stan::math::binomial_coefficient_log; + using stan::math::is_nan; + using stan::math::log; + using stan::math::log_sum_exp; + using stan::math::value_of; + using stan::math::var; + using stan::test::expect_near_rel; + + std::vector n_to_test + = {-0.1, 0, 1e-100, 1e-8, 1e-1, 1, 1 + 1e-6, 15, 10, 1e3, 1e30, 1e100}; + + std::vector k_ratios_to_test + = {-0.1, 1e-10, 1e-5, 1e-3, 1e-1, 0.5, 0.9, 1 - 1e-5, 1 - 1e-10}; + + // Recurrence relation: binomial_coefficient_log(n, k) == + // binomial_coefficient_log(n - 1, k - 1) + log(n) - log(k) + for (double n_dbl : n_to_test) { + for (double k_ratio : k_ratios_to_test) { + double k_dbl = n_dbl * k_ratio; + if (n_dbl <= 0 || k_dbl <= 0) { + continue; + } + // The redundant -1 +1 is necessary as this copies the loss of precision + // for very small n_dbl + if ((n_dbl - 1) + 1 - k_dbl <= 0) { + continue; + } + + stan::math::nested_rev_autodiff nested; + var n(n_dbl); + var k(k_dbl); + + // TODO(martinmodrak) Use the framework for testing identities, once it is + // ready + var val_left = binomial_coefficient_log(n, k); + var val_right_partial; + var val_right; + // Choose the more stable identity + if (n_dbl > 1 && k_dbl > 1 && (n_dbl - 1) + 1 - k_dbl > 0) { + val_right_partial = binomial_coefficient_log(n - 1, k - 1); + val_right = val_right_partial + log(n) - log(k); + } else { + val_right_partial = binomial_coefficient_log(n + 1, k + 1); + val_right = val_right_partial - log(n + 1) + log(k + 1); + } + + std::vector vars; + vars.push_back(n); + vars.push_back(k); + + std::vector gradients_left; + val_left.grad(vars, gradients_left); + + nested.set_zero_all_adjoints(); + + std::vector gradients_right; + val_right.grad(vars, gradients_right); + + for (int i = 0; i < 2; ++i) { + EXPECT_FALSE(is_nan(gradients_left[i])); + EXPECT_FALSE(is_nan(gradients_right[i])); + } + + std::stringstream msg; + msg << std::setprecision(22) << " successor: n = " << n << ", k = " << k + << std::endl + << "val = " << val_left << ", val2 = " << val_right_partial + << std::endl + << ", logn = " << log(n) << ", logk = " << log(k); + + expect_near_rel(std::string("val") + msg.str(), value_of(val_left), + value_of(val_right)); + expect_near_rel(std::string("dn") + msg.str(), gradients_left[0], + gradients_right[0]); + expect_near_rel(std::string("dk") + msg.str(), gradients_left[1], + gradients_right[1]); + } + } +} + +namespace binomial_coefficient_log_test_internal { +struct TestValue { + double n; + double k; + double val; + double dn; + double dk; +}; + +// Hand-checked edge cases. Using one-sided limits from +// within the function domain where the value doesn't exist +std::vector testValuesEdge = { + {-1, 0, 0, 0, stan::math::NEGATIVE_INFTY}, + {0, -1, stan::math::NEGATIVE_INFTY, -1.0, stan::math::INFTY}, + {3, -1, stan::math::NEGATIVE_INFTY, -0.25, stan::math::INFTY}, + {-1, -0.2, stan::math::INFTY, stan::math::NEGATIVE_INFTY, + -4.324031329886049836}, + {-0.5, 0.5, stan::math::NEGATIVE_INFTY, stan::math::INFTY, + stan::math::NEGATIVE_INFTY}, + {4.0, 5.0, stan::math::NEGATIVE_INFTY, stan::math::INFTY, + stan::math::NEGATIVE_INFTY}, + {1, 0, 0, 0, 1}, +}; + +const double NaN = stan::math::NOT_A_NUMBER; +// Test values generated in Mathematica, reproducible notebook at +// https://www.wolframcloud.com/obj/martin.modrak/Published/binomial_coefficient_log.nb +// Mathematica Code reproduced below for convenience: + +// toCString[x_] := ToString[CForm[N[x, 24]]]; +// singleTest[x_,y_]:= Module[{val, cdn,cdk},{ +// val = toCString[bclog[x,y]]; +// cdn = If[x > 10^6 || y > 10^6,"NaN", toCString[dbclogdn[x,y]]]; +// cdk = If[x > 10^6 || y > 10^6,"NaN", toCString[dbclogdk[x,y]]]; +// StringJoin[" {",toCString[x],",",toCString[y],",", +// val,",",cdn,",",cdk,"},\n"] +// }]; +// ns= {-0.1,3*10^-5,2*10^-3,1,8, 1325,845*10^3}; +// ratios = {-1,10^- 10,10^-5,10^-2,1/5,1/2,1-3*10^-2,1-6*10^-8, 1 -3*10^-9,2}; +// out = "std::vector testValues = {\n"; +// For[i = 1, i <= Length[ns], i++, { +// For[j = 1, j <= Length[ratios], j++, { +// cn = ns[[i]]; +// ck = If[ratios[[j]] < 0,-9/10, +// If[ratios[[j]] > 1,cn + 9/10,cn * ratios[[j]] ]]; +// out = StringJoin[out, singleTest[cn,ck]]; +// }] +// }] +// extremeNs = {3*10^15+1/2,10^20 + 1/2}; +// lowKs = {3, 100, 12895}; +// For[i = 1, i <= Length[extremeNs], i++, { +// For[j = 1, j <= Length[lowKs], j++, { +// cn = extremeNs[[i]]; +// ck = lowKs[[j]]; +// out = StringJoin[out,singleTest[cn,ck]]; +// }] +// }] +// out = StringJoin[out,"};\n"]; +// out +std::vector testValues = { + {-0.1, -0.9, -2.11525253908509081592028, -1.0399183832409129390763, + 10.7087463737049383316859}, + {-0.1, -1.0000000000000001e-11, 1.77711285027681189779234e-12, + -1.92253995946119474331752e-11, -0.177711285009843801688898}, + {-0.1, -2.0000000000000003e-6, 3.55415435144048059050025e-7, + -3.8450735153733054435704e-6, -0.177704150099061235922007}, + {-0.1, -0.003, 0.000517083756840281579810978, -0.00575325547677734075296938, + -0.167012379549101675417438}, + {-0.1, -0.020000000000000004, 0.00284168667292114343243947, + -0.0378231526143224327191161, -0.106499800517313678664015}, + {-0.1, -0.05, 0.0044386492587971776182016, -0.0923173337417260614732854, 0}, + {-0.1, -0.097, 0.000517083756840282014391421, -0.172765635025879011871306, + 0.167012379549101666140604}, + {-0.1, -0.09999999400000001, 1.06626764549734262955223e-9, + -0.177711275175914102783629, 0.177711263640674409624395}, + {-0.1, -0.0999999998, 3.55422574122942521222662e-11, + -0.17771128471653172414194, 0.177711284332023727176803}, + {-0.1, 0.8, -2.11525253908509135092963, 9.66882799046403046022088, + -10.7087463737049434361164}, + {0.00003, -0.9, -2.21375637737528044964183, -0.933371118080918307851001, + 10.7799597405306456982508}, + {0.00003, 3.0000000000000002e-15, 1.4804082053556344384283e-19, + 4.93458583906860402434769e-15, 0.0000493469401735864500932795}, + {0.00003, 6.000000000000001e-10, 2.96075719467911309956804e-14, + 9.86917168246424148279385e-10, 0.000049344966305847915513204}, + {0.00003, 9.e-7, 4.30798787797857747053402e-11, + 1.48037672530856143443777e-6, 0.0000463861237716491755189368}, + {0.00003, 6.e-6, 2.36865312869367146776112e-10, + 9.86921494891293022056408e-6, 0.0000296081641072682823142211}, + {0.00003, 0.000015, 3.70102051348524063287983e-10, + 0.0000246731996398828634493583, 0}, + {0.00003, 0.0000291, 4.30798787797858483751685e-11, + 0.0000478665004969577343409155, -0.0000463861237716491702941263}, + {0.00003, 0.000029999998200000002, 8.88244869467749046930373e-17, + 0.0000493469372225745196092216, -0.0000493469342618230179633344}, + {0.00003, 0.00002999999994, 2.96081652032227041059628e-18, + 0.0000493469400847597902807008, -0.0000493469399860680696581904}, + {0.00003, 0.90003, -2.2137563773752804140164, 9.84658862244972705518472, + -10.7799597405306453607622}, + {0.002, -0.9, -2.21559326412971099690821, -0.931489785799666354301045, + 10.7813141298573216196055}, + {0.002, 2.e-13, 6.57013709556564711297722e-16, + 3.28027758802827577274611e-13, 0.00328506854745431617062256}, + {0.002, 4.e-8, 1.31400113866164680767665e-10, 6.56055536734964490768367e-8, + 0.00328493714519690660826286}, + {0.002, 0.00006, 1.91190982195918985147542e-7, + 0.0000984126319888139547815645, 0.0030879641992827061931754}, + {0.002, 0.0004, 1.05122171458287357370166e-6, 0.000656246880415250699426597, + 0.0019710403008191429519067}, + {0.002, 0.001, 1.64253373496215320086901e-6, 0.00164133545687894157470528, + 0}, + {0.002, 0.0019399999999999999, 1.91190982195919466419307e-7, + 0.00318637683127151989160979, -0.00308796419928270568118357}, + {0.002, 0.00199999988, 3.94208201970328159413461e-13, + 0.00328506835071924281368086, -0.00328506815390258758994027}, + {0.002, 0.001999999996, 1.31402741136587958866517e-14, + 0.00328506854153159450171264, -0.0032850685349710393518526}, + {0.002, 0.902, -2.21559326412971125500408, 9.84982434405765769353491, + -10.7813141298573240642835}, + {1., -0.9, -2.85558226198351740582195, -0.459715615539276790357555, + 11.3062548910488207249193}, + {1., 1.e-10, 9.9999999988550662975071e-11, 6.44934066868432150285563e-11, + 0.99999999977101318664035}, + {1., 0.00002, 0.0000199995420290398824268595, + 0.0000128987621603843854005126, 0.999954203037316753927052}, + {1., 0.03, 0.0289783282362563119258558, 0.0195321262846958328746912, + 0.932173296099201809954111}, + {1., 0.2, 0.156457962917688023080733, 0.137792901804605606966842, + 0.57403132988604981390314}, + {1., 0.5, 0.241564475270490444691037, 0.386294361119890618834464, 0}, + {1., 0.97, 0.0289783282362563377988622, 0.951705422383897599096425, + -0.932173296099201747978444}, + {1., 0.99999994, 5.99999958466560848176267e-8, 0.999999901303960368460267, + -0.999999862607915650529701}, + {1., 0.999999998, 2.00000004987870302153356e-9, 0.999999996710131781531233, + -0.999999995420263611904449}, + {1., 1.9, -2.85558226198351640162479, 10.846539275509534925475, + -11.3062548910488116793316}, + {8., -0.9, -4.22528965320883461943031, -0.100538838650771402252215, + 12.6649352570174581939568}, + {8., 8.e-10, 2.17428571372173161903874e-9, 9.40096117596390056755358e-11, + 2.71785714144718599267395}, + {8., 0.00016, 0.000434834585178913876603697, 0.0000188020989077387807855132, + 2.71757518207576360648536}, + {8., 0.24, 0.606274586245453630229602, 0.0286077438730426700300604, + 2.35152741320850413169898}, + {8., 1.6, 2.90678606291134293918723, 0.208248082071609215411629, + 1.18134594311052448766911}, + {8., 4., 4.24849524204935898912334, 0.634523809523809523809524, 0}, + {8., 7.76, 0.606274586245454152373578, 2.38013515708154653288918, + -2.35152741320850383600988}, + {8., 7.99999952, 1.30457122485101544957266e-6, 2.71785635328906813937859, + -2.71785629688329952694258}, + {8., 7.999999984, 4.34857152442032476701103e-8, 2.71785711653819737865347, + -2.71785711465800509058726}, + {8., 8.9, -4.22528965320883911891918, 12.5643964183667228280515, + -12.664935257017494268063}, + {1325.45, -0.9, -8.72391406172695433576549, -0.000678528341300029218848988, + 17.6143179550958346212948}, + {1325.45, 1.32545e-7, 1.02949027509089702892375e-6, + 9.99622864543865370572321e-11, 7.76709993311726359860762}, + {1325.45, 0.026509000000000005, 0.205327156306863379978836, + 0.0000199926571417065267265905, 7.72429965627466779820619}, + {1325.45, 39.7635, 175.846725556752161664879, 0.0304475435453562149043387, + 3.46396589228722918392492}, + {1325.45, 265.09000000000003, 659.660660749231586923507, + 0.223049270402325103613771, 1.38488085895377544787304}, + {1325.45, 662.725, 914.911196853663854616454, 0.692770092488070560492316, + 0}, + {1325.45, 1285.6865, 175.846725556752235503753, 3.49441343583258486943696, + -3.46396589228722863795938}, + {1325.45, 1325.449920473, 0.000617688970012744110414011, + 7.76696934217533185815885, -7.76696928219795817077374}, + {1325.45, 1325.4499973491, 0.0000205898008981706567184859, + 7.76709579069753161291118, -7.76709578869828579554802}, + {1325.45, 1326.3500000000001, -8.72391406172855634865104, + 17.613639426763759896892, -17.6143179551050599946563}, + {845000.3, -0.9, -14.5350966987995578090701, -1.06508718182367185039981e-6, + 24.0708488585827418941125}, + {845000.3, 0.00008450003000000001, 0.00120194862411218431712068, + 9.99999408334467162987851e-11, 14.2241695294903594263306}, + {845000.3, 16.900006, 197.416639012626785645761, + 0.0000200001881681193593499904, 10.790464758593674854536}, + {845000.3, 25350.009000000002, 113851.198555151510249396, + 0.0304591891842282610860051, 3.47607957612220465494725}, + {845000.3, 169000.06000000003, 422833.371816046100941544, + 0.223143403385333866823948, 1.38629214218850240580417}, + {845000.3, 422500.15, 585702.52617952879969091, 0.693146588844529182207756, + 0}, + {845000.3, 819650.291, 113851.198555151775813377, 3.50653876530642990238351, + -3.47607957612220154809006}, + {845000.3, 845000.2492999821, 0.71910904869572166682554, + 14.1438656829570673122744, -14.1438656229571010748933}, + {845000.3, 845000.2983099994, 0.0240367434486935944287041, + 14.2215320063338616170949, -14.2215320043338627454076}, + {845000.3, 845001.2000000001, -14.5350966993600009324015, + 24.0708477958572381039047, -24.0708488609444199551305}, + {3.0000000000000005e15, 3.1, 108.557127724329303822723, NaN, NaN}, + {3.0000000000000005e15, 100.2, 3206.2047392970248044977, NaN, NaN}, + {3.0000000000000005e15, 12895.6, 350403.227999624864153782, NaN, NaN}, + {1.e20, 3.1, 140.841498570865873374959, NaN, NaN}, + {1.e20, 100.2, 4249.7189195624987706026, NaN, NaN}, + {1.e20, 12895.6, 484702.044995974181173534, NaN, NaN}, +}; + +} // namespace binomial_coefficient_log_test_internal + +TEST(MathFunctions, binomial_coefficient_log_precomputed) { + using binomial_coefficient_log_test_internal::TestValue; + using binomial_coefficient_log_test_internal::testValues; + using binomial_coefficient_log_test_internal::testValuesEdge; + using stan::math::is_nan; + using stan::math::value_of; + using stan::math::var; + using stan::test::expect_near_rel; + using stan::test::relative_tolerance; + + std::vector allTestValues = testValues; + allTestValues.insert(allTestValues.end(), testValuesEdge.begin(), + testValuesEdge.end()); + for (TestValue t : allTestValues) { + std::stringstream msg; + msg << std::setprecision(22) << "n = " << t.n << ", k = " << t.k; + + var n(t.n); + var k(t.k); + var val = stan::math::binomial_coefficient_log(n, k); + + std::vector vars; + vars.push_back(n); + vars.push_back(k); + + std::vector gradients; + val.grad(vars, gradients); + + for (int i = 0; i < 2; ++i) { + EXPECT_FALSE(is_nan(gradients[i])); + } + + expect_near_rel(msg.str(), value_of(val), t.val, + relative_tolerance(1e-14, 1e-14)); + + relative_tolerance tol_grad; + if (n < 1 || k < 1) { + tol_grad = relative_tolerance(1e-8, 1e-7); + } else { + tol_grad = relative_tolerance(1e-10, 1e-8); + } + if (!is_nan(t.dn)) { + expect_near_rel(std::string("dn: ") + msg.str(), gradients[0], t.dn, + tol_grad); + } + if (!is_nan(t.dk)) { + expect_near_rel(std::string("dk: ") + msg.str(), gradients[1], t.dk, + tol_grad); + } + } +} diff --git a/test/unit/math/rev/fun/lbeta_test.cpp b/test/unit/math/rev/fun/lbeta_test.cpp index 780cd8186f2..f4ea4ccf6bb 100644 --- a/test/unit/math/rev/fun/lbeta_test.cpp +++ b/test/unit/math/rev/fun/lbeta_test.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include #include @@ -8,6 +10,105 @@ #include #include +namespace lbeta_test_internal { +// TODO(martinmodrak) the function here should be replaced by helpers for +// testing identities once those are available + +struct identity_tolerances { + stan::test::relative_tolerance value; + stan::test::relative_tolerance gradient; +}; + +template +void expect_identity(const std::string& msg, + const identity_tolerances& tolerances, const F1 lh, + const F2 rh, double x_dbl, double y_dbl) { + using stan::math::var; + using stan::test::expect_near_rel; + + stan::math::nested_rev_autodiff nested; + + var x(x_dbl); + var y(y_dbl); + + std::vector vars = {x, y}; + + var left = lh(x, y); + double left_dbl = value_of(left); + std::vector gradients_left; + left.grad(vars, gradients_left); + + nested.set_zero_all_adjoints(); + + var right = rh(x, y); + double right_dbl = value_of(right); + std::vector gradients_right; + right.grad(vars, gradients_right); + + std::stringstream args; + args << std::setprecision(22) << "args = [" << x << "," << y << "]"; + expect_near_rel(std::string() + args.str() + std::string(": ") + msg, + left_dbl, right_dbl, tolerances.value); + + for (size_t i = 0; i < gradients_left.size(); ++i) { + std::stringstream grad_msg; + grad_msg << "grad_" << i << ", " << args.str() << ": " << msg; + expect_near_rel(grad_msg.str(), gradients_left[i], gradients_right[i], + tolerances.gradient); + } +} +} // namespace lbeta_test_internal + +TEST(MathFunctions, lbeta_identities_gradient) { + using stan::math::lbeta; + using stan::math::pi; + using stan::math::var; + using stan::test::expect_near_rel; + + std::vector to_test + = {1e-100, 1e-8, 1e-1, 1, 2, 1 + 1e-6, 1e3, 1e30, 1e100}; + + lbeta_test_internal::identity_tolerances tol{{1e-15, 1e-15}, {1e-10, 1e-10}}; + + // All identities from https://en.wikipedia.org/wiki/Beta_function#Properties + // Successors: beta(a,b) = beta(a + 1, b) + beta(a, b + 1) + for (double x : to_test) { + for (double y : to_test) { + // TODO(martinmodrak) this restriction on testing should be lifted once + // the log_sum_exp bug (#1679) is resolved + if (x > 1e10 || y > 1e10) { + continue; + } + auto rh = [](const var& a, const var& b) { + return stan::math::log_sum_exp(lbeta(a + 1, b), lbeta(a, b + 1)); + }; + lbeta_test_internal::expect_identity( + "succesors", tol, static_cast(lbeta), + rh, x, y); + } + } + + // Sin: beta(x, 1 - x) == pi / sin(pi * x) + for (double x : to_test) { + if (x < 1) { + std::stringstream msg; + msg << std::setprecision(22) << "sin: x = " << x; + double lh = lbeta(x, 1.0 - x); + double rh = log(pi()) - log(sin(pi() * x)); + expect_near_rel(msg.str(), lh, rh, tol.value); + } + } + + // Inv: beta(1, x) == 1 / x + for (double x : to_test) { + std::stringstream msg; + msg << std::setprecision(22) << "inv: x = " << x; + double lh = lbeta(x, 1.0); + double rh = -log(x); + expect_near_rel(msg.str(), lh, rh, tol.value); + } +} + namespace lbeta_test_internal { struct TestValue { double x;