Skip to content

Commit 8443857

Browse files
committed
Implement grad for mode="valid"
1 parent 9aa7fba commit 8443857

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

pytensor/signal/conv.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from scipy.signal import convolve
1+
from scipy.signal import convolve as scipy_convolve
22

33
import pytensor.tensor as pt
44
from pytensor.graph import Apply, Op
55
from pytensor.scalar.basic import upcast
66

77

88
class Conv1d(Op):
9+
__props__ = ("mode",)
10+
911
def __init__(self, mode="full"):
1012
self.mode = mode
1113

@@ -35,14 +37,40 @@ def make_node(self, data, kernel):
3537

3638
def perform(self, node, inputs, outputs):
3739
data, kernel = inputs
38-
outputs[0][0] = convolve(data, kernel, mode=self.mode)
40+
outputs[0][0] = scipy_convolve(data, kernel, mode=self.mode)
41+
42+
def infer_shape(self, fgraph, node, shapes):
43+
data_shape, kernel_shape = shapes
44+
n = data_shape[0]
45+
k = kernel_shape[0]
46+
if self.mode == "full":
47+
shape = n + k - 1
48+
elif self.mode == "valid":
49+
shape = pt.maximum(n, k) - pt.minimum(n, k) + 1
50+
elif self.mode == "same":
51+
shape = pt.maximum(n, k)
52+
return [[shape]]
3953

4054
def L_op(self, inputs, outputs, output_grads):
4155
data, kernel = inputs
4256
[grad] = output_grads
4357

4458
if self.mode == "full":
45-
data_bar = type(self)(mode="valid")(grad, kernel[::-1])
46-
kernel_bar = type(self)(mode="valid")(grad, data[::-1])
59+
valid_conv = type(self)(mode="valid")
60+
data_bar = valid_conv(grad, kernel[::-1])
61+
kernel_bar = valid_conv(grad, data[::-1])
62+
63+
elif self.mode == "valid":
64+
full_conv = type(self)(mode="full")
65+
n = data.shape[0]
66+
k = kernel.shape[0]
67+
kmn = pt.maximum(0, k - n)
68+
nkm = pt.maximum(0, n - k)
69+
# We need mode="full" if k >= n else "valid" for data_bar (opposite for kernel_bar), but mode is not symbolic.
70+
# Instead we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
71+
data_bar = full_conv(grad, kernel[::-1])
72+
data_bar = data_bar[kmn : data_bar.shape[0] - kmn]
73+
kernel_bar = full_conv(grad, data[::-1])
74+
kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm]
4775

4876
return [data_bar, kernel_bar]

tests/signal/test_conv.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import numpy as np
2+
import pytest
23

34
from pytensor.signal.conv import Conv1d
45
from tests import unittest_tools as utt
56

67

7-
def test_conv1d_grads():
8+
@pytest.mark.parametrize("data_shape", [3, 5, 8])
9+
@pytest.mark.parametrize("kernel_shape", [3, 5, 8])
10+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
11+
def test_conv1d_grad(mode, data_shape, kernel_shape):
812
rng = np.random.default_rng()
913

10-
data_val = rng.normal(size=(3,))
11-
kernel_val = rng.normal(size=(5,))
14+
data_val = rng.normal(size=data_shape)
15+
kernel_val = rng.normal(size=kernel_shape)
1216

13-
op = Conv1d(mode="full")
17+
op = Conv1d(mode=mode)
1418

1519
utt.verify_grad(op=op, pt=[data_val, kernel_val])

0 commit comments

Comments
 (0)