Open
Description
Description
Currently we have an Op that calls np.tri
, but we can very easily build lower triangular mask matrices with _iota
:
from pytensor.tensor.einsum import _iota
def tri(M, N, k):
return ((_iota(M) + k) > _iota(N)).astype(int)
This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri
, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)
I suggest we wrap this in a dummy OpFromGraph like we do for Kron
and AllocDiag
so that the dprints are nicer. We can also overload the L_op
if we want? The current tri
has grad_undefined
, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota
function should be differentiable.
Metadata
Metadata
Assignees
Type
Projects
Milestone
Relationships
Development
No branches or pull requests
Activity
Nimish-4 commentedon Mar 8, 2025
@jessegrabowski New contributor here. I created a PR based on my understanding of the issue. It probably has a few errors. I was unable to run pytest locally due to some circular import issue. Waiting for feedback!