From 5fd8e2b36102460fb3106e238d86a29295eb4aad Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Tue, 10 Mar 2020 16:24:31 -0400 Subject: [PATCH 01/10] change derivative calculation for log_sum_exp; fixes #1679 --- stan/math/rev/fun/log_sum_exp.hpp | 6 +++-- test/unit/math/rev/fun/log_sum_exp_test.cpp | 29 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 test/unit/math/rev/fun/log_sum_exp_test.cpp diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index 9ca331aab74..fd5596ec1e8 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -22,8 +22,8 @@ class log_sum_exp_vv_vari : public op_vv_vari { log_sum_exp_vv_vari(vari* avi, vari* bvi) : op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {} void chain() { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); - bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_); + avi_->adj_ += adj_ / (1 + exp(bvi_->val_ - avi_->val_)); + bvi_->adj_ += adj_ / (1 + exp(avi_->val_ - bvi_->val_)); } }; class log_sum_exp_vd_vari : public op_vd_vari { @@ -69,7 +69,9 @@ class log_sum_exp_matrix_vari : public op_matrix_vari { : op_matrix_vari(log_sum_exp(x.val()), x) {} void chain() { Eigen::Map vis_map(vis_, size_); + vis_map.adj().array() += adj_ * (vis_map.val().array() - val_).exp(); + } }; } // namespace internal diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp new file mode 100644 index 00000000000..9efaf0d78a9 --- /dev/null +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include + +TEST(log_sum_exp_tests, large_values) { + using stan::math::var; + + var a = 1e50; + var output = stan::math::log_sum_exp(a, a); + output.grad(); + EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a)); + EXPECT_FLOAT_EQ(a.adj(), 1.0); + + var a1 = 1e50; + var a2 = 1; + var output2 = stan::math::log_sum_exp(a1, a2); + output2.grad(); + EXPECT_FLOAT_EQ(a1.adj(), 1.0); + EXPECT_FLOAT_EQ(a2.adj(), 0.0); + + var a3 = 1; + var a4 = 1e50; + var output3 = stan::math::log_sum_exp(a3, a4); + output3.grad(); + EXPECT_FLOAT_EQ(a3.adj(), 0.0); + EXPECT_FLOAT_EQ(a4.adj(), 1.0); + +} From 0869a162b24a84cd7d54d835a587fa716e6d6236 Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Tue, 10 Mar 2020 16:48:36 -0400 Subject: [PATCH 02/10] remove white space --- stan/math/rev/fun/log_sum_exp.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index fd5596ec1e8..49cff4e6b1c 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -69,9 +69,7 @@ class log_sum_exp_matrix_vari : public op_matrix_vari { : op_matrix_vari(log_sum_exp(x.val()), x) {} void chain() { Eigen::Map vis_map(vis_, size_); - vis_map.adj().array() += adj_ * (vis_map.val().array() - val_).exp(); - } }; } // namespace internal From 2f5432fc585e9e12b7b7452875b2f4dcb6e38f1a Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Tue, 10 Mar 2020 17:25:24 -0400 Subject: [PATCH 03/10] remove more white space --- test/unit/math/rev/fun/log_sum_exp_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp index 9efaf0d78a9..805f2047836 100644 --- a/test/unit/math/rev/fun/log_sum_exp_test.cpp +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -25,5 +25,4 @@ TEST(log_sum_exp_tests, large_values) { output3.grad(); EXPECT_FLOAT_EQ(a3.adj(), 0.0); EXPECT_FLOAT_EQ(a4.adj(), 1.0); - } From b5133581d24162e75ca5d04a5a39d43205707767 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 10 Mar 2020 21:33:55 +0000 Subject: [PATCH 04/10] [Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (tags/RELEASE_500/final) --- test/unit/math/rev/fun/log_sum_exp_test.cpp | 36 ++++++++++----------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp index 805f2047836..95e2a72797a 100644 --- a/test/unit/math/rev/fun/log_sum_exp_test.cpp +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -4,25 +4,25 @@ #include TEST(log_sum_exp_tests, large_values) { - using stan::math::var; + using stan::math::var; - var a = 1e50; - var output = stan::math::log_sum_exp(a, a); - output.grad(); - EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a)); - EXPECT_FLOAT_EQ(a.adj(), 1.0); + var a = 1e50; + var output = stan::math::log_sum_exp(a, a); + output.grad(); + EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a)); + EXPECT_FLOAT_EQ(a.adj(), 1.0); - var a1 = 1e50; - var a2 = 1; - var output2 = stan::math::log_sum_exp(a1, a2); - output2.grad(); - EXPECT_FLOAT_EQ(a1.adj(), 1.0); - EXPECT_FLOAT_EQ(a2.adj(), 0.0); + var a1 = 1e50; + var a2 = 1; + var output2 = stan::math::log_sum_exp(a1, a2); + output2.grad(); + EXPECT_FLOAT_EQ(a1.adj(), 1.0); + EXPECT_FLOAT_EQ(a2.adj(), 0.0); - var a3 = 1; - var a4 = 1e50; - var output3 = stan::math::log_sum_exp(a3, a4); - output3.grad(); - EXPECT_FLOAT_EQ(a3.adj(), 0.0); - EXPECT_FLOAT_EQ(a4.adj(), 1.0); + var a3 = 1; + var a4 = 1e50; + var output3 = stan::math::log_sum_exp(a3, a4); + output3.grad(); + EXPECT_FLOAT_EQ(a3.adj(), 0.0); + EXPECT_FLOAT_EQ(a4.adj(), 1.0); } From 3a03bd98a743bb608af45452ecc9db8a8f045e3a Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Wed, 11 Mar 2020 18:10:00 -0400 Subject: [PATCH 05/10] replaced calculate_chain and removed function --- stan/math/rev/fun.hpp | 1 - stan/math/rev/fun/calculate_chain.hpp | 16 ---------------- stan/math/rev/fun/log1p_exp.hpp | 3 +-- stan/math/rev/fun/log_diff_exp.hpp | 5 ++--- stan/math/rev/fun/log_sum_exp.hpp | 3 +-- test/unit/math/mix/fun/log_sum_exp_test.cpp | 14 ++++++++++++++ 6 files changed, 18 insertions(+), 24 deletions(-) delete mode 100644 stan/math/rev/fun/calculate_chain.hpp diff --git a/stan/math/rev/fun.hpp b/stan/math/rev/fun.hpp index 1deb92f9738..295e223e860 100644 --- a/stan/math/rev/fun.hpp +++ b/stan/math/rev/fun.hpp @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include diff --git a/stan/math/rev/fun/calculate_chain.hpp b/stan/math/rev/fun/calculate_chain.hpp deleted file mode 100644 index 2ce66667b3a..00000000000 --- a/stan/math/rev/fun/calculate_chain.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef STAN_MATH_REV_FUN_CALCULATE_CHAIN_HPP -#define STAN_MATH_REV_FUN_CALCULATE_CHAIN_HPP - -#include -#include - -namespace stan { -namespace math { - -inline double calculate_chain(double x, double val) { - return std::exp(x - val); // works out to inv_logit(x) -} - -} // namespace math -} // namespace stan -#endif diff --git a/stan/math/rev/fun/log1p_exp.hpp b/stan/math/rev/fun/log1p_exp.hpp index 762f8ee356a..068ae600834 100644 --- a/stan/math/rev/fun/log1p_exp.hpp +++ b/stan/math/rev/fun/log1p_exp.hpp @@ -3,7 +3,6 @@ #include #include -#include #include namespace stan { @@ -13,7 +12,7 @@ namespace internal { class log1p_exp_v_vari : public op_v_vari { public: explicit log1p_exp_v_vari(vari* avi) : op_v_vari(log1p_exp(avi->val_), avi) {} - void chain() { avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); } + void chain() { avi_->adj_ += adj_ / (1 + std::exp(-avi_->val_)); } }; } // namespace internal diff --git a/stan/math/rev/fun/log_diff_exp.hpp b/stan/math/rev/fun/log_diff_exp.hpp index 978987ca5c9..aa8c0034ad9 100644 --- a/stan/math/rev/fun/log_diff_exp.hpp +++ b/stan/math/rev/fun/log_diff_exp.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -17,7 +16,7 @@ class log_diff_exp_vv_vari : public op_vv_vari { log_diff_exp_vv_vari(vari* avi, vari* bvi) : op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {} void chain() { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_); bvi_->adj_ -= adj_ / expm1(avi_->val_ - bvi_->val_); } }; @@ -29,7 +28,7 @@ class log_diff_exp_vd_vari : public op_vd_vari { if (val_ == NEGATIVE_INFTY) { avi_->adj_ += (bd_ == NEGATIVE_INFTY) ? adj_ : adj_ * INFTY; } else { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ -= adj_ / expm1(bd_ - avi_->val_); } } }; diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index 49cff4e6b1c..94112170aa8 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -34,7 +33,7 @@ class log_sum_exp_vd_vari : public op_vd_vari { if (val_ == NEGATIVE_INFTY) { avi_->adj_ += adj_; } else { - avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); + avi_->adj_ += adj_ / (1 + exp(bd_ - avi_->val_)); } } }; diff --git a/test/unit/math/mix/fun/log_sum_exp_test.cpp b/test/unit/math/mix/fun/log_sum_exp_test.cpp index 7244b70895e..90bb1c5504b 100644 --- a/test/unit/math/mix/fun/log_sum_exp_test.cpp +++ b/test/unit/math/mix/fun/log_sum_exp_test.cpp @@ -90,3 +90,17 @@ TEST(MathMixMatFun, logSumExp) { std::vector(x2c.data(), x2c.data() + x2c.size())}; stan::test::expect_ad(tols, f, ststx); } + +TEST(mathMixVarDouble, logSumExp) { + using stan::math::var; + + var a = 12; + double b = 8; + var out1 = stan::math::log_sum_exp(a, b); + + double a1 = 12; + double b1 = 8; + var out2 = stan::math::log_sum_exp(a1, b1); + + EXPECT_FLOAT_EQ(value_of(out1), value_of(out2)); +} From 817dec871e3aa6dc20f67b11390b03925304c084 Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Fri, 13 Mar 2020 11:15:34 -0400 Subject: [PATCH 06/10] included inv_logit, added new test --- stan/math/rev/fun/log1p_exp.hpp | 2 +- stan/math/rev/fun/log_sum_exp.hpp | 6 ++-- test/unit/math/mix/fun/log_sum_exp_test.cpp | 14 --------- test/unit/math/rev/fun/log_sum_exp_test.cpp | 35 +++++++++++++++++---- 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/stan/math/rev/fun/log1p_exp.hpp b/stan/math/rev/fun/log1p_exp.hpp index 068ae600834..bf3a7442023 100644 --- a/stan/math/rev/fun/log1p_exp.hpp +++ b/stan/math/rev/fun/log1p_exp.hpp @@ -12,7 +12,7 @@ namespace internal { class log1p_exp_v_vari : public op_v_vari { public: explicit log1p_exp_v_vari(vari* avi) : op_v_vari(log1p_exp(avi->val_), avi) {} - void chain() { avi_->adj_ += adj_ / (1 + std::exp(-avi_->val_)); } + void chain() { avi_->adj_ += adj_ * inv_logit(avi_->val_); } }; } // namespace internal diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index 94112170aa8..f29928bc209 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -21,8 +21,8 @@ class log_sum_exp_vv_vari : public op_vv_vari { log_sum_exp_vv_vari(vari* avi, vari* bvi) : op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {} void chain() { - avi_->adj_ += adj_ / (1 + exp(bvi_->val_ - avi_->val_)); - bvi_->adj_ += adj_ / (1 + exp(avi_->val_ - bvi_->val_)); + avi_->adj_ += adj_ * inv_logit(avi_->val_ - bvi_->val_); + bvi_->adj_ += adj_ * inv_logit(bvi_->val_ - avi_->val_); } }; class log_sum_exp_vd_vari : public op_vd_vari { @@ -33,7 +33,7 @@ class log_sum_exp_vd_vari : public op_vd_vari { if (val_ == NEGATIVE_INFTY) { avi_->adj_ += adj_; } else { - avi_->adj_ += adj_ / (1 + exp(bd_ - avi_->val_)); + avi_->adj_ += adj_ * inv_logit(avi_->val_ - bd_); } } }; diff --git a/test/unit/math/mix/fun/log_sum_exp_test.cpp b/test/unit/math/mix/fun/log_sum_exp_test.cpp index 90bb1c5504b..7244b70895e 100644 --- a/test/unit/math/mix/fun/log_sum_exp_test.cpp +++ b/test/unit/math/mix/fun/log_sum_exp_test.cpp @@ -90,17 +90,3 @@ TEST(MathMixMatFun, logSumExp) { std::vector(x2c.data(), x2c.data() + x2c.size())}; stan::test::expect_ad(tols, f, ststx); } - -TEST(mathMixVarDouble, logSumExp) { - using stan::math::var; - - var a = 12; - double b = 8; - var out1 = stan::math::log_sum_exp(a, b); - - double a1 = 12; - double b1 = 8; - var out2 = stan::math::log_sum_exp(a1, b1); - - EXPECT_FLOAT_EQ(value_of(out1), value_of(out2)); -} diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp index 95e2a72797a..576e28c4b0c 100644 --- a/test/unit/math/rev/fun/log_sum_exp_test.cpp +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -6,23 +6,46 @@ TEST(log_sum_exp_tests, large_values) { using stan::math::var; + // check combinations of vars and doubles with large argument values var a = 1e50; var output = stan::math::log_sum_exp(a, a); output.grad(); EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a)); EXPECT_FLOAT_EQ(a.adj(), 1.0); - var a1 = 1e50; var a2 = 1; - var output2 = stan::math::log_sum_exp(a1, a2); + var a3 = 1e50; + var output2 = stan::math::log_sum_exp(a2, a3); output2.grad(); - EXPECT_FLOAT_EQ(a1.adj(), 1.0); EXPECT_FLOAT_EQ(a2.adj(), 0.0); + EXPECT_FLOAT_EQ(a3.adj(), 1.0); - var a3 = 1; var a4 = 1e50; - var output3 = stan::math::log_sum_exp(a3, a4); + var a5 = 1; + var output3 = stan::math::log_sum_exp(a4, a5); output3.grad(); - EXPECT_FLOAT_EQ(a3.adj(), 0.0); EXPECT_FLOAT_EQ(a4.adj(), 1.0); + EXPECT_FLOAT_EQ(a5.adj(), 0.0); + + + // check combinations of vars and doubles with large argument values + var b = 1e20; + var output4 = stan::math::log_sum_exp(b, b); + output4.grad(); + EXPECT_FLOAT_EQ(output4.val(), log(2.0) + value_of(b)); + EXPECT_FLOAT_EQ(b.adj(), 1.0); + + var b2 = -2; + var b3 = 1e20; + var output5 = stan::math::log_sum_exp(b2, b3); + output5.grad(); + EXPECT_FLOAT_EQ(b2.adj(), 0.0); + EXPECT_FLOAT_EQ(b3.adj(), 1.0); + + var b4 = 1e20; + var b5 = -2; + var output6 = stan::math::log_sum_exp(b4, b5); + output6.grad(); + EXPECT_FLOAT_EQ(b4.adj(), 1.0); + EXPECT_FLOAT_EQ(b5.adj(), 0.0); } From 829be7f5e20f26089da8d1d68ddfe3e721292d56 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 13 Mar 2020 11:17:05 -0400 Subject: [PATCH 07/10] [Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/stable/2017-11-14) --- test/unit/math/rev/fun/log_sum_exp_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp index 576e28c4b0c..f32289d77ba 100644 --- a/test/unit/math/rev/fun/log_sum_exp_test.cpp +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -27,7 +27,6 @@ TEST(log_sum_exp_tests, large_values) { EXPECT_FLOAT_EQ(a4.adj(), 1.0); EXPECT_FLOAT_EQ(a5.adj(), 0.0); - // check combinations of vars and doubles with large argument values var b = 1e20; var output4 = stan::math::log_sum_exp(b, b); From c2619312ca8292ecbcbb424e3d53f0f67d18dbf3 Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Fri, 13 Mar 2020 11:39:43 -0400 Subject: [PATCH 08/10] added test for var and double --- test/unit/math/rev/fun/log_sum_exp_test.cpp | 32 ++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/test/unit/math/rev/fun/log_sum_exp_test.cpp b/test/unit/math/rev/fun/log_sum_exp_test.cpp index 576e28c4b0c..aa6d49f81a1 100644 --- a/test/unit/math/rev/fun/log_sum_exp_test.cpp +++ b/test/unit/math/rev/fun/log_sum_exp_test.cpp @@ -6,7 +6,7 @@ TEST(log_sum_exp_tests, large_values) { using stan::math::var; - // check combinations of vars and doubles with large argument values + // check autodiffing works with var types with large values var a = 1e50; var output = stan::math::log_sum_exp(a, a); output.grad(); @@ -27,25 +27,37 @@ TEST(log_sum_exp_tests, large_values) { EXPECT_FLOAT_EQ(a4.adj(), 1.0); EXPECT_FLOAT_EQ(a5.adj(), 0.0); - - // check combinations of vars and doubles with large argument values + // check autodiffing works with var types with large values var b = 1e20; - var output4 = stan::math::log_sum_exp(b, b); - output4.grad(); - EXPECT_FLOAT_EQ(output4.val(), log(2.0) + value_of(b)); + var output6 = stan::math::log_sum_exp(b, b); + output6.grad(); + EXPECT_FLOAT_EQ(output6.val(), log(2.0) + value_of(b)); EXPECT_FLOAT_EQ(b.adj(), 1.0); var b2 = -2; var b3 = 1e20; - var output5 = stan::math::log_sum_exp(b2, b3); - output5.grad(); + var output7 = stan::math::log_sum_exp(b2, b3); + output7.grad(); EXPECT_FLOAT_EQ(b2.adj(), 0.0); EXPECT_FLOAT_EQ(b3.adj(), 1.0); var b4 = 1e20; var b5 = -2; - var output6 = stan::math::log_sum_exp(b4, b5); - output6.grad(); + var output8 = stan::math::log_sum_exp(b4, b5); + output8.grad(); EXPECT_FLOAT_EQ(b4.adj(), 1.0); EXPECT_FLOAT_EQ(b5.adj(), 0.0); + + // check arguement combinations of vars and doubles + var a6 = 1e50; + double a7 = 1; + var output4 = stan::math::log_sum_exp(a6, a7); + output4.grad(); + EXPECT_FLOAT_EQ(a6.adj(), 1.0); + + var a8 = 1; + double a9 = 1e50; + var output5 = stan::math::log_sum_exp(a8, a9); + output5.grad(); + EXPECT_FLOAT_EQ(a8.adj(), 0.0); } From 62a3c1fbc7a27a7c55743544bf7248ef5ac458d8 Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Mon, 16 Mar 2020 13:49:03 -0400 Subject: [PATCH 09/10] updated header --- stan/math/rev/fun/log1p_exp.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stan/math/rev/fun/log1p_exp.hpp b/stan/math/rev/fun/log1p_exp.hpp index bf3a7442023..5126b2bcb8d 100644 --- a/stan/math/rev/fun/log1p_exp.hpp +++ b/stan/math/rev/fun/log1p_exp.hpp @@ -3,6 +3,7 @@ #include #include +#include #include namespace stan { From a793873483a3be437c54c934eb17bb8f3b0de386 Mon Sep 17 00:00:00 2001 From: Philip Greengard Date: Mon, 16 Mar 2020 17:54:19 -0400 Subject: [PATCH 10/10] updated another header --- stan/math/rev/fun/log_sum_exp.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/stan/math/rev/fun/log_sum_exp.hpp b/stan/math/rev/fun/log_sum_exp.hpp index f29928bc209..f003ead2f6b 100644 --- a/stan/math/rev/fun/log_sum_exp.hpp +++ b/stan/math/rev/fun/log_sum_exp.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include