Skip to content

Commit 9aa7fba

Browse files
don't merge
1 parent 498621c commit 9aa7fba

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed

pytensor/signal/__init__.py

Whitespace-only changes.

pytensor/signal/conv.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from scipy.signal import convolve
2+
3+
import pytensor.tensor as pt
4+
from pytensor.graph import Apply, Op
5+
from pytensor.scalar.basic import upcast
6+
7+
8+
class Conv1d(Op):
9+
def __init__(self, mode="full"):
10+
self.mode = mode
11+
12+
def make_node(self, data, kernel):
13+
data = pt.as_tensor_variable(data)
14+
kernel = pt.as_tensor_variable(kernel)
15+
16+
assert data.ndim == 1
17+
assert kernel.ndim == 1
18+
19+
dtype = upcast(data.dtype, kernel.dtype)
20+
21+
n = data.type.shape[0]
22+
k = kernel.type.shape[0]
23+
24+
if n is None or k is None:
25+
out_shape = (None,)
26+
elif self.mode == "full":
27+
out_shape = (n + k - 1,)
28+
elif self.mode == "valid":
29+
out_shape = (max(n, k) - min(n, k) + 1,)
30+
elif self.mode == "same":
31+
out_shape = (max(n, k),)
32+
33+
out = pt.tensor(dtype=dtype, shape=out_shape)
34+
return Apply(self, [data, kernel], [out])
35+
36+
def perform(self, node, inputs, outputs):
37+
data, kernel = inputs
38+
outputs[0][0] = convolve(data, kernel, mode=self.mode)
39+
40+
def L_op(self, inputs, outputs, output_grads):
41+
data, kernel = inputs
42+
[grad] = output_grads
43+
44+
if self.mode == "full":
45+
data_bar = type(self)(mode="valid")(grad, kernel[::-1])
46+
kernel_bar = type(self)(mode="valid")(grad, data[::-1])
47+
48+
return [data_bar, kernel_bar]

tests/signal/__init__.py

Whitespace-only changes.

tests/signal/test_conv.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
3+
from pytensor.signal.conv import Conv1d
4+
from tests import unittest_tools as utt
5+
6+
7+
def test_conv1d_grads():
8+
rng = np.random.default_rng()
9+
10+
data_val = rng.normal(size=(3,))
11+
kernel_val = rng.normal(size=(5,))
12+
13+
op = Conv1d(mode="full")
14+
15+
utt.verify_grad(op=op, pt=[data_val, kernel_val])

0 commit comments

Comments
 (0)