Skip to content

Commit aa48112

Browse files
Modify test for median
1 parent c1cb563 commit aa48112

File tree

2 files changed

+55
-44
lines changed

2 files changed

+55
-44
lines changed

pytensor/tensor/math.py

+32-19
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29-
extract_constant,
3029
full_like,
3130
stack,
3231
switch,
@@ -1580,34 +1579,48 @@ def median(input, axis=None):
15801579
15811580
Parameters
15821581
----------
1582+
input: TensorVariable
1583+
The input tensor.
15831584
axis: None or int or (list of int) (see `Sum`)
15841585
Compute the median along this axis of the tensor.
15851586
None means all axes (like numpy).
1586-
1587-
Notes
1588-
-----
1589-
This function uses the numpy implementation of median.
15901587
"""
15911588
from pytensor.ifelse import ifelse
15921589

1590+
x = as_tensor_variable(input)
1591+
x_ndim = x.type.ndim
15931592
if axis is None:
1594-
input = input.flatten()
1595-
axis = 0
1593+
axis = list(range(x_ndim))
1594+
else:
1595+
axis = list(normalize_axis_tuple(axis, x_ndim))
15961596

1597-
input = as_tensor_variable(input)
1598-
sorted_input = input.sort(axis=axis)
1599-
shape = input.shape[axis]
1600-
k = extract_constant(shape) // 2
1601-
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
1602-
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1603-
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
1604-
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
1597+
new_axes_order = [i for i in range(x.ndim) if i not in axis] + list(axis)
1598+
x = x.dimshuffle(new_axes_order)
1599+
x_shape = x.shape
1600+
1601+
remaining_axis_size = shape(x)[: x.ndim - len(axis)]
1602+
1603+
x = x.reshape((*remaining_axis_size, -1))
1604+
1605+
# Sort the input tensor along the specified axis
1606+
sorted_x = x.sort(axis=-1)
1607+
x_shape = x.shape[-1]
1608+
k = x_shape // 2
1609+
1610+
indices1 = expand_dims(full_like(sorted_x.take(0, axis=-1), k), -1)
1611+
indices2 = expand_dims(full_like(sorted_x.take(0, axis=-1), k - 1), -1)
1612+
ans1 = take_along_axis(sorted_x, indices1, axis=-1)
1613+
ans2 = take_along_axis(sorted_x, indices2, axis=-1)
16051614
median_val_even = (ans1 + ans2) / 2.0
1606-
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1607-
median_val_odd = take_along_axis(sorted_input, indices, axis=axis)
1608-
median_val = ifelse(eq(mod(shape, 2), 0), median_val_even, median_val_odd)
1615+
1616+
median_val_odd = (
1617+
take_along_axis(sorted_x, indices1, axis=-1) / 1.0
1618+
) # Divide by one so that the two dtypes passed in ifelse are compatible
1619+
1620+
median_val = ifelse(eq(mod(x_shape, 2), 0), median_val_even, median_val_odd)
16091621
median_val.name = "median"
1610-
return median_val.squeeze(axis=axis)
1622+
1623+
return median_val.squeeze(axis=-1)
16111624

16121625

16131626
@scalar_elemwise(symbolname="scalar_maximum")

tests/tensor/test_math.py

+23-25
Original file line numberDiff line numberDiff line change
@@ -3735,32 +3735,30 @@ def test_nan_to_num(nan, posinf, neginf):
37353735

37363736

37373737
@pytest.mark.parametrize(
3738-
"data, axis",
3738+
"ndim, axis",
37393739
[
3740-
# 1D array
3741-
([1, 7, 3, 6, 5, 2, 4], None),
3742-
([1, 7, 3, 6, 5, 2, 4], 0),
3743-
# 2D array
3744-
([[6, 2], [4, 3], [1, 5]], 0),
3745-
([[6, 2], [4, 3], [1, 5]], 1),
3746-
# 3D array
3747-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], None),
3748-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
3749-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
3750-
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),
3751-
# 4D array
3752-
(
3753-
[
3754-
[[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]],
3755-
[[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]],
3756-
],
3757-
3,
3758-
),
3740+
(2, None),
3741+
(2, 1),
3742+
(2, (0, 1)),
3743+
(3, None),
3744+
(3, (1, 2)),
3745+
(4, (1, 3, 0)),
37593746
],
37603747
)
3761-
def test_median(data, axis):
3762-
x = tensor(shape=np.array(data).shape)
3748+
def test_median(ndim, axis):
3749+
# Generate random data with both odd and even lengths
3750+
shape_even = np.arange(1, ndim + 1) * 2
3751+
shape_odd = shape_even - 1
3752+
3753+
data_even = np.random.rand(*shape_even)
3754+
data_odd = np.random.rand(*shape_odd)
3755+
3756+
x = tensor(dtype="float64", shape=(None,) * ndim)
37633757
f = function([x], median(x, axis=axis))
3764-
result = f(data)
3765-
expected = np.median(data, axis=axis)
3766-
assert np.allclose(result, expected)
3758+
result_odd = f(data_odd)
3759+
result_even = f(data_even)
3760+
expected_odd = np.median(data_odd, axis=axis)
3761+
expected_even = np.median(data_even, axis=axis)
3762+
3763+
assert np.allclose(result_odd, expected_odd)
3764+
assert np.allclose(result_even, expected_even)

0 commit comments

Comments
 (0)