Skip to content

Commit 6991cb1

Browse files
Add support for numpy like percentile and quantile
1 parent 23427a0 commit 6991cb1

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

pytensor/tensor/math.py

+102
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29+
full_like,
2930
stack,
3031
switch,
32+
take_along_axis,
3133
)
3234
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3335
from pytensor.tensor.elemwise import (
@@ -2870,6 +2872,104 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
28702872
return x
28712873

28722874

2875+
def percentile(input, q, axis=None):
2876+
"""
2877+
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
2878+
2879+
Parameters
2880+
----------
2881+
input: TensorVariable
2882+
The input tensor.
2883+
q: float or list of floats
2884+
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
2885+
axis: None or int or list of int, optional
2886+
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2887+
"""
2888+
input = as_tensor_variable(input)
2889+
input_ndim = input.type.ndim
2890+
2891+
if axis is None:
2892+
axis = list(range(input_ndim))
2893+
elif isinstance(axis, (int | np.integer)):
2894+
axis = [axis]
2895+
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
2896+
axis = [int(axis)]
2897+
else:
2898+
axis = [int(a) for a in axis]
2899+
2900+
# Compute the shape of the remaining axes
2901+
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
2902+
input = input.dimshuffle(new_axes_order)
2903+
input_shape = shape(input)
2904+
remaining_axis_size = input_shape[: input.ndim - len(axis)]
2905+
flattened_axis_size = prod(input_shape[input.ndim - len(axis) :])
2906+
input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
2907+
axis = -1
2908+
2909+
# Sort the input tensor along the specified axis
2910+
sorted_input = input.sort(axis=axis)
2911+
input_shape = input.shape[axis]
2912+
2913+
if isinstance(q, (int | float)):
2914+
q = [q]
2915+
2916+
for percentile in q:
2917+
if percentile < 0 or percentile > 100:
2918+
raise ValueError("Percentiles must be in the range [0, 100]")
2919+
2920+
result = []
2921+
for percentile in q:
2922+
k = (percentile / 100.0) * (input_shape - 1)
2923+
k_floor = floor(k).astype("int64")
2924+
k_ceil = ceil(k).astype("int64")
2925+
2926+
indices1 = expand_dims(
2927+
full_like(sorted_input.take(0, axis=axis), k_floor), axis
2928+
)
2929+
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis)
2930+
2931+
val1 = take_along_axis(sorted_input, indices1, axis=axis)
2932+
val2 = take_along_axis(sorted_input, indices2, axis=axis)
2933+
2934+
d = k - k_floor
2935+
percentile_val = val1 + d * (val2 - val1)
2936+
2937+
result.append(percentile_val.squeeze(axis=axis))
2938+
2939+
if len(result) == 1:
2940+
result = result[0]
2941+
else:
2942+
result = stack(result)
2943+
2944+
result.name = "percentile"
2945+
return result
2946+
2947+
2948+
def quantile(input, q, axis=None):
2949+
"""
2950+
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
2951+
2952+
Parameters
2953+
----------
2954+
input: TensorVariable
2955+
The input tensor.
2956+
q: float or list of floats
2957+
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2958+
axis: None or int or list of int, optional
2959+
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2960+
"""
2961+
if isinstance(q, (int | float)):
2962+
q = [q]
2963+
2964+
for quantile in q:
2965+
if quantile < 0 or quantile > 1:
2966+
raise ValueError("Quantiles must be in the range [0, 1]")
2967+
2968+
percentiles = [100.0 * x for x in q]
2969+
2970+
return percentile(input, percentiles, axis)
2971+
2972+
28732973
# NumPy logical aliases
28742974
square = sqr
28752975

@@ -3023,6 +3123,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30233123
"outer",
30243124
"any",
30253125
"all",
3126+
"percentile",
3127+
"quantile",
30263128
"ptp",
30273129
"power",
30283130
"logaddexp",

0 commit comments

Comments
 (0)