diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index 1deb92f9738..295e223e860 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/stan/math/rev/fun/calculate_chain.hpp b/stan/math/rev/fun/calculate_chain.hpp deleted file mode 100644 index 2ce66667b3a..00000000000 --- a/stan/math/rev/fun/calculate_chain.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef STAN_MATH_REV_FUN_CALCULATE_CHAIN_HPP -#define STAN_MATH_REV_FUN_CALCULATE_CHAIN_HPP - -#include -#include - -namespace stan { -namespace math { - -inline double calculate_chain(double x, double val) { - return std::exp(x - val); // works out to inv_logit(x) -} - -} // namespace math -} // namespace stan -#endif diff --git a/stan/math/rev/fun/log1p_exp.hpp b/stan/math/rev/fun/log1p_exp.hpp index 762f8ee356a..5126b2bcb8d 100644 --- a/stan/math/rev/fun/log1p_exp.hpp +++ b/stan/math/rev/fun/log1p_exp.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include namespace stan { @@ -13,7 +13,7 @@ namespace internal { class log1p_exp_v_vari : public op_v_vari { public: explicit log1p_exp_v_vari(vari* avi) : op_v_vari(log1p_exp(avi->val_), avi) {} - void chain() { avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); } + void chain() { avi_->adj_ += adj_ * inv_logit(avi_->val_); } }; } // namespace internal diff --git a/stan/math/rev/fun/log_diff_exp.hpp b/stan/math/rev/fun/log_diff_exp.hpp index 978987ca5c9..aa8c0034ad9 100644 --- a/stan/math/rev/fun/log_diff_exp.hpp +++ b/stan/math/rev/fun/log_diff_exp.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -17,7 +16,7 @@ class log_diff_exp_vv_vari : public op_vv_vari { log_diff_exp_vv_vari(vari* avi, vari* bvi) : op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {} void chain() { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_); bvi_->adj_ -= adj_ / expm1(avi_->val_ - bvi_->val_); } }; @@ -29,7 +28,7 @@ class log_diff_exp_vd_vari : public op_vd_vari { if (val_ == NEGATIVE_INFTY) { avi_->adj_ += (bd_ == NEGATIVE_INFTY) ? adj_ : adj_ * INFTY; } else { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ -= adj_ / expm1(bd_ - avi_->val_); } } }; diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index 9ca331aab74..f003ead2f6b 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -3,11 +3,11 @@ #include #include -#include #include #include #include #include +#include #include #include #include @@ -22,8 +22,8 @@ class log_sum_exp_vv_vari : public op_vv_vari { log_sum_exp_vv_vari(vari* avi, vari* bvi) : op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {} void chain() { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); - bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_); + avi_->adj_ += adj_ * inv_logit(avi_->val_ - bvi_->val_); + bvi_->adj_ += adj_ * inv_logit(bvi_->val_ - avi_->val_); } }; class log_sum_exp_vd_vari : public op_vd_vari { @@ -34,7 +34,7 @@ class log_sum_exp_vd_vari : public op_vd_vari { if (val_ == NEGATIVE_INFTY) { avi_->adj_ += adj_; } else { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ += adj_ * inv_logit(avi_->val_ - bd_); } } }; diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp new file mode 100644 index 00000000000..aa6d49f81a1 --- /dev/null +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include + +TEST(log_sum_exp_tests, large_values) { + using stan::math::var; + + // check autodiffing works with var types with large values + var a = 1e50; + var output = stan::math::log_sum_exp(a, a); + output.grad(); + EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a)); + EXPECT_FLOAT_EQ(a.adj(), 1.0); + + var a2 = 1; + var a3 = 1e50; + var output2 = stan::math::log_sum_exp(a2, a3); + output2.grad(); + EXPECT_FLOAT_EQ(a2.adj(), 0.0); + EXPECT_FLOAT_EQ(a3.adj(), 1.0); + + var a4 = 1e50; + var a5 = 1; + var output3 = stan::math::log_sum_exp(a4, a5); + output3.grad(); + EXPECT_FLOAT_EQ(a4.adj(), 1.0); + EXPECT_FLOAT_EQ(a5.adj(), 0.0); + + // check autodiffing works with var types with large values + var b = 1e20; + var output6 = stan::math::log_sum_exp(b, b); + output6.grad(); + EXPECT_FLOAT_EQ(output6.val(), log(2.0) + value_of(b)); + EXPECT_FLOAT_EQ(b.adj(), 1.0); + + var b2 = -2; + var b3 = 1e20; + var output7 = stan::math::log_sum_exp(b2, b3); + output7.grad(); + EXPECT_FLOAT_EQ(b2.adj(), 0.0); + EXPECT_FLOAT_EQ(b3.adj(), 1.0); + + var b4 = 1e20; + var b5 = -2; + var output8 = stan::math::log_sum_exp(b4, b5); + output8.grad(); + EXPECT_FLOAT_EQ(b4.adj(), 1.0); + EXPECT_FLOAT_EQ(b5.adj(), 0.0); + + // check arguement combinations of vars and doubles + var a6 = 1e50; + double a7 = 1; + var output4 = stan::math::log_sum_exp(a6, a7); + output4.grad(); + EXPECT_FLOAT_EQ(a6.adj(), 1.0); + + var a8 = 1; + double a9 = 1e50; + var output5 = stan::math::log_sum_exp(a8, a9); + output5.grad(); + EXPECT_FLOAT_EQ(a8.adj(), 0.0); +}