Skip to content

Replace our Tri op with an OpFromGraph #1265

Open
@jessegrabowski

Description

@jessegrabowski

Description

Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:

from pytensor.tensor.einsum import _iota
def tri(M, N, k):
    return ((_iota(M) + k) > _iota(N)).astype(int)

This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)

I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.

Activity

Nimish-4

Nimish-4 commented on Mar 8, 2025

@Nimish-4

@jessegrabowski New contributor here. I created a PR based on my understanding of the issue. It probably has a few errors. I was unable to run pytest locally due to some circular import issue. Waiting for feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @jessegrabowski@Nimish-4

        Issue actions

          Replace our `Tri` op with an `OpFromGraph` · Issue #1265 · pymc-devs/pytensor