Skip to content

Commit cdb9d14

Browse files
authored
Merge pull request #1660 from bstatcomp/generalize_view_and_size
Generalize view and size functions
2 parents 1decede + c98fa84 commit cdb9d14

16 files changed

+155
-118
lines changed

stan/math/prim/err.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include <stan/math/prim/err/check_symmetric.hpp>
4444
#include <stan/math/prim/err/check_unit_vector.hpp>
4545
#include <stan/math/prim/err/check_vector.hpp>
46+
#include <stan/math/prim/err/check_vector_index.hpp>
4647
#include <stan/math/prim/err/constraint_tolerance.hpp>
4748
#include <stan/math/prim/err/domain_error.hpp>
4849
#include <stan/math/prim/err/domain_error_vec.hpp>

stan/math/prim/err/check_column_index.hpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@ namespace math {
1717
* <code>stan::error_index::value</code>. This function will
1818
* throw an <code>std::out_of_range</code> exception if
1919
* the index is out of bounds.
20-
* @tparam T_y Type of scalar
21-
* @tparam R number of rows or Eigen::Dynamic
22-
* @tparam C number of columns or Eigen::Dynamic
20+
* @tparam T_y Type of matrix
2321
* @param function Function name (for error messages)
2422
* @param name Variable name (for error messages)
2523
* @param y matrix to test
2624
* @param i column index to check
2725
* @throw <code>std::out_of_range</code> if index is an invalid column
2826
*/
29-
template <typename T_y, int R, int C>
27+
template <typename T_y, typename = require_eigen_t<T_y>>
3028
inline void check_column_index(const char* function, const char* name,
31-
const Eigen::Matrix<T_y, R, C>& y, size_t i) {
29+
const T_y& y, size_t i) {
3230
if (i >= stan::error_index::value
3331
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value) {
3432
return;

stan/math/prim/err/check_row_index.hpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@ namespace math {
1414
* Check if the specified index is a valid row of the matrix
1515
* This check is 1-indexed by default. This behavior can be changed
1616
* by setting <code>stan::error_index::value</code>.
17-
* @tparam T Scalar type
18-
* @tparam R number of rows or Eigen::Dynamic
19-
* @tparam C number of columns or Eigen::Dynamic
17+
* @tparam T Matrix type
2018
* @param function Function name (for error messages)
2119
* @param name Variable name (for error messages)
2220
* @param y matrix to test
2321
* @param i row index to check
2422
* @throw <code>std::out_of_range</code> if the index is out of range.
2523
*/
26-
template <typename T_y, int R, int C>
24+
template <typename T_y, typename = require_eigen_t<T_y>>
2725
inline void check_row_index(const char* function, const char* name,
28-
const Eigen::Matrix<T_y, R, C>& y, size_t i) {
26+
const T_y& y, size_t i) {
2927
if (i >= stan::error_index::value
3028
&& i < static_cast<size_t>(y.rows()) + stan::error_index::value) {
3129
return;
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_VECTOR_INDEX_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_VECTOR_INDEX_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err/out_of_range.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <sstream>
8+
#include <string>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* Check if the specified index is a valid element of the row or column vector
15+
* This check is 1-indexed by default. This behavior can be changed
16+
* by setting <code>stan::error_index::value</code>.
17+
* @tparam T Vector type
18+
* @param function Function name (for error messages)
19+
* @param name Variable name (for error messages)
20+
* @param y vector to test
21+
* @param i row index to check
22+
* @throw <code>std::out_of_range</code> if the index is out of range.
23+
*/
24+
template <typename T, typename = require_eigen_vector_t<T>>
25+
inline void check_vector_index(const char* function, const char* name,
26+
const T& y, size_t i) {
27+
if (i >= stan::error_index::value
28+
&& i < static_cast<size_t>(y.size()) + stan::error_index::value) {
29+
return;
30+
}
31+
32+
std::stringstream msg;
33+
msg << " for size of " << name;
34+
std::string msg_str(msg.str());
35+
out_of_range(function, y.rows(), i, msg_str.c_str());
36+
}
37+
38+
} // namespace math
39+
} // namespace stan
40+
#endif

stan/math/prim/fun/col.hpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@ namespace math {
1414
* This is equivalent to calling <code>m.col(i - 1)</code> and
1515
* assigning the resulting template expression to a column vector.
1616
*
17-
* @tparam T type of elements in the matrix
17+
* @tparam T type of the matrix
1818
* @param m Matrix.
1919
* @param j Column index (count from 1).
2020
* @return Specified column of the matrix.
2121
* @throw std::out_of_range if j is out of range.
2222
*/
23-
template <typename T>
24-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> col(
25-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t j) {
23+
template <typename T, typename = require_eigen_t<T>>
24+
inline auto col(const T& m, size_t j) {
2625
check_column_index("col", "j", m, j);
27-
return m.col(j - 1);
26+
return m.col(j - 1).eval();
2827
}
2928

3029
} // namespace math

stan/math/prim/fun/cols.hpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_COLS_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
56

67
namespace stan {
78
namespace math {
@@ -10,14 +11,12 @@ namespace math {
1011
* Return the number of columns in the specified
1112
* matrix, vector, or row vector.
1213
*
13-
* @tparam T type of elements in the matrix
14-
* @tparam R number of rows, can be Eigen::Dynamic
15-
* @tparam C number of columns, can be Eigen::Dynamic
14+
* @tparam T type of the matrix
1615
* @param[in] m Input matrix, vector, or row vector.
1716
* @return Number of columns.
1817
*/
19-
template <typename T, int R, int C>
20-
inline int cols(const Eigen::Matrix<T, R, C>& m) {
18+
template <typename T, typename = require_eigen_t<T>>
19+
inline int cols(const T& m) {
2120
return m.cols();
2221
}
2322

stan/math/prim/fun/diagonal.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_DIAGONAL_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
56

67
namespace stan {
78
namespace math {
@@ -10,14 +11,13 @@ namespace math {
1011
* Return a column vector of the diagonal elements of the
1112
* specified matrix. The matrix is not required to be square.
1213
*
13-
* @tparam T type of elements in the matrix
14+
* @tparam T type of the matrix
1415
* @param m Specified matrix.
1516
* @return Diagonal of the matrix.
1617
*/
17-
template <typename T>
18-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> diagonal(
19-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m) {
20-
return m.diagonal();
18+
template <typename T, typename = require_eigen_t<T>>
19+
inline auto diagonal(const T& m) {
20+
return m.diagonal().eval();
2121
}
2222

2323
} // namespace math

stan/math/prim/fun/dims.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
#define STAN_MATH_PRIM_FUN_DIMS_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
56
#include <vector>
67

78
namespace stan {
89
namespace math {
910

10-
template <typename T>
11+
template <typename T, typename = require_stan_scalar_t<T>>
1112
inline void dims(const T& x, std::vector<int>& result) {
1213
/* no op */
1314
}
14-
template <typename T, int R, int C>
15-
inline void dims(const Eigen::Matrix<T, R, C>& x, std::vector<int>& result) {
15+
template <typename T, typename = require_eigen_t<T>, typename = void>
16+
inline void dims(const T& x, std::vector<int>& result) {
1617
result.push_back(x.rows());
1718
result.push_back(x.cols());
1819
}

stan/math/prim/fun/head.hpp

+8-31
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,21 @@ namespace stan {
99
namespace math {
1010

1111
/**
12-
* Return the specified number of elements as a vector
13-
* from the front of the specified vector.
12+
* Return the specified number of elements as a vector or row vector (same as
13+
* input) from the front of the specified vector or row vector.
1414
*
15-
* @tparam T type of elements in the vector
15+
* @tparam T type of the vector
1616
* @param v Vector input.
1717
* @param n Size of return.
1818
* @return The first n elements of v.
1919
* @throw std::out_of_range if n is out of range.
2020
*/
21-
template <typename T>
22-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> head(
23-
const Eigen::Matrix<T, Eigen::Dynamic, 1>& v, size_t n) {
21+
template <typename T, typename = require_eigen_vector_t<T>>
22+
inline auto head(const T& v, size_t n) {
2423
if (n != 0) {
25-
check_row_index("head", "n", v, n);
24+
check_vector_index("head", "n", v, n);
2625
}
27-
return v.head(n);
28-
}
29-
30-
/**
31-
* Return the specified number of elements as a row vector
32-
* from the front of the specified row vector.
33-
*
34-
* @tparam T type of elements in the vector
35-
* @param rv Row vector.
36-
* @param n Size of return row vector.
37-
* @return The first n elements of rv.
38-
* @throw std::out_of_range if n is out of range.
39-
*/
40-
template <typename T>
41-
inline Eigen::Matrix<T, 1, Eigen::Dynamic> head(
42-
const Eigen::Matrix<T, 1, Eigen::Dynamic>& rv, size_t n) {
43-
if (n != 0) {
44-
check_column_index("head", "n", rv, n);
45-
}
46-
return rv.head(n);
26+
return v.head(n).eval();
4727
}
4828

4929
/**
@@ -62,10 +42,7 @@ std::vector<T> head(const std::vector<T>& sv, size_t n) {
6242
check_std_vector_index("head", "n", sv, n);
6343
}
6444

65-
std::vector<T> s;
66-
for (size_t i = 0; i < n; ++i) {
67-
s.push_back(sv[i]);
68-
}
45+
std::vector<T> s(sv.begin(), sv.begin() + n);
6946
return s;
7047
}
7148

stan/math/prim/fun/num_elements.hpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_NUM_ELEMENTS_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
56
#include <vector>
67

78
namespace stan {
@@ -14,23 +15,21 @@ namespace math {
1415
* @param x Argument of primitive type.
1516
* @return 1
1617
*/
17-
template <typename T>
18+
template <typename T, typename = require_stan_scalar_t<T>>
1819
inline int num_elements(const T& x) {
1920
return 1;
2021
}
2122

2223
/**
2324
* Returns the size of the specified matrix.
2425
*
25-
* @tparam T type of elements in the matrix
26-
* @tparam R number of rows, can be Eigen::Dynamic
27-
* @tparam C number of columns, can be Eigen::Dynamic
26+
* @tparam T type of the matrix
2827
*
2928
* @param m argument matrix
3029
* @return size of matrix
3130
*/
32-
template <typename T, int R, int C>
33-
inline int num_elements(const Eigen::Matrix<T, R, C>& m) {
31+
template <typename T, typename = require_eigen_t<T>, typename = void>
32+
inline int num_elements(const T& m) {
3433
return m.size();
3534
}
3635

stan/math/prim/fun/row.hpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,17 @@ namespace math {
1414
* This is equivalent to calling <code>m.row(i - 1)</code> and
1515
* assigning the resulting template expression to a row vector.
1616
*
17-
* @tparam T type of elements in the matrix
17+
* @tparam T type of the matrix
1818
* @param m Matrix.
1919
* @param i Row index (count from 1).
2020
* @return Specified row of the matrix.
2121
* @throw std::out_of_range if i is out of range.
2222
*/
23-
template <typename T>
24-
inline Eigen::Matrix<T, 1, Eigen::Dynamic> row(
25-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i) {
23+
template <typename T, typename = require_eigen_t<T>>
24+
inline auto row(const T& m, size_t i) {
2625
check_row_index("row", "i", m, i);
2726

28-
return m.row(i - 1);
27+
return m.row(i - 1).eval();
2928
}
3029

3130
} // namespace math

stan/math/prim/fun/rows.hpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_ROWS_HPP
33

44
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
56

67
namespace stan {
78
namespace math {
@@ -10,15 +11,12 @@ namespace math {
1011
* Return the number of rows in the specified
1112
* matrix, vector, or row vector.
1213
*
13-
* @tparam T type of elements in the matrix
14-
* @tparam R number of rows, can be Eigen::Dynamic
15-
* @tparam C number of columns, can be Eigen::Dynamic
16-
*
14+
* @tparam T type of the matrix
1715
* @param[in] m Input matrix, vector, or row vector.
1816
* @return Number of rows.
1917
*/
20-
template <typename T, int R, int C>
21-
inline int rows(const Eigen::Matrix<T, R, C>& m) {
18+
template <typename T, typename = require_eigen_t<T>>
19+
inline int rows(const T& m) {
2220
return m.rows();
2321
}
2422

stan/math/prim/fun/sub_col.hpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,21 @@ namespace math {
1010
/**
1111
* Return a nrows x 1 subcolumn starting at (i-1, j-1).
1212
*
13-
* @tparam T type of elements in the matrix
13+
* @tparam T type of the matrix
1414
* @param m Matrix.
1515
* @param i Starting row + 1.
1616
* @param j Starting column + 1.
1717
* @param nrows Number of rows in block.
1818
* @throw std::out_of_range if either index is out of range.
1919
*/
20-
template <typename T>
21-
inline Eigen::Matrix<T, Eigen::Dynamic, 1> sub_col(
22-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i,
23-
size_t j, size_t nrows) {
20+
template <typename T, typename = require_eigen_t<T>>
21+
inline auto sub_col(const T& m, size_t i, size_t j, size_t nrows) {
2422
check_row_index("sub_col", "i", m, i);
2523
if (nrows > 0) {
2624
check_row_index("sub_col", "i+nrows-1", m, i + nrows - 1);
2725
}
2826
check_column_index("sub_col", "j", m, j);
29-
return m.block(i - 1, j - 1, nrows, 1);
27+
return m.col(j - 1).segment(i - 1, nrows).eval();
3028
}
3129

3230
} // namespace math

stan/math/prim/fun/sub_row.hpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,21 @@ namespace math {
1010
/**
1111
* Return a 1 x nrows subrow starting at (i-1, j-1).
1212
*
13-
* @tparam T type of elements in the matrix
13+
* @tparam T type of the matrix
1414
* @param m Matrix Input matrix.
1515
* @param i Starting row + 1.
1616
* @param j Starting column + 1.
1717
* @param ncols Number of columns in block.
1818
* @throw std::out_of_range if either index is out of range.
1919
*/
20-
template <typename T>
21-
inline Eigen::Matrix<T, 1, Eigen::Dynamic> sub_row(
22-
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>& m, size_t i,
23-
size_t j, size_t ncols) {
20+
template <typename T, typename = require_eigen_t<T>>
21+
inline auto sub_row(const T& m, size_t i, size_t j, size_t ncols) {
2422
check_row_index("sub_row", "i", m, i);
2523
check_column_index("sub_row", "j", m, j);
2624
if (ncols > 0) {
2725
check_column_index("sub_col", "j+ncols-1", m, j + ncols - 1);
2826
}
29-
return m.block(i - 1, j - 1, 1, ncols);
27+
return m.row(i - 1).segment(j - 1, ncols).eval();
3028
}
3129

3230
} // namespace math

0 commit comments

Comments
 (0)