Skip to content

Rewrite nested inc/set_subtensor on zeros #1292

Open
@ricardoV94

Description

@ricardoV94

Description

The gradient of x[1:][-1] has two successive inc_subtensor on zeros of increasing size. We should collapse them, as happens if you take the gradient of the single slice that corresponds to the two slices together x[-1].

This shows up in the gradient of Scans for the last outputs of a recurring sequence.

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.vector("x", shape=(4,))
out = x[1:][-1]  # When you select the last entry of a scan sitsot this shows up in the graph
g = pt.grad(out, x)
rewrite_graph(g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{start:} [id A]
#  ├─ Alloc [id B]
#  │  ├─ [0.] [id C]
#  │  └─ 4 [id D]
#  ├─ IncSubtensor{i} [id E]
#  │  ├─ Alloc [id F]
#  │  │  ├─ [0.] [id C]
#  │  │  └─ 3 [id G]
#  │  ├─ 1.0 [id H]
#  │  └─ -1 [id I]
#  └─ 1 [id J]

new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{i} [id A]
#  ├─ Alloc [id B]
#  │  ├─ [0.] [id C]
#  │  └─ 4 [id D]
#  ├─ 1.0 [id E]
#  └─ 3 [id F]

I think the rule is incsubtensor on the larger buffer with the negative inner index, or outer start + positive inner index. We may also want to handle the unknown sign symbolically, but even the constant case would be a nice start.

Bonus points if we can combine it with an outer flip that the scan gradient also does:

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.vector("x", shape=(4,))
out = x[1:][-1]
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)[::-1]
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# Subtensor{::step} [id A]
#  ├─ IncSubtensor{i} [id B]
#  │  ├─ Alloc [id C]
#  │  │  ├─ [0.] [id D]
#  │  │  └─ 4 [id E]
#  │  ├─ 1.0 [id F]
#  │  └─ 3 [id G]
#  └─ -1 [id H]

Which should be doable by flipping the indices. Not as important since the flip is just a cheap view on the input

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions