Skip to content

Commit fcd70c1

Browse files
jessegrabowskizaxtaxricardoV94
committed
Implement batched convolve1d
Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com> Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent 2e9d502 commit fcd70c1

File tree

14 files changed

+313
-0
lines changed

14 files changed

+313
-0
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytensor.link.jax.dispatch.scalar
1515
import pytensor.link.jax.dispatch.scan
1616
import pytensor.link.jax.dispatch.shape
17+
import pytensor.link.jax.dispatch.signal
1718
import pytensor.link.jax.dispatch.slinalg
1819
import pytensor.link.jax.dispatch.sort
1920
import pytensor.link.jax.dispatch.sparse
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.link.jax.dispatch.signal.conv
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.signal.conv import Conv1d
5+
6+
7+
@jax_funcify.register(Conv1d)
8+
def jax_funcify_Conv1d(op, node, **kwargs):
9+
mode = op.mode
10+
11+
def conv1d(data, kernel):
12+
return jax.numpy.convolve(data, kernel, mode=mode)
13+
14+
return conv1d

pytensor/link/numba/dispatch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.signal
1213
import pytensor.link.numba.dispatch.slinalg
1314
import pytensor.link.numba.dispatch.sparse
1415
import pytensor.link.numba.dispatch.subtensor
1516
import pytensor.link.numba.dispatch.tensor_basic
1617

18+
1719
# isort: on
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.link.numba.dispatch.signal.conv
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
3+
from pytensor.link.numba.dispatch import numba_funcify
4+
from pytensor.link.numba.dispatch.basic import numba_njit
5+
from pytensor.tensor.signal.conv import Conv1d
6+
7+
8+
@numba_funcify.register(Conv1d)
9+
def numba_funcify_Conv1d(op, node, **kwargs):
10+
mode = op.mode
11+
12+
@numba_njit
13+
def conv1d(data, kernel):
14+
return np.convolve(data, kernel, mode=mode)
15+
16+
return conv1d

pytensor/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
116116
# isort: off
117117
from pytensor.tensor import linalg
118118
from pytensor.tensor import special
119+
from pytensor.tensor import signal
119120

120121
# For backward compatibility
121122
from pytensor.tensor import nlinalg

pytensor/tensor/signal/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pytensor.tensor.signal.conv import convolve, convolve1d
2+
3+
4+
__all__ = (
5+
"convolve",
6+
"convolve1d",
7+
)

pytensor/tensor/signal/conv.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from typing import TYPE_CHECKING, Literal, cast
2+
3+
from numpy import convolve as numpy_convolve
4+
5+
from pytensor.graph import Apply, Op
6+
from pytensor.scalar.basic import upcast
7+
from pytensor.tensor.basic import as_tensor_variable, join, zeros
8+
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.math import maximum, minimum
10+
from pytensor.tensor.type import vector
11+
from pytensor.tensor.variable import TensorVariable
12+
13+
14+
if TYPE_CHECKING:
15+
from pytensor.tensor import TensorLike
16+
17+
18+
class Conv1d(Op):
19+
__props__ = ("mode",)
20+
gufunc_signature = "(n),(k)->(o)"
21+
22+
def __init__(self, mode: Literal["full", "valid"] = "full"):
23+
if mode not in ("full", "valid"):
24+
raise ValueError(f"Invalid mode: {mode}")
25+
self.mode = mode
26+
27+
def make_node(self, data, kernel):
28+
data = as_tensor_variable(data)
29+
kernel = as_tensor_variable(kernel)
30+
31+
assert data.ndim == 1
32+
assert kernel.ndim == 1
33+
34+
dtype = upcast(data.dtype, kernel.dtype)
35+
36+
n = data.type.shape[0]
37+
k = kernel.type.shape[0]
38+
39+
if n is None or k is None:
40+
out_shape = (None,)
41+
elif self.mode == "full":
42+
out_shape = (n + k - 1,)
43+
else: # mode == "valid":
44+
out_shape = (max(n, k) - min(n, k) + 1,)
45+
46+
out = vector(dtype=dtype, shape=out_shape)
47+
return Apply(self, [data, kernel], [out])
48+
49+
def perform(self, node, inputs, outputs):
50+
data, kernel = inputs
51+
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
52+
# And mode != "same", which this Op doesn't cover anyway.
53+
outputs[0][0] = numpy_convolve(data, kernel, mode=self.mode)
54+
55+
def infer_shape(self, fgraph, node, shapes):
56+
data_shape, kernel_shape = shapes
57+
n = data_shape[0]
58+
k = kernel_shape[0]
59+
if self.mode == "full":
60+
shape = n + k - 1
61+
else: # mode == "valid":
62+
shape = maximum(n, k) - minimum(n, k) + 1
63+
return [[shape]]
64+
65+
def L_op(self, inputs, outputs, output_grads):
66+
data, kernel = inputs
67+
[grad] = output_grads
68+
69+
if self.mode == "full":
70+
valid_conv = type(self)(mode="valid")
71+
data_bar = valid_conv(grad, kernel[::-1])
72+
kernel_bar = valid_conv(grad, data[::-1])
73+
74+
else: # mode == "valid":
75+
full_conv = type(self)(mode="full")
76+
n = data.shape[0]
77+
k = kernel.shape[0]
78+
kmn = maximum(0, k - n)
79+
nkm = maximum(0, n - k)
80+
# We need mode="full" if k >= n else "valid" for data_bar (opposite for kernel_bar), but mode is not symbolic.
81+
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
82+
data_bar = full_conv(grad, kernel[::-1])
83+
data_bar = data_bar[kmn : data_bar.shape[0] - kmn]
84+
kernel_bar = full_conv(grad, data[::-1])
85+
kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm]
86+
87+
return [data_bar, kernel_bar]
88+
89+
90+
def convolve1d(
91+
in1: "TensorLike",
92+
in2: "TensorLike",
93+
mode: Literal["full", "valid", "same"] = "full",
94+
) -> TensorVariable:
95+
"""Convolve two one-dimensional arrays.
96+
97+
Convolve in1 and in2, with the output size determined by the mode argument.
98+
99+
Parameters
100+
----------
101+
in1 : (..., N,) tensor_like
102+
First input.
103+
in2 : (..., M,) tensor_like
104+
Second input.
105+
mode : {'full', 'valid', 'same'}, optional
106+
A string indicating the size of the output:
107+
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,).
108+
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,).
109+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
110+
111+
Returns
112+
-------
113+
out: tensor_variable
114+
The discrete linear convolution of in1 with in2.
115+
116+
"""
117+
in1 = as_tensor_variable(in1)
118+
in2 = as_tensor_variable(in2)
119+
120+
if mode == "same":
121+
# We implement "same" as "valid" with padded data.
122+
in1_batch_shape = tuple(in1.shape)[:-1]
123+
zeros_left = in2.shape[0] // 2
124+
zeros_right = (in2.shape[0] - 1) // 2
125+
in1 = join(
126+
-1,
127+
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
128+
in1,
129+
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
130+
)
131+
mode = "valid"
132+
133+
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
134+
135+
136+
def convolve(
137+
in1: "TensorLike",
138+
in2: "TensorLike",
139+
mode: Literal["full", "valid", "same"] = "full",
140+
method: Literal["auto", "direct", "fft"] = "direct",
141+
) -> TensorVariable:
142+
"""Convolve two N-dimensional arrays.
143+
144+
Convolve in1 and in2, with the output size determined by the mode argument.
145+
146+
Parameters
147+
----------
148+
in1 : tensor_like
149+
First input.
150+
in2 : tensor_like
151+
Second input.
152+
mode : {'full', 'valid', 'same'}, optional
153+
A string indicating the size of the output:
154+
- 'full': The output is the full discrete linear convolution of the inputs. (Default)
155+
- 'valid': The output consists only of elements that do not rely on zero-padding.
156+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
157+
method : {'auto', 'direct', 'fft'}, optional
158+
Unused by PyTensor, only direct method is used.
159+
160+
Returns
161+
-------
162+
out: tensor_variable
163+
The discrete linear convolution of in1 with in2.
164+
"""
165+
in1 = as_tensor_variable(in1)
166+
in2 = as_tensor_variable(in2)
167+
if in1.ndim != 1 or in2.ndim != 1:
168+
raise NotImplementedError(
169+
"convolve only implemented for 1D inputs. If you want a batch 1d convolution, use convolve1d."
170+
)
171+
return convolve1d(in1, in2, mode=mode)

tests/link/jax/signal/__init__.py

Whitespace-only changes.

tests/link/jax/signal/test_conv.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.tensor import matrix
5+
from pytensor.tensor.signal import convolve1d
6+
from tests.link.jax.test_basic import compare_jax_and_py
7+
8+
9+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
10+
def test_convolve1d(mode):
11+
x = matrix("x")
12+
y = matrix("y")
13+
out = convolve1d(x[None], y[:, None], mode=mode)
14+
15+
rng = np.random.default_rng()
16+
test_x = rng.normal(size=(3, 5))
17+
test_y = rng.normal(size=(7, 11))
18+
compare_jax_and_py([x, y], out, [test_x, test_y])

tests/link/numba/signal/test_conv.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.tensor import matrix
5+
from pytensor.tensor.signal import convolve1d
6+
from tests.link.numba.test_basic import compare_numba_and_py
7+
8+
9+
pytestmark = pytest.mark.filterwarnings("error")
10+
11+
12+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13+
def test_convolve1d(mode):
14+
x = matrix("x")
15+
y = matrix("y")
16+
out = convolve1d(x[None], y[:, None], mode=mode)
17+
18+
rng = np.random.default_rng()
19+
test_x = rng.normal(size=(3, 5))
20+
test_y = rng.normal(size=(7, 11))
21+
# Object mode is not supported for numba
22+
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)

tests/tensor/signal/__init__.py

Whitespace-only changes.

tests/tensor/signal/test_conv.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from functools import partial
2+
3+
import numpy as np
4+
import pytest
5+
from scipy.signal import convolve as scipy_convolve
6+
7+
from pytensor import function
8+
from pytensor.graph.basic import equal_computations
9+
from pytensor.tensor import matrix, vector
10+
from pytensor.tensor.signal.conv import convolve, convolve1d
11+
from tests import unittest_tools as utt
12+
13+
14+
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
15+
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
16+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
17+
def test_convolve1d(mode, data_shape, kernel_shape):
18+
data = vector("data")
19+
kernel = vector("kernel")
20+
op = partial(convolve1d, mode=mode)
21+
22+
rng = np.random.default_rng()
23+
data_val = rng.normal(size=data_shape)
24+
kernel_val = rng.normal(size=kernel_shape)
25+
26+
fn = function([data, kernel], op(data, kernel))
27+
np.testing.assert_allclose(
28+
fn(data_val, kernel_val),
29+
scipy_convolve(data_val, kernel_val, mode=mode),
30+
)
31+
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])
32+
33+
34+
def test_convolve1d_batch():
35+
x = matrix("data")
36+
y = matrix("kernel")
37+
out = convolve1d(x, y)
38+
39+
# Convolution is unchanged by order
40+
rng = np.random.default_rng(38)
41+
x_test = rng.normal(size=(2, 8))
42+
y_test = x_test[::-1]
43+
44+
res = out.eval({x: x_test, y: y_test})
45+
res_np = np.convolve(x_test[0], y_test[0])
46+
np.testing.assert_allclose(res[0], res_np)
47+
np.testing.assert_allclose(res[1], res_np)
48+
49+
50+
def test_convolve():
51+
x = vector()
52+
y = vector()
53+
out = convolve(x, y, mode="valid")
54+
assert equal_computations([out], [convolve1d(x, y, mode="valid")])
55+
56+
x = matrix()
57+
y = matrix()
58+
with pytest.raises(NotImplementedError):
59+
convolve(x, y)

0 commit comments

Comments
 (0)