Skip to content

Commit 163736a

Browse files
Modify test for median
1 parent c1cb563 commit 163736a

File tree

2 files changed

+58
-36
lines changed

2 files changed

+58
-36
lines changed

pytensor/tensor/math.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -1580,33 +1580,54 @@ def median(input, axis=None):
15801580
15811581
Parameters
15821582
----------
1583+
input: TensorVariable
1584+
The input tensor.
15831585
axis: None or int or (list of int) (see `Sum`)
15841586
Compute the median along this axis of the tensor.
15851587
None means all axes (like numpy).
1586-
1587-
Notes
1588-
-----
1589-
This function uses the numpy implementation of median.
15901588
"""
15911589
from pytensor.ifelse import ifelse
15921590

1591+
input = as_tensor_variable(input)
1592+
input_ndim = input.type.ndim
15931593
if axis is None:
1594-
input = input.flatten()
1595-
axis = 0
1594+
axis = list(range(input_ndim))
1595+
elif isinstance(axis, int | np.integer):
1596+
axis = [axis]
1597+
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
1598+
axis = [int(axis)]
1599+
else:
1600+
axis = [int(a) for a in axis]
15961601

1597-
input = as_tensor_variable(input)
1602+
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
1603+
input = input.dimshuffle(new_axes_order)
1604+
input_shape = input.shape
1605+
1606+
remaining_axis_size = shape(input)[: input.ndim - len(axis)]
1607+
flattened_axis_size = prod(shape(input)[input.ndim - len(axis) :])
1608+
1609+
input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
1610+
axis = -1
1611+
1612+
# Sort the input tensor along the specified axis
15981613
sorted_input = input.sort(axis=axis)
1599-
shape = input.shape[axis]
1600-
k = extract_constant(shape) // 2
1614+
input_shape = input.shape[axis]
1615+
k = extract_constant(input_shape) // 2
1616+
16011617
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
16021618
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
16031619
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
16041620
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
16051621
median_val_even = (ans1 + ans2) / 2.0
1622+
16061623
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1607-
median_val_odd = take_along_axis(sorted_input, indices, axis=axis)
1608-
median_val = ifelse(eq(mod(shape, 2), 0), median_val_even, median_val_odd)
1624+
median_val_odd = (
1625+
take_along_axis(sorted_input, indices, axis=axis) / 1.0
1626+
) # Divide by one so that the two dtypes passed in ifelse are compatible
1627+
1628+
median_val = ifelse(eq(mod(input_shape, 2), 0), median_val_even, median_val_odd)
16091629
median_val.name = "median"
1630+
16101631
return median_val.squeeze(axis=axis)
16111632

16121633

tests/tensor/test_math.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -3735,32 +3735,33 @@ def test_nan_to_num(nan, posinf, neginf):
37353735

37363736

37373737
@pytest.mark.parametrize(
3738-
"data, axis",
3738+
"ndim, axis",
37393739
[
3740-
# 1D array
3741-
([1, 7, 3, 6, 5, 2, 4], None),
3742-
([1, 7, 3, 6, 5, 2, 4], 0),
3743-
# 2D array
3744-
([[6, 2], [4, 3], [1, 5]], 0),
3745-
([[6, 2], [4, 3], [1, 5]], 1),
3746-
# 3D array
3747-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], None),
3748-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
3749-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
3750-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),
3751-
# 4D array
3752-
(
3753-
[
3754-
[[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]],
3755-
[[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]],
3756-
],
3757-
3,
3758-
),
3740+
(2, None),
3741+
(2, 1),
3742+
(2, (0, 1)),
3743+
(3, None),
3744+
(3, (1, 2)),
3745+
(4, (1, 3, 0)),
37593746
],
37603747
)
3761-
def test_median(data, axis):
3762-
x = tensor(shape=np.array(data).shape)
3748+
def test_median(ndim, axis):
3749+
# Generate random data with both odd and even lengths
3750+
shape = tuple(np.random.randint(2, 6) for _ in range(ndim))
3751+
data_odd = np.random.rand(*shape) * 100
3752+
data_even = (
3753+
np.random.rand(
3754+
*(dim + 1 if i == ndim - 1 else dim for i, dim in enumerate(shape))
3755+
)
3756+
* 100
3757+
)
3758+
3759+
x = tensor(dtype="float64", shape=(None,) * ndim)
37633760
f = function([x], median(x, axis=axis))
3764-
result = f(data)
3765-
expected = np.median(data, axis=axis)
3766-
assert np.allclose(result, expected)
3761+
result_odd = f(data_odd)
3762+
result_even = f(data_even)
3763+
expected_odd = np.median(data_odd, axis=axis)
3764+
expected_even = np.median(data_even, axis=axis)
3765+
3766+
assert np.allclose(result_odd, expected_odd)
3767+
assert np.allclose(result_even, expected_even)

0 commit comments

Comments
 (0)