|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class Conv1d(Op):
|
| 9 | + __props__ = ("mode",) |
| 10 | + |
9 | 11 | def __init__(self, mode="full"):
|
10 | 12 | self.mode = mode
|
11 | 13 |
|
@@ -37,12 +39,38 @@ def perform(self, node, inputs, outputs):
|
37 | 39 | data, kernel = inputs
|
38 | 40 | outputs[0][0] = convolve(data, kernel, mode=self.mode)
|
39 | 41 |
|
| 42 | + def infer_shape(self, fgraph, node, shapes): |
| 43 | + data_shape, kernel_shape = shapes |
| 44 | + n = data_shape[0] |
| 45 | + k = kernel_shape[0] |
| 46 | + if self.mode == "full": |
| 47 | + shape = n + k - 1 |
| 48 | + elif self.mode == "valid": |
| 49 | + shape = pt.maximum(n, k) - pt.minimum(n, k) + 1 |
| 50 | + elif self.mode == "same": |
| 51 | + shape = pt.maximum(n, k) |
| 52 | + return [[shape]] |
| 53 | + |
40 | 54 | def L_op(self, inputs, outputs, output_grads):
|
41 | 55 | data, kernel = inputs
|
42 | 56 | [grad] = output_grads
|
43 | 57 |
|
44 | 58 | if self.mode == "full":
|
45 |
| - data_bar = type(self)(mode="valid")(grad, kernel[::-1]) |
46 |
| - kernel_bar = type(self)(mode="valid")(grad, data[::-1]) |
| 59 | + valid_conv = type(self)(mode="valid") |
| 60 | + data_bar = valid_conv(grad, kernel[::-1]) |
| 61 | + kernel_bar = valid_conv(grad, data[::-1]) |
| 62 | + |
| 63 | + elif self.mode == "valid": |
| 64 | + full_conv = type(self)(mode="full") |
| 65 | + n = data.shape[0] |
| 66 | + k = kernel.shape[0] |
| 67 | + kmn = pt.maximum(0, k - n) |
| 68 | + nkm = pt.maximum(0, n - k) |
| 69 | + # We need mode="full" if k >= n else "same", but we can't have symbolic mode |
| 70 | + # Instead we use mode="full" and then slice the result so it behaves like "same" when needed |
| 71 | + data_bar = full_conv(grad, kernel[::-1]) |
| 72 | + data_bar = data_bar[kmn : data_bar.shape[0] - kmn] |
| 73 | + kernel_bar = full_conv(grad, data[::-1]) |
| 74 | + kernel_bar = kernel_bar[nkm : kernel_bar.shape[0] - nkm] |
47 | 75 |
|
48 | 76 | return [data_bar, kernel_bar]
|
0 commit comments