Skip to content

Commit f277af7

Browse files
Implement median helper
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
1 parent ed6ca16 commit f277af7

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

pytensor/tensor/math.py

+43
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
15661566
return ret
15671567

15681568

1569+
def median(x: TensorLike, axis=None) -> TensorVariable:
1570+
"""
1571+
Computes the median along the given axis(es) of a tensor `input`.
1572+
1573+
Parameters
1574+
----------
1575+
x: TensorVariable
1576+
The input tensor.
1577+
axis: None or int or (list of int) (see `Sum`)
1578+
Compute the median along this axis of the tensor.
1579+
None means all axes (like numpy).
1580+
"""
1581+
from pytensor.ifelse import ifelse
1582+
1583+
x = as_tensor_variable(x)
1584+
x_ndim = x.type.ndim
1585+
if axis is None:
1586+
axis = list(range(x_ndim))
1587+
else:
1588+
axis = list(normalize_axis_tuple(axis, x_ndim))
1589+
1590+
non_axis = [i for i in range(x_ndim) if i not in axis]
1591+
non_axis_shape = [x.shape[i] for i in non_axis]
1592+
1593+
# Put axis at the end and unravel them
1594+
x_raveled = x.transpose(*non_axis, *axis)
1595+
if len(axis) > 1:
1596+
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
1597+
raveled_size = x_raveled.shape[-1]
1598+
k = raveled_size // 2
1599+
1600+
# Sort the input tensor along the specified axis and pick median value
1601+
x_sorted = x_raveled.sort(axis=-1)
1602+
k_values = x_sorted[..., k]
1603+
km1_values = x_sorted[..., k - 1]
1604+
1605+
even_median = (k_values + km1_values) / 2.0
1606+
odd_median = k_values.astype(even_median.type.dtype)
1607+
even_k = eq(mod(raveled_size, 2), 0)
1608+
return ifelse(even_k, even_median, odd_median, name="median")
1609+
1610+
15691611
@scalar_elemwise(symbolname="scalar_maximum")
15701612
def maximum(x, y):
15711613
"""elemwise maximum. See max for the maximum in one tensor"""
@@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30153057
"sum",
30163058
"prod",
30173059
"mean",
3060+
"median",
30183061
"var",
30193062
"std",
30203063
"std",

tests/tensor/test_math.py

+31
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
max_and_argmax,
9494
maximum,
9595
mean,
96+
median,
9697
min,
9798
minimum,
9899
mod,
@@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf):
37353736
out,
37363737
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
37373738
)
3739+
3740+
3741+
@pytest.mark.parametrize(
3742+
"ndim, axis",
3743+
[
3744+
(2, None),
3745+
(2, 1),
3746+
(2, (0, 1)),
3747+
(3, None),
3748+
(3, (1, 2)),
3749+
(4, (1, 3, 0)),
3750+
],
3751+
)
3752+
def test_median(ndim, axis):
3753+
# Generate random data with both odd and even lengths
3754+
shape_even = np.arange(1, ndim + 1) * 2
3755+
shape_odd = shape_even - 1
3756+
3757+
data_even = np.random.rand(*shape_even)
3758+
data_odd = np.random.rand(*shape_odd)
3759+
3760+
x = tensor(dtype="float64", shape=(None,) * ndim)
3761+
f = function([x], median(x, axis=axis))
3762+
result_odd = f(data_odd)
3763+
result_even = f(data_even)
3764+
expected_odd = np.median(data_odd, axis=axis)
3765+
expected_even = np.median(data_even, axis=axis)
3766+
3767+
assert np.allclose(result_odd, expected_odd)
3768+
assert np.allclose(result_even, expected_even)

0 commit comments

Comments
 (0)