Skip to content

Commit fe4b1dd

Browse files
committed
Merge branch 'update/eigen-3.4' of github.com:stan-dev/math into update/eigen-3.4
2 parents 1c90adf + d9e1b11 commit fe4b1dd

3 files changed

Lines changed: 28 additions & 14 deletions

File tree

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ mdivide_right(const EigMat1& b, const EigMat2& A) {
5454
}
5555
}
5656
auto A_mult_inv_b = mdivide_right(val_b, val_A).eval();
57-
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)> ret(A_mult_inv_b.rows(), A_mult_inv_b.cols());
58-
ret.val() = A_mult_inv_b;
59-
ret.d() = mdivide_right(deriv_b, val_A)
57+
promote_scalar_t<fvar<inner_ret_t>, decltype(A_mult_inv_b)>
58+
ret(A_mult_inv_b.rows(), A_mult_inv_b.cols()); ret.val() = A_mult_inv_b; ret.d()
59+
= mdivide_right(deriv_b, val_A)
6060
- multiply(A_mult_inv_b, mdivide_right(deriv_A, val_A));
6161
return ret;
6262
}
@@ -70,10 +70,12 @@ template <typename EigMat1, typename EigMat2,
7070
check_square("mdivide_right", "A", A);
7171
check_multiplicable("mdivide_right", "b", b, "A", A);
7272
if (A.size() == 0) {
73-
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
74-
return ret_type{b.rows(), 0};
73+
using ret_type = decltype(A.transpose().template
74+
cast<T_return>().lu().solve(b.template
75+
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
7576
}
76-
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
77+
return A.transpose().template cast<T_return>().lu().solve(b.template
78+
cast<T_return>().transpose()).transpose().eval();
7779
}
7880
7981
template <typename EigMat1, typename EigMat2,
@@ -85,10 +87,12 @@ template <typename EigMat1, typename EigMat2,
8587
check_square("mdivide_right", "A", A);
8688
check_multiplicable("mdivide_right", "b", b, "A", A);
8789
if (A.size() == 0) {
88-
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
89-
return ret_type{b.rows(), 0};
90+
using ret_type = decltype(A.transpose().template
91+
cast<T_return>().lu().solve(b.template
92+
cast<T_return>().transpose()).transpose().eval()); return ret_type{b.rows(), 0};
9093
}
91-
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
94+
return A.transpose().template cast<T_return>().lu().solve(b.template
95+
cast<T_return>().transpose()).transpose().eval();
9296
}
9397
*/
9498
} // namespace math

stan/math/prim/fun/mdivide_right.hpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,26 @@ namespace math {
2222
*/
2323
template <typename EigMat1, typename EigMat2,
2424
require_all_eigen_t<EigMat1, EigMat2>* = nullptr>
25-
inline auto
26-
mdivide_right(const EigMat1& b, const EigMat2& A) {
25+
inline auto mdivide_right(const EigMat1& b, const EigMat2& A) {
2726
using T_return = return_type_t<EigMat1, EigMat2>;
2827
check_square("mdivide_right", "A", A);
2928
check_multiplicable("mdivide_right", "b", b, "A", A);
3029
if (A.size() == 0) {
31-
using ret_type = decltype(A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval());
30+
using ret_type
31+
= decltype(A.transpose()
32+
.template cast<T_return>()
33+
.lu()
34+
.solve(b.template cast<T_return>().transpose())
35+
.transpose()
36+
.eval());
3237
return ret_type{b.rows(), 0};
3338
}
34-
return A.transpose().template cast<T_return>().lu().solve(b.template cast<T_return>().transpose()).transpose().eval();
39+
return A.transpose()
40+
.template cast<T_return>()
41+
.lu()
42+
.solve(b.template cast<T_return>().transpose())
43+
.transpose()
44+
.eval();
3545
}
3646

3747
} // namespace math

test/unit/math/mix/fun/mdivide_right_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST(MathMixMatFun, mdivideRight_rowvector_matrix1) {
7474
Eigen::RowVectorXd g(2);
7575
g << 1, 1;
7676

77-
stan::test::expect_ad(f, g, b);
77+
stan::test::expect_ad(f, g, b);
7878
// vector, matrix
7979
/*
8080
for (const auto& m : std::vector<Eigen::MatrixXd>{b}) {

0 commit comments

Comments
 (0)