Skip to content

Commit a8b8248

Browse files
committed
Implement grad for mode="valid"
1 parent 9aa7fba commit a8b8248

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

pytensor/signal/conv.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
class Conv1d(Op):
9+
__props__ = ("mode",)
10+
911
def __init__(self, mode="full"):
1012
self.mode = mode
1113

@@ -37,12 +39,38 @@ def perform(self, node, inputs, outputs):
3739
data, kernel = inputs
3840
outputs[0][0] = convolve(data, kernel, mode=self.mode)
3941

42+
def infer_shape(self, fgraph, node, shapes):
43+
data_shape, kernel_shape = shapes
44+
n = data_shape[0]
45+
k = kernel_shape[0]
46+
if self.mode == "full":
47+
shape = n + k - 1
48+
elif self.mode == "valid":
49+
shape = pt.maximum(n, k) - pt.minimum(n, k) + 1
50+
elif self.mode == "same":
51+
shape = pt.maximum(n, k)
52+
return [[shape]]
53+
4054
def L_op(self, inputs, outputs, output_grads):
4155
data, kernel = inputs
4256
[grad] = output_grads
4357

4458
if self.mode == "full":
45-
data_bar = type(self)(mode="valid")(grad, kernel[::-1])
46-
kernel_bar = type(self)(mode="valid")(grad, data[::-1])
59+
valid_conv = type(self)(mode="valid")
60+
data_bar = valid_conv(grad, kernel[::-1])
61+
kernel_bar = valid_conv(grad, data[::-1])
62+
63+
elif self.mode == "valid":
64+
full_conv = type(self)(mode="full")
65+
n = data.shape[0]
66+
k = kernel.shape[0]
67+
kmn = pt.maximum(0, k - n)
68+
nkm = pt.maximum(0, n - k)
69+
# We need mode="full" if k >= n else "same", but we can't have symbolic mode
70+
# Instead we use mode="full" and then slice the result so it behaves like "same" when needed
71+
data_bar = full_conv(grad, kernel[::-1])
72+
data_bar = data_bar[kmn : data_bar.shape[0] - kmn]
73+
kernel_bar = full_conv(grad, data[::-1])
74+
kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm]
4775

4876
return [data_bar, kernel_bar]

tests/signal/test_conv.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import numpy as np
2+
import pytest
23

34
from pytensor.signal.conv import Conv1d
45
from tests import unittest_tools as utt
56

67

7-
def test_conv1d_grads():
8+
@pytest.mark.parametrize("data_shape", [3, 5, 8])
9+
@pytest.mark.parametrize("kernel_shape", [3, 5, 8])
10+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
11+
def test_conv1d_grad(mode, data_shape, kernel_shape):
812
rng = np.random.default_rng()
913

10-
data_val = rng.normal(size=(3,))
11-
kernel_val = rng.normal(size=(5,))
14+
data_val = rng.normal(size=data_shape)
15+
kernel_val = rng.normal(size=kernel_shape)
1216

13-
op = Conv1d(mode="full")
17+
op = Conv1d(mode=mode)
1418

1519
utt.verify_grad(op=op, pt=[data_val, kernel_val])

0 commit comments

Comments
 (0)