Skip to content

Implement C code for ExtractDiagonal and ARange #1392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3207,13 +3207,14 @@
return A_replicated.reshape(tiled_shape)


class ARange(Op):
class ARange(COp):
"""Create an array containing evenly spaced values within a given interval.

Parameters and behaviour are the same as numpy.arange().

"""

# TODO: Arange should work with scalars as inputs, not arrays
Copy link
Preview

Copilot AI May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider addressing the TODO by supporting scalar inputs for ARange, which would simplify the interface and align behavior with numpy.arange.

Copilot uses AI. Check for mistakes.

__props__ = ("dtype",)

def __init__(self, dtype):
Expand Down Expand Up @@ -3293,13 +3294,30 @@
)
]

def perform(self, node, inp, out_):
start, stop, step = inp
(out,) = out_
start = start.item()
stop = stop.item()
step = step.item()
out[0] = np.arange(start, stop, step, dtype=self.dtype)
def perform(self, node, inputs, output_storage):
start, stop, step = inputs
output_storage[0][0] = np.arange(
start.item(), stop.item(), step.item(), dtype=self.dtype
)

def c_code(self, node, nodename, input_names, output_names, sub):
[start_name, stop_name, step_name] = input_names
[out_name] = output_names
typenum = np.dtype(self.dtype).num
return f"""
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
if (!{out_name}) {{
{sub["fail"]}
}}
"""

def c_code_cache_version(self):
return (0,)

def connection_pattern(self, node):
return [[True], [False], [True]]
Expand Down Expand Up @@ -3686,7 +3704,7 @@


# TODO: optimization to insert ExtractDiag with view=True
class ExtractDiag(Op):
class ExtractDiag(COp):
"""
Return specified diagonals.

Expand Down Expand Up @@ -3742,7 +3760,7 @@

__props__ = ("offset", "axis1", "axis2", "view")

def __init__(self, offset=0, axis1=0, axis2=1, view=False):
def __init__(self, offset=0, axis1=0, axis2=1, view=True):
self.view = view
if self.view:
self.view_map = {0: [0]}
Expand All @@ -3765,24 +3783,74 @@
if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)

out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]
if (dim1 := x.type.shape[self.axis1]) is not None and (
dim2 := x.type.shape[self.axis2]
) is not None:
offset = self.offset
if offset > 0:
diag_size = int(np.clip(dim2 - offset, 0, dim1))
elif offset < 0:
diag_size = int(np.clip(dim1 + offset, 0, dim2))
else:
diag_size = int(np.minimum(dim1, dim2))
else:
diag_size = None

out_shape = (
*(
dim
for i, dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
),
diag_size,
)

return Apply(
self,
[x],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
[x.type.clone(dtype=x.dtype, shape=out_shape)()],
)

def perform(self, node, inputs, outputs):
def perform(self, node, inputs, output_storage):
(x,) = inputs
(z,) = outputs
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
if not self.view:
z[0] = z[0].copy()
out = x.diagonal(self.offset, self.axis1, self.axis2)
if self.view:
try:
out.flags.writeable = True
except ValueError:

Check warning on line 3820 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3820

Added line #L3820 was not covered by tests
# We can't make this array writable
Copy link
Preview

Copilot AI May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider logging a warning or providing additional context here when the array cannot be made writeable and a copy is made, to aid in diagnosing potential performance implications.

Suggested change
# We can't make this array writable
# We can't make this array writable
warnings.warn(
"Unable to make the array writeable. A copy of the array is being created instead. "
"This may have performance implications if the array is large.",
RuntimeWarning,
)

Copilot uses AI. Check for mistakes.

out = out.copy()

Check warning on line 3822 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3822

Added line #L3822 was not covered by tests
else:
out = out.copy()

Check warning on line 3824 in pytensor/tensor/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/basic.py#L3824

Added line #L3824 was not covered by tests
output_storage[0][0] = out

def c_code(self, node, nodename, input_names, output_names, sub):
[x_name] = input_names
[out_name] = output_names
return f"""
Py_XDECREF({out_name});

{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
if (!{out_name}) {{
{sub["fail"]} // Error already set by Numpy
}}

if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
// Make output writeable if input was writeable
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
}} else {{
// Make a copy
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
Py_DECREF({out_name});
if (!{out_name}_copy) {{
{sub['fail']}; // Error already set by Numpy
}}
{out_name} = {out_name}_copy;
}}
"""

def c_code_cache_version(self):
return (0,)

def grad(self, inputs, gout):
# Avoid circular import
Expand Down Expand Up @@ -3829,19 +3897,6 @@
out_shape.append(diag_size)
return [tuple(out_shape)]

def __setstate__(self, state):
self.__dict__.update(state)

if self.view:
self.view_map = {0: [0]}

if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1


def extract_diag(x):
warnings.warn(
Expand Down