|
1 |
| -from scipy.signal import convolve |
| 1 | +from scipy.signal import convolve as scipy_convolve |
2 | 2 |
|
3 | 3 | import pytensor.tensor as pt
|
4 | 4 | from pytensor.graph import Apply, Op
|
5 | 5 | from pytensor.scalar.basic import upcast
|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class Conv1d(Op):
|
| 9 | + __props__ = ("mode",) |
| 10 | + |
9 | 11 | def __init__(self, mode="full"):
|
10 | 12 | self.mode = mode
|
11 | 13 |
|
@@ -35,14 +37,40 @@ def make_node(self, data, kernel):
|
35 | 37 |
|
36 | 38 | def perform(self, node, inputs, outputs):
|
37 | 39 | data, kernel = inputs
|
38 |
| - outputs[0][0] = convolve(data, kernel, mode=self.mode) |
| 40 | + outputs[0][0] = scipy_convolve(data, kernel, mode=self.mode) |
| 41 | + |
| 42 | + def infer_shape(self, fgraph, node, shapes): |
| 43 | + data_shape, kernel_shape = shapes |
| 44 | + n = data_shape[0] |
| 45 | + k = kernel_shape[0] |
| 46 | + if self.mode == "full": |
| 47 | + shape = n + k - 1 |
| 48 | + elif self.mode == "valid": |
| 49 | + shape = pt.maximum(n, k) - pt.minimum(n, k) + 1 |
| 50 | + elif self.mode == "same": |
| 51 | + shape = pt.maximum(n, k) |
| 52 | + return [[shape]] |
39 | 53 |
|
40 | 54 | def L_op(self, inputs, outputs, output_grads):
|
41 | 55 | data, kernel = inputs
|
42 | 56 | [grad] = output_grads
|
43 | 57 |
|
44 | 58 | if self.mode == "full":
|
45 |
| - data_bar = type(self)(mode="valid")(grad, kernel[::-1]) |
46 |
| - kernel_bar = type(self)(mode="valid")(grad, data[::-1]) |
| 59 | + valid_conv = type(self)(mode="valid") |
| 60 | + data_bar = valid_conv(grad, kernel[::-1]) |
| 61 | + kernel_bar = valid_conv(grad, data[::-1]) |
| 62 | + |
| 63 | + elif self.mode == "valid": |
| 64 | + full_conv = type(self)(mode="full") |
| 65 | + n = data.shape[0] |
| 66 | + k = kernel.shape[0] |
| 67 | + kmn = pt.maximum(0, k - n) |
| 68 | + nkm = pt.maximum(0, n - k) |
| 69 | + # We need mode="full" if k >= n else "valid" for data_bar (opposite for kernel_bar), but mode is not symbolic. |
| 70 | + # Instead we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter. |
| 71 | + data_bar = full_conv(grad, kernel[::-1]) |
| 72 | + data_bar = data_bar[kmn : data_bar.shape[0] - kmn] |
| 73 | + kernel_bar = full_conv(grad, data[::-1]) |
| 74 | + kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm] |
47 | 75 |
|
48 | 76 | return [data_bar, kernel_bar]
|
0 commit comments