diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 4786b71778..39e482542c 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3207,13 +3207,14 @@ def tile( return A_replicated.reshape(tiled_shape) -class ARange(Op): +class ARange(COp): """Create an array containing evenly spaced values within a given interval. Parameters and behaviour are the same as numpy.arange(). """ + # TODO: Arange should work with scalars as inputs, not arrays __props__ = ("dtype",) def __init__(self, dtype): @@ -3293,13 +3294,30 @@ def upcast(var): ) ] - def perform(self, node, inp, out_): - start, stop, step = inp - (out,) = out_ - start = start.item() - stop = stop.item() - step = step.item() - out[0] = np.arange(start, stop, step, dtype=self.dtype) + def perform(self, node, inputs, output_storage): + start, stop, step = inputs + output_storage[0][0] = np.arange( + start.item(), stop.item(), step.item(), dtype=self.dtype + ) + + def c_code(self, node, nodename, input_names, output_names, sub): + [start_name, stop_name, step_name] = input_names + [out_name] = output_names + typenum = np.dtype(self.dtype).num + return f""" + double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0]; + double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0]; + double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0]; + //printf("start: %f, stop: %f, step: %f\\n", start, stop, step); + Py_XDECREF({out_name}); + {out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum}); + if (!{out_name}) {{ + {sub["fail"]} + }} + """ + + def c_code_cache_version(self): + return (0,) def connection_pattern(self, node): return [[True], [False], [True]] @@ -3686,7 +3704,7 @@ def inverse_permutation(perm): # TODO: optimization to insert ExtractDiag with view=True -class ExtractDiag(Op): +class ExtractDiag(COp): """ Return specified diagonals. @@ -3742,7 +3760,7 @@ class ExtractDiag(Op): __props__ = ("offset", "axis1", "axis2", "view") - def __init__(self, offset=0, axis1=0, axis2=1, view=False): + def __init__(self, offset=0, axis1=0, axis2=1, view=True): self.view = view if self.view: self.view_map = {0: [0]} @@ -3765,24 +3783,74 @@ def make_node(self, x): if x.ndim < 2: raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x) - out_shape = [ - st_dim - for i, st_dim in enumerate(x.type.shape) - if i not in (self.axis1, self.axis2) - ] + [None] + if (dim1 := x.type.shape[self.axis1]) is not None and ( + dim2 := x.type.shape[self.axis2] + ) is not None: + offset = self.offset + if offset > 0: + diag_size = int(np.clip(dim2 - offset, 0, dim1)) + elif offset < 0: + diag_size = int(np.clip(dim1 + offset, 0, dim2)) + else: + diag_size = int(np.minimum(dim1, dim2)) + else: + diag_size = None + + out_shape = ( + *( + dim + for i, dim in enumerate(x.type.shape) + if i not in (self.axis1, self.axis2) + ), + diag_size, + ) return Apply( self, [x], - [x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()], + [x.type.clone(dtype=x.dtype, shape=out_shape)()], ) - def perform(self, node, inputs, outputs): + def perform(self, node, inputs, output_storage): (x,) = inputs - (z,) = outputs - z[0] = x.diagonal(self.offset, self.axis1, self.axis2) - if not self.view: - z[0] = z[0].copy() + out = x.diagonal(self.offset, self.axis1, self.axis2) + if self.view: + try: + out.flags.writeable = True + except ValueError: + # We can't make this array writable + out = out.copy() + else: + out = out.copy() + output_storage[0][0] = out + + def c_code(self, node, nodename, input_names, output_names, sub): + [x_name] = input_names + [out_name] = output_names + return f""" + Py_XDECREF({out_name}); + + {out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2}); + if (!{out_name}) {{ + {sub["fail"]} // Error already set by Numpy + }} + + if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{ + // Make output writeable if input was writeable + PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE); + }} else {{ + // Make a copy + PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name}); + Py_DECREF({out_name}); + if (!{out_name}_copy) {{ + {sub['fail']}; // Error already set by Numpy + }} + {out_name} = {out_name}_copy; + }} + """ + + def c_code_cache_version(self): + return (0,) def grad(self, inputs, gout): # Avoid circular import @@ -3829,19 +3897,6 @@ def infer_shape(self, fgraph, node, shapes): out_shape.append(diag_size) return [tuple(out_shape)] - def __setstate__(self, state): - self.__dict__.update(state) - - if self.view: - self.view_map = {0: [0]} - - if "offset" not in state: - self.offset = 0 - if "axis1" not in state: - self.axis1 = 0 - if "axis2" not in state: - self.axis2 = 1 - def extract_diag(x): warnings.warn(