@@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
1566
1566
return ret
1567
1567
1568
1568
1569
+ def median (x : TensorLike , axis = None ) -> TensorVariable :
1570
+ """
1571
+ Computes the median along the given axis(es) of a tensor `input`.
1572
+
1573
+ Parameters
1574
+ ----------
1575
+ x: TensorVariable
1576
+ The input tensor.
1577
+ axis: None or int or (list of int) (see `Sum`)
1578
+ Compute the median along this axis of the tensor.
1579
+ None means all axes (like numpy).
1580
+ """
1581
+ from pytensor .ifelse import ifelse
1582
+
1583
+ x = as_tensor_variable (x )
1584
+ x_ndim = x .type .ndim
1585
+ if axis is None :
1586
+ axis = list (range (x_ndim ))
1587
+ else :
1588
+ axis = list (normalize_axis_tuple (axis , x_ndim ))
1589
+
1590
+ non_axis = [i for i in range (x_ndim ) if i not in axis ]
1591
+ non_axis_shape = [x .shape [i ] for i in non_axis ]
1592
+
1593
+ # Put axis at the end and unravel them
1594
+ x_raveled = x .transpose (* non_axis , * axis )
1595
+ if len (axis ) > 1 :
1596
+ x_raveled = x_raveled .reshape ((* non_axis_shape , - 1 ))
1597
+ raveled_size = x_raveled .shape [- 1 ]
1598
+ k = raveled_size // 2
1599
+
1600
+ # Sort the input tensor along the specified axis and pick median value
1601
+ x_sorted = x_raveled .sort (axis = - 1 )
1602
+ k_values = x_sorted [..., k ]
1603
+ km1_values = x_sorted [..., k - 1 ]
1604
+
1605
+ even_median = (k_values + km1_values ) / 2.0
1606
+ odd_median = k_values .astype (even_median .type .dtype )
1607
+ even_k = eq (mod (raveled_size , 2 ), 0 )
1608
+ return ifelse (even_k , even_median , odd_median , name = "median" )
1609
+
1610
+
1569
1611
@scalar_elemwise (symbolname = "scalar_maximum" )
1570
1612
def maximum (x , y ):
1571
1613
"""elemwise maximum. See max for the maximum in one tensor"""
@@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
3015
3057
"sum" ,
3016
3058
"prod" ,
3017
3059
"mean" ,
3060
+ "median" ,
3018
3061
"var" ,
3019
3062
"std" ,
3020
3063
"std" ,
0 commit comments