Skip to content

Commit 14967fe

Browse files
Add quantile and quantile dependent percentile
1 parent 28c1529 commit 14967fe

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

pytensor/tensor/math.py

+33-33
Original file line numberDiff line numberDiff line change
@@ -2883,6 +2883,31 @@ def percentile(input, q, axis=None):
28832883
axis: None or int or list of int, optional
28842884
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
28852885
"""
2886+
if isinstance(q, (int | float)):
2887+
q = [q]
2888+
2889+
for percentile in q:
2890+
if percentile < 0 or percentile > 100:
2891+
raise ValueError("Percentiles must be in the range [0, 100]")
2892+
2893+
quantiles = [x / 100 for x in q]
2894+
2895+
return quantile(input, quantiles, axis)
2896+
2897+
2898+
def quantile(input, q, axis=None):
2899+
"""
2900+
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
2901+
2902+
Parameters
2903+
----------
2904+
input: TensorVariable
2905+
The input tensor.
2906+
q: float or list of floats
2907+
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2908+
axis: None or int or list of int, optional
2909+
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2910+
"""
28862911
x = as_tensor_variable(input)
28872912
x_ndim = x.type.ndim
28882913

@@ -2911,13 +2936,13 @@ def percentile(input, q, axis=None):
29112936
if isinstance(q, (int | float)):
29122937
q = [q]
29132938

2914-
for percentile in q:
2915-
if percentile < 0 or percentile > 100:
2916-
raise ValueError("Percentiles must be in the range [0, 100]")
2939+
for quantile in q:
2940+
if quantile < 0 or quantile > 1:
2941+
raise ValueError("Quantiles must be in the range [0, 1]")
29172942

29182943
result = []
2919-
for percentile in q:
2920-
k = (percentile / 100.0) * (input_shape - 1)
2944+
for quantile in q:
2945+
k = (quantile) * (input_shape - 1)
29212946
k_floor = floor(k).astype("int64")
29222947
k_ceil = ceil(k).astype("int64")
29232948

@@ -2927,44 +2952,19 @@ def percentile(input, q, axis=None):
29272952
val2 = sorted_input[tuple(slices2)]
29282953

29292954
d = k - k_floor
2930-
percentile_val = val1 + d * (val2 - val1)
2955+
quantile_val = val1 + d * (val2 - val1)
29312956

2932-
result.append(percentile_val.squeeze(axis=-1))
2957+
result.append(quantile_val.squeeze(axis=-1))
29332958

29342959
if len(result) == 1:
29352960
result = result[0]
29362961
else:
29372962
result = stack(result)
29382963

2939-
result.name = "percentile"
2964+
result.name = "quantile"
29402965
return result
29412966

29422967

2943-
def quantile(input, q, axis=None):
2944-
"""
2945-
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
2946-
2947-
Parameters
2948-
----------
2949-
input: TensorVariable
2950-
The input tensor.
2951-
q: float or list of floats
2952-
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2953-
axis: None or int or list of int, optional
2954-
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2955-
"""
2956-
if isinstance(q, (int | float)):
2957-
q = [q]
2958-
2959-
for quantile in q:
2960-
if quantile < 0 or quantile > 1:
2961-
raise ValueError("Quantiles must be in the range [0, 1]")
2962-
2963-
percentiles = [100.0 * x for x in q]
2964-
2965-
return percentile(input, percentiles, axis)
2966-
2967-
29682968
# NumPy logical aliases
29692969
square = sqr
29702970

0 commit comments

Comments
 (0)