Open
Description
Description
The gradient of x[1:][-1] has two successive inc_subtensor on zeros of increasing size. We should collapse them, as happens if you take the gradient of the single slice that corresponds to the two slices together x[-1]
.
This shows up in the gradient of Scans for the last outputs of a recurring sequence.
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
x = pt.vector("x", shape=(4,))
out = x[1:][-1] # When you select the last entry of a scan sitsot this shows up in the graph
g = pt.grad(out, x)
rewrite_graph(g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{start:} [id A]
# ├─ Alloc [id B]
# │ ├─ [0.] [id C]
# │ └─ 4 [id D]
# ├─ IncSubtensor{i} [id E]
# │ ├─ Alloc [id F]
# │ │ ├─ [0.] [id C]
# │ │ └─ 3 [id G]
# │ ├─ 1.0 [id H]
# │ └─ -1 [id I]
# └─ 1 [id J]
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{i} [id A]
# ├─ Alloc [id B]
# │ ├─ [0.] [id C]
# │ └─ 4 [id D]
# ├─ 1.0 [id E]
# └─ 3 [id F]
I think the rule is incsubtensor on the larger buffer with the negative inner index, or outer start + positive inner index. We may also want to handle the unknown sign symbolically, but even the constant case would be a nice start.
Bonus points if we can combine it with an outer flip that the scan gradient also does:
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
x = pt.vector("x", shape=(4,))
out = x[1:][-1]
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)[::-1]
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# Subtensor{::step} [id A]
# ├─ IncSubtensor{i} [id B]
# │ ├─ Alloc [id C]
# │ │ ├─ [0.] [id D]
# │ │ └─ 4 [id E]
# │ ├─ 1.0 [id F]
# │ └─ 3 [id G]
# └─ -1 [id H]
Which should be doable by flipping the indices. Not as important since the flip is just a cheap view on the input