Skip to content

Commit c26cda1

Browse files
authored
Merge pull request #1772 from pgree/bugfix/1679-log-sum-exp
Changed calculation of log_sum_exp(x1, x2)
2 parents f3cbe21 + a793873 commit c26cda1

6 files changed

Lines changed: 71 additions & 26 deletions

File tree

stan/math/rev/fun.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <stan/math/rev/fun/bessel_second_kind.hpp>
2323
#include <stan/math/rev/fun/beta.hpp>
2424
#include <stan/math/rev/fun/binary_log_loss.hpp>
25-
#include <stan/math/rev/fun/calculate_chain.hpp>
2625
#include <stan/math/rev/fun/cbrt.hpp>
2726
#include <stan/math/rev/fun/ceil.hpp>
2827
#include <stan/math/rev/fun/cholesky_decompose.hpp>

stan/math/rev/fun/calculate_chain.hpp

Lines changed: 0 additions & 16 deletions
This file was deleted.

stan/math/rev/fun/log1p_exp.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/rev/fun/calculate_chain.hpp>
6+
#include <stan/math/prim/fun/inv_logit.hpp>
77
#include <stan/math/prim/fun/log1p_exp.hpp>
88

99
namespace stan {
@@ -13,7 +13,7 @@ namespace internal {
1313
class log1p_exp_v_vari : public op_v_vari {
1414
public:
1515
explicit log1p_exp_v_vari(vari* avi) : op_v_vari(log1p_exp(avi->val_), avi) {}
16-
void chain() { avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_); }
16+
void chain() { avi_->adj_ += adj_ * inv_logit(avi_->val_); }
1717
};
1818
} // namespace internal
1919

stan/math/rev/fun/log_diff_exp.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/rev/fun/calculate_chain.hpp>
76
#include <stan/math/prim/fun/constants.hpp>
87
#include <stan/math/prim/fun/expm1.hpp>
98
#include <stan/math/prim/fun/log_diff_exp.hpp>
@@ -17,7 +16,7 @@ class log_diff_exp_vv_vari : public op_vv_vari {
1716
log_diff_exp_vv_vari(vari* avi, vari* bvi)
1817
: op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {}
1918
void chain() {
20-
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
19+
avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_);
2120
bvi_->adj_ -= adj_ / expm1(avi_->val_ - bvi_->val_);
2221
}
2322
};
@@ -29,7 +28,7 @@ class log_diff_exp_vd_vari : public op_vd_vari {
2928
if (val_ == NEGATIVE_INFTY) {
3029
avi_->adj_ += (bd_ == NEGATIVE_INFTY) ? adj_ : adj_ * INFTY;
3130
} else {
32-
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
31+
avi_->adj_ -= adj_ / expm1(bd_ - avi_->val_);
3332
}
3433
}
3534
};

stan/math/rev/fun/log_sum_exp.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/rev/fun/calculate_chain.hpp>
76
#include <stan/math/rev/fun/typedefs.hpp>
87
#include <stan/math/prim/meta.hpp>
98
#include <stan/math/prim/fun/constants.hpp>
109
#include <stan/math/prim/fun/Eigen.hpp>
10+
#include <stan/math/prim/fun/inv_logit.hpp>
1111
#include <stan/math/prim/fun/log_sum_exp.hpp>
1212
#include <cmath>
1313
#include <vector>
@@ -22,8 +22,8 @@ class log_sum_exp_vv_vari : public op_vv_vari {
2222
log_sum_exp_vv_vari(vari* avi, vari* bvi)
2323
: op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {}
2424
void chain() {
25-
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
26-
bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
25+
avi_->adj_ += adj_ * inv_logit(avi_->val_ - bvi_->val_);
26+
bvi_->adj_ += adj_ * inv_logit(bvi_->val_ - avi_->val_);
2727
}
2828
};
2929
class log_sum_exp_vd_vari : public op_vd_vari {
@@ -34,7 +34,7 @@ class log_sum_exp_vd_vari : public op_vd_vari {
3434
if (val_ == NEGATIVE_INFTY) {
3535
avi_->adj_ += adj_;
3636
} else {
37-
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
37+
avi_->adj_ += adj_ * inv_logit(avi_->val_ - bd_);
3838
}
3939
}
4040
};
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include <stan/math/rev.hpp>
2+
#include <gtest/gtest.h>
3+
#include <test/unit/math/rev/fun/util.hpp>
4+
#include <test/unit/math/rev/util.hpp>
5+
6+
TEST(log_sum_exp_tests, large_values) {
7+
using stan::math::var;
8+
9+
// check autodiffing works with var types with large values
10+
var a = 1e50;
11+
var output = stan::math::log_sum_exp(a, a);
12+
output.grad();
13+
EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a));
14+
EXPECT_FLOAT_EQ(a.adj(), 1.0);
15+
16+
var a2 = 1;
17+
var a3 = 1e50;
18+
var output2 = stan::math::log_sum_exp(a2, a3);
19+
output2.grad();
20+
EXPECT_FLOAT_EQ(a2.adj(), 0.0);
21+
EXPECT_FLOAT_EQ(a3.adj(), 1.0);
22+
23+
var a4 = 1e50;
24+
var a5 = 1;
25+
var output3 = stan::math::log_sum_exp(a4, a5);
26+
output3.grad();
27+
EXPECT_FLOAT_EQ(a4.adj(), 1.0);
28+
EXPECT_FLOAT_EQ(a5.adj(), 0.0);
29+
30+
// check autodiffing works with var types with large values
31+
var b = 1e20;
32+
var output6 = stan::math::log_sum_exp(b, b);
33+
output6.grad();
34+
EXPECT_FLOAT_EQ(output6.val(), log(2.0) + value_of(b));
35+
EXPECT_FLOAT_EQ(b.adj(), 1.0);
36+
37+
var b2 = -2;
38+
var b3 = 1e20;
39+
var output7 = stan::math::log_sum_exp(b2, b3);
40+
output7.grad();
41+
EXPECT_FLOAT_EQ(b2.adj(), 0.0);
42+
EXPECT_FLOAT_EQ(b3.adj(), 1.0);
43+
44+
var b4 = 1e20;
45+
var b5 = -2;
46+
var output8 = stan::math::log_sum_exp(b4, b5);
47+
output8.grad();
48+
EXPECT_FLOAT_EQ(b4.adj(), 1.0);
49+
EXPECT_FLOAT_EQ(b5.adj(), 0.0);
50+
51+
// check arguement combinations of vars and doubles
52+
var a6 = 1e50;
53+
double a7 = 1;
54+
var output4 = stan::math::log_sum_exp(a6, a7);
55+
output4.grad();
56+
EXPECT_FLOAT_EQ(a6.adj(), 1.0);
57+
58+
var a8 = 1;
59+
double a9 = 1e50;
60+
var output5 = stan::math::log_sum_exp(a8, a9);
61+
output5.grad();
62+
EXPECT_FLOAT_EQ(a8.adj(), 0.0);
63+
}

0 commit comments

Comments
 (0)