Skip to content

Commit f11c9e0

Browse files
authored
Merge pull request #1715 from stan-dev/cleanup/1714-mdivide-left-right
Cleanup mdivide_* and tests
2 parents 9b2e93b + ed296c8 commit f11c9e0

27 files changed

+254
-216
lines changed

stan/math/fwd/fun/mdivide_left.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/mdivide_left.hpp>
77
#include <stan/math/prim/fun/multiply.hpp>
8-
#include <stan/math/prim/fun/typedefs.hpp>
98
#include <stan/math/fwd/core.hpp>
109
#include <stan/math/fwd/fun/multiply.hpp>
1110
#include <stan/math/fwd/fun/to_fvar.hpp>

stan/math/fwd/fun/mdivide_left_tri_low.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace stan {
1313
namespace math {
1414

1515
template <typename T, int R1, int C1, int R2, int C2>
16-
inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_left_tri_low(
16+
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
1717
const Eigen::Matrix<fvar<T>, R1, C1>& A,
1818
const Eigen::Matrix<fvar<T>, R2, C2>& b) {
1919
check_square("mdivide_left_tri_low", "A", A);
@@ -54,7 +54,7 @@ inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_left_tri_low(
5454
}
5555

5656
template <typename T, int R1, int C1, int R2, int C2>
57-
inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_left_tri_low(
57+
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
5858
const Eigen::Matrix<double, R1, C1>& A,
5959
const Eigen::Matrix<fvar<T>, R2, C2>& b) {
6060
check_square("mdivide_left_tri_low", "A", A);
@@ -90,7 +90,7 @@ inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_left_tri_low(
9090
}
9191

9292
template <typename T, int R1, int C1, int R2, int C2>
93-
inline Eigen::Matrix<fvar<T>, R1, C1> mdivide_left_tri_low(
93+
inline Eigen::Matrix<fvar<T>, R1, C2> mdivide_left_tri_low(
9494
const Eigen::Matrix<fvar<T>, R1, C1>& A,
9595
const Eigen::Matrix<double, R2, C2>& b) {
9696
check_square("mdivide_left_tri_low", "A", A);

stan/math/fwd/fun/mdivide_right.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <stan/math/fwd/core.hpp>
99
#include <stan/math/fwd/fun/multiply.hpp>
1010
#include <stan/math/fwd/fun/to_fvar.hpp>
11-
#include <stan/math/fwd/fun/typedefs.hpp>
1211
#include <vector>
1312

1413
namespace stan {

stan/math/opencl/prim/mdivide_left_tri_low.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
#include <stan/math/opencl/matrix_cl.hpp>
77
#include <stan/math/opencl/multiply.hpp>
88
#include <stan/math/opencl/tri_inverse.hpp>
9+
910
namespace stan {
1011
namespace math {
1112

1213
/**
1314
* Returns the solution of the system Ax=b when A is lower triangular.
15+
*
1416
* @tparam T1 type of elements in A
1517
* @tparam T2 type of elements in b
1618
* @param A Triangular matrix.
@@ -30,6 +32,7 @@ inline matrix_cl<return_type_t<T1, T2>> mdivide_left_tri_low(
3032

3133
/**
3234
* Returns the solution of the system Ax=b when A is triangular and b=I.
35+
*
3336
* @tparam T type of elements in A
3437
* @tparam R1 number of rows in A
3538
* @tparam C1 number of columns in A

stan/math/opencl/prim/mdivide_right_tri_low.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#include <stan/math/opencl/matrix_cl.hpp>
77
#include <stan/math/opencl/multiply.hpp>
88
#include <stan/math/opencl/tri_inverse.hpp>
9+
910
namespace stan {
1011
namespace math {
1112

1213
/**
1314
* Returns the solution of the system Ax=b where A is a
1415
* lower triangular matrix.
16+
*
1517
* @param A Matrix.
1618
* @param b Right hand side matrix or vector.
1719
* @return x = b * tri(A)^-1, solution of the linear system.

stan/math/prim/fun/mdivide_left.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77

88
namespace stan {
99
namespace math {
@@ -31,12 +31,9 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left(
3131
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3232
check_square("mdivide_left", "A", A);
3333
check_multiplicable("mdivide_left", "A", A, "b", b);
34-
return promote_common<Eigen::Matrix<T1, R1, C1>, Eigen::Matrix<T2, R1, C1> >(
35-
A)
36-
.lu()
37-
.solve(
38-
promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(
39-
b));
34+
35+
return Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(A).lu().solve(
36+
Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(b));
4037
}
4138

4239
} // namespace math

stan/math/prim/fun/mdivide_left_ldlt.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_LDLT_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_LDLT_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
67
#include <stan/math/prim/fun/LDLT_factor.hpp>
7-
#include <stan/math/prim/fun/promote_common.hpp>
88
#include <type_traits>
99

1010
namespace stan {
@@ -36,8 +36,7 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_ldlt(
3636

3737
check_multiplicable("mdivide_left_ldlt", "A", A, "b", b);
3838

39-
return A.solve(
40-
promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(b));
39+
return A.solve(Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(b));
4140
}
4241

4342
} // namespace math

stan/math/prim/fun/mdivide_left_spd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_SPD_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_SPD_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77

88
namespace stan {
99
namespace math {

stan/math/prim/fun/mdivide_left_tri.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_LEFT_TRI_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77
#ifdef STAN_OPENCL
88
#include <stan/math/opencl/opencl.hpp>
99
#endif
@@ -36,12 +36,10 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_left_tri(
3636
const Eigen::Matrix<T1, R1, C1> &A, const Eigen::Matrix<T2, R2, C2> &b) {
3737
check_square("mdivide_left_tri", "A", A);
3838
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
39-
return promote_common<Eigen::Matrix<T1, R1, C1>, Eigen::Matrix<T2, R1, C1> >(
40-
A)
39+
40+
return Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(A)
4141
.template triangularView<TriView>()
42-
.solve(
43-
promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(
44-
b));
42+
.solve(Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(b));
4543
}
4644

4745
/**

stan/math/prim/fun/mdivide_right.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77

88
namespace stan {
99
namespace math {
@@ -31,14 +31,11 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right(
3131
const Eigen::Matrix<T1, R1, C1> &b, const Eigen::Matrix<T2, R2, C2> &A) {
3232
check_square("mdivide_right", "A", A);
3333
check_multiplicable("mdivide_right", "b", b, "A", A);
34-
return promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(
35-
A)
34+
35+
return Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(A)
3636
.transpose()
3737
.lu()
38-
.solve(
39-
promote_common<Eigen::Matrix<T1, R1, C1>, Eigen::Matrix<T2, R1, C1> >(
40-
b)
41-
.transpose())
38+
.solve(Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(b).transpose())
4239
.transpose();
4340
}
4441

stan/math/prim/fun/mdivide_right_ldlt.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ namespace math {
2727
* @return x = b A^-1, solution of the linear system.
2828
* @throws std::domain_error if rows of b don't match the size of A.
2929
*/
30-
3130
template <typename T1, typename T2, int R1, int C1, int R2, int C2>
3231
inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_ldlt(
3332
const Eigen::Matrix<T1, R1, C1> &b, const LDLT_factor<T2, R2, C2> &A) {

stan/math/prim/fun/mdivide_right_tri.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77
#ifdef STAN_OPENCL
88
#include <stan/math/opencl/opencl.hpp>
99
#endif
@@ -43,14 +43,11 @@ inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_tri(
4343
"triangular view must be Eigen::Lower or Eigen::Upper",
4444
"", "");
4545
}
46-
return promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(
47-
A)
46+
47+
return Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(A)
4848
.template triangularView<TriView>()
4949
.transpose()
50-
.solve(
51-
promote_common<Eigen::Matrix<T1, R1, C1>, Eigen::Matrix<T2, R1, C1> >(
52-
b)
53-
.transpose())
50+
.solve(Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(b).transpose())
5451
.transpose();
5552
}
5653

stan/math/prim/fun/mdivide_right_tri_low.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#ifndef STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_LOW_HPP
22
#define STAN_MATH_PRIM_FUN_MDIVIDE_RIGHT_TRI_LOW_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/fun/Eigen.hpp>
56
#include <stan/math/prim/fun/mdivide_right_tri.hpp>
6-
#include <stan/math/prim/fun/promote_common.hpp>
77

88
namespace stan {
99
namespace math {
@@ -31,8 +31,8 @@ template <typename T1, typename T2, int R1, int C1, int R2, int C2>
3131
inline Eigen::Matrix<return_type_t<T1, T2>, R1, C2> mdivide_right_tri_low(
3232
const Eigen::Matrix<T1, R1, C1> &b, const Eigen::Matrix<T2, R2, C2> &A) {
3333
return mdivide_right_tri<Eigen::Lower>(
34-
promote_common<Eigen::Matrix<T1, R1, C1>, Eigen::Matrix<T2, R1, C1> >(b),
35-
promote_common<Eigen::Matrix<T1, R2, C2>, Eigen::Matrix<T2, R2, C2> >(A));
34+
Eigen::Matrix<return_type_t<T1, T2>, R1, C1>(b),
35+
Eigen::Matrix<return_type_t<T1, T2>, R2, C2>(A));
3636
}
3737

3838
} // namespace math

test/unit/math/mix/fun/mdivide_left_spd_test.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,26 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
88
return stan::math::mdivide_left_spd(x_sym, y);
99
};
1010

11-
// signature 1 of 2: matrix-matrix
11+
// size zero inputs
12+
Eigen::MatrixXd m00(0, 0);
13+
Eigen::MatrixXd m02(0, 2);
14+
Eigen::VectorXd v0(0);
15+
stan::test::expect_ad(f, m00, m00);
16+
stan::test::expect_ad(f, m00, m02);
17+
stan::test::expect_ad(f, m00, v0);
18+
1219
Eigen::MatrixXd aa(1, 1);
1320
aa << 1;
1421
Eigen::MatrixXd bb(1, 1);
1522
bb << 2;
1623
stan::test::expect_ad(f, aa, bb);
24+
Eigen::MatrixXd b0(1, 0);
25+
stan::test::expect_ad(f, aa, b0);
1726

18-
// signature 2 of 2: matrix-vector
1927
Eigen::VectorXd cc(1);
2028
cc << 3;
2129
stan::test::expect_ad(f, aa, cc);
2230

23-
Eigen::MatrixXd m00(0, 0);
24-
Eigen::VectorXd v0(0);
25-
stan::test::expect_ad(f, m00, v0);
26-
stan::test::expect_ad(f, m00, m00);
27-
2831
Eigen::MatrixXd a(2, 2);
2932
a << 2, 3, 3, 7;
3033

@@ -45,15 +48,12 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
4548
// matrix, vector : ditto
4649
stan::test::expect_ad(f, a, d);
4750
stan::test::expect_ad(f, b, d);
48-
stan::test::expect_ad(f, a, d);
49-
stan::test::expect_ad(f, b, d);
5051

5152
Eigen::MatrixXd m33 = Eigen::MatrixXd::Zero(3, 3);
5253
Eigen::MatrixXd m44 = Eigen::MatrixXd::Zero(4, 4);
5354
Eigen::VectorXd v3 = Eigen::VectorXd::Zero(3);
5455
Eigen::VectorXd v4 = Eigen::VectorXd::Zero(4);
5556
Eigen::RowVectorXd rv3 = Eigen::RowVectorXd::Zero(3);
56-
Eigen::RowVectorXd rv4 = Eigen::RowVectorXd::Zero(4);
5757

5858
// exceptions: not symmetric
5959
stan::test::expect_ad(f, c, a);

test/unit/math/mix/fun/mdivide_left_test.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,29 @@
33

44
TEST(MathMixMatFun, mdivideLeft) {
55
auto f = [](const auto& x, const auto& y) {
6-
if (x.rows() != x.cols())
7-
return stan::math::mdivide_left(x, y);
8-
auto x_sym = ((x + x.transpose()) * 0.5).eval(); // sym for finite diffs
9-
return stan::math::mdivide_left(x_sym, y);
6+
return stan::math::mdivide_left(x, y);
107
};
118

12-
// signature 1 of 2: matrix-matrix
9+
// size zero inputs
10+
Eigen::MatrixXd m00(0, 0);
11+
Eigen::MatrixXd m02(0, 2);
12+
Eigen::VectorXd v0(0);
13+
stan::test::expect_ad(f, m00, m00);
14+
stan::test::expect_ad(f, m00, m02);
15+
stan::test::expect_ad(f, m00, v0);
16+
1317
Eigen::MatrixXd aa(1, 1);
1418
aa << 1;
1519
Eigen::MatrixXd bb(1, 1);
1620
bb << 2;
1721
stan::test::expect_ad(f, aa, bb);
22+
Eigen::MatrixXd b0(1, 0);
23+
stan::test::expect_ad(f, aa, b0);
1824

19-
// signature 2 of 2: matrix-vector
2025
Eigen::VectorXd cc(1);
2126
cc << 3;
2227
stan::test::expect_ad(f, aa, cc);
2328

24-
Eigen::MatrixXd m00(0, 0);
25-
Eigen::VectorXd v0(0);
26-
stan::test::expect_ad(f, m00, v0);
27-
stan::test::expect_ad(f, m00, m00);
28-
2929
Eigen::MatrixXd a(2, 2);
3030
a << 2, 3, 3, 7;
3131

@@ -38,20 +38,29 @@ TEST(MathMixMatFun, mdivideLeft) {
3838
Eigen::MatrixXd d(2, 2);
3939
d << 2, 3, 5, 7;
4040

41-
Eigen::MatrixXd e(2, 2);
42-
e << 1, 2, 3, 4;
41+
Eigen::MatrixXd e(2, 0);
4342

4443
Eigen::VectorXd g(2);
4544
g << 12, 13;
4645

4746
// matrix, matrix
48-
for (const auto& m1 : std::vector<Eigen::MatrixXd>{a, b, c, d, e})
49-
for (const auto& m2 : std::vector<Eigen::MatrixXd>{a, b, c, d, e})
47+
for (const auto& m1 : std::vector<Eigen::MatrixXd>{a, b, c, d}) {
48+
for (const auto& m2 : std::vector<Eigen::MatrixXd>{a, b, c, d, e}) {
5049
stan::test::expect_ad(f, m1, m2);
50+
}
51+
}
5152

5253
// matrix, vector
53-
for (const auto& m : std::vector<Eigen::MatrixXd>{a, b, c, d, e})
54+
for (const auto& m : std::vector<Eigen::MatrixXd>{a, b, c, d}) {
5455
stan::test::expect_ad(f, m, g);
56+
}
57+
58+
Eigen::MatrixXd v(5, 5);
59+
v << 20, 8, -9, 7, 5, 8, 20, 0, 4, 4, -9, 0, 20, 2, 5, 7, 4, 2, 20, -5, 5, 4,
60+
5, -5, 20;
61+
Eigen::VectorXd u(5);
62+
u << 62, 84, 84, 76, 108;
63+
stan::test::expect_ad(f, v, u);
5564

5665
Eigen::MatrixXd m33 = Eigen::MatrixXd::Zero(3, 3);
5766
Eigen::MatrixXd m44 = Eigen::MatrixXd::Zero(4, 4);
@@ -60,10 +69,6 @@ TEST(MathMixMatFun, mdivideLeft) {
6069
Eigen::RowVectorXd rv3 = Eigen::RowVectorXd::Zero(3);
6170
Eigen::RowVectorXd rv4 = Eigen::RowVectorXd::Zero(4);
6271

63-
// exceptions: not symmetric
64-
stan::test::expect_ad(f, c, a);
65-
stan::test::expect_ad(f, c, d);
66-
6772
// exceptions: wrong sizes
6873
stan::test::expect_ad(f, m33, m44);
6974
stan::test::expect_ad(f, m33, v4);

0 commit comments

Comments
 (0)