Skip to content

Commit 1c11dbe

Browse files
Add logic for median from scratch
1 parent 56c30e0 commit 1c11dbe

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

pytensor/tensor/math.py

+36
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29+
extract_constant,
30+
full_like,
2931
stack,
3032
switch,
33+
take_along_axis,
3134
)
3235
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3336
from pytensor.tensor.elemwise import (
@@ -1571,6 +1574,38 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
15711574
return ret
15721575

15731576

1577+
def median(input, axis=None):
1578+
"""
1579+
Computes the median along the given axis(es) of a tensor `input`.
1580+
1581+
Parameters
1582+
----------
1583+
axis: None or int or (list of int) (see `Sum`)
1584+
Compute the median along this axis of the tensor.
1585+
None means all axes (like numpy).
1586+
1587+
Notes
1588+
-----
1589+
This function uses the numpy implementation of median.
1590+
"""
1591+
1592+
input = as_tensor_variable(input)
1593+
sorted_input = input.sort(axis=axis)
1594+
shape = input.shape[axis]
1595+
k = extract_constant(shape) // 2
1596+
if extract_constant(shape % 2) == 0:
1597+
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
1598+
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1599+
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
1600+
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
1601+
median_val = (ans1 + ans2) / 2.0
1602+
else:
1603+
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1604+
median_val = take_along_axis(sorted_input, indices, axis=axis)
1605+
median_val.name = "median"
1606+
return median_val.squeeze(axis=axis)
1607+
1608+
15741609
@scalar_elemwise(symbolname="scalar_maximum")
15751610
def maximum(x, y):
15761611
"""elemwise maximum. See max for the maximum in one tensor"""
@@ -3006,6 +3041,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30063041
"sum",
30073042
"prod",
30083043
"mean",
3044+
"median",
30093045
"var",
30103046
"std",
30113047
"std",

tests/tensor/test_math.py

+31
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
max_and_argmax,
9595
maximum,
9696
mean,
97+
median,
9798
min,
9899
minimum,
99100
mod,
@@ -3731,3 +3732,33 @@ def test_nan_to_num(nan, posinf, neginf):
37313732
out,
37323733
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
37333734
)
3735+
3736+
3737+
@pytest.mark.parametrize(
3738+
"data, axis",
3739+
[
3740+
# 1D array
3741+
([1, 7, 3, 6, 5, 2, 4], 0),
3742+
# 2D array
3743+
([[6, 2], [4, 3], [1, 5]], 0),
3744+
([[6, 2], [4, 3], [1, 5]], 1),
3745+
# 3D array
3746+
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
3747+
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
3748+
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),
3749+
# 4D array
3750+
(
3751+
[
3752+
[[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]],
3753+
[[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]],
3754+
],
3755+
3,
3756+
),
3757+
],
3758+
)
3759+
def test_median(data, axis):
3760+
x = tensor(shape=np.array(data).shape)
3761+
f = function([x], median(x, axis=axis))
3762+
result = f(data)
3763+
expected = np.median(data, axis=axis)
3764+
assert np.allclose(result, expected)

0 commit comments

Comments
 (0)