Closed
Description
Something like:
def create_diagonal(m: NumpyRealArray) -> NumpyRealArray:
"""A vectorized version of diagonal.
Args:
m: Has shape (*k, n)
Returns: Array with shape (*k, n, n) and the elements of m on the diagonals.
"""
indices = (..., *np.diag_indices(m.shape[-1]))
retval = np.zeros((*m.shape, m.shape[-1]), dtype=m.dtype)
retval[indices] = m
return retval
I noticed that the array API has no way to do this in either batched or unbatched mode?
Metadata
Metadata
Assignees
Labels
No labels