|
26 | 26 | concatenate,
|
27 | 27 | constant,
|
28 | 28 | expand_dims,
|
| 29 | + full_like, |
29 | 30 | stack,
|
30 | 31 | switch,
|
| 32 | + take_along_axis, |
31 | 33 | )
|
32 | 34 | from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
|
33 | 35 | from pytensor.tensor.elemwise import (
|
@@ -2870,6 +2872,104 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
2870 | 2872 | return x
|
2871 | 2873 |
|
2872 | 2874 |
|
| 2875 | +def percentile(input, q, axis=None): |
| 2876 | + """ |
| 2877 | + Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation. |
| 2878 | +
|
| 2879 | + Parameters |
| 2880 | + ---------- |
| 2881 | + input: TensorVariable |
| 2882 | + The input tensor. |
| 2883 | + q: float or list of floats |
| 2884 | + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. |
| 2885 | + axis: None or int or list of int, optional |
| 2886 | + Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array. |
| 2887 | + """ |
| 2888 | + input = as_tensor_variable(input) |
| 2889 | + input_ndim = input.type.ndim |
| 2890 | + |
| 2891 | + if axis is None: |
| 2892 | + axis = list(range(input_ndim)) |
| 2893 | + elif isinstance(axis, (int | np.integer)): |
| 2894 | + axis = [axis] |
| 2895 | + elif isinstance(axis, np.ndarray) and axis.ndim == 0: |
| 2896 | + axis = [int(axis)] |
| 2897 | + else: |
| 2898 | + axis = [int(a) for a in axis] |
| 2899 | + |
| 2900 | + # Compute the shape of the remaining axes |
| 2901 | + new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis) |
| 2902 | + input = input.dimshuffle(new_axes_order) |
| 2903 | + input_shape = shape(input) |
| 2904 | + remaining_axis_size = input_shape[: input.ndim - len(axis)] |
| 2905 | + flattened_axis_size = prod(input_shape[input.ndim - len(axis) :]) |
| 2906 | + input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]])) |
| 2907 | + axis = -1 |
| 2908 | + |
| 2909 | + # Sort the input tensor along the specified axis |
| 2910 | + sorted_input = input.sort(axis=axis) |
| 2911 | + input_shape = input.shape[axis] |
| 2912 | + |
| 2913 | + if isinstance(q, (int | float)): |
| 2914 | + q = [q] |
| 2915 | + |
| 2916 | + for percentile in q: |
| 2917 | + if percentile < 0 or percentile > 100: |
| 2918 | + raise ValueError("Percentiles must be in the range [0, 100]") |
| 2919 | + |
| 2920 | + result = [] |
| 2921 | + for percentile in q: |
| 2922 | + k = (percentile / 100.0) * (input_shape - 1) |
| 2923 | + k_floor = floor(k).astype("int64") |
| 2924 | + k_ceil = ceil(k).astype("int64") |
| 2925 | + |
| 2926 | + indices1 = expand_dims( |
| 2927 | + full_like(sorted_input.take(0, axis=axis), k_floor), axis |
| 2928 | + ) |
| 2929 | + indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis) |
| 2930 | + |
| 2931 | + val1 = take_along_axis(sorted_input, indices1, axis=axis) |
| 2932 | + val2 = take_along_axis(sorted_input, indices2, axis=axis) |
| 2933 | + |
| 2934 | + d = k - k_floor |
| 2935 | + percentile_val = val1 + d * (val2 - val1) |
| 2936 | + |
| 2937 | + result.append(percentile_val.squeeze(axis=axis)) |
| 2938 | + |
| 2939 | + if len(result) == 1: |
| 2940 | + result = result[0] |
| 2941 | + else: |
| 2942 | + result = stack(result) |
| 2943 | + |
| 2944 | + result.name = "percentile" |
| 2945 | + return result |
| 2946 | + |
| 2947 | + |
| 2948 | +def quantile(input, q, axis=None): |
| 2949 | + """ |
| 2950 | + Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation. |
| 2951 | +
|
| 2952 | + Parameters |
| 2953 | + ---------- |
| 2954 | + input: TensorVariable |
| 2955 | + The input tensor. |
| 2956 | + q: float or list of floats |
| 2957 | + Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. |
| 2958 | + axis: None or int or list of int, optional |
| 2959 | + Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array. |
| 2960 | + """ |
| 2961 | + if isinstance(q, (int | float)): |
| 2962 | + q = [q] |
| 2963 | + |
| 2964 | + for quantile in q: |
| 2965 | + if quantile < 0 or quantile > 1: |
| 2966 | + raise ValueError("Quantiles must be in the range [0, 1]") |
| 2967 | + |
| 2968 | + percentiles = [100.0 * x for x in q] |
| 2969 | + |
| 2970 | + return percentile(input, percentiles, axis) |
| 2971 | + |
| 2972 | + |
2873 | 2973 | # NumPy logical aliases
|
2874 | 2974 | square = sqr
|
2875 | 2975 |
|
@@ -3023,6 +3123,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
3023 | 3123 | "outer",
|
3024 | 3124 | "any",
|
3025 | 3125 | "all",
|
| 3126 | + "percentile", |
| 3127 | + "quantile", |
3026 | 3128 | "ptp",
|
3027 | 3129 | "power",
|
3028 | 3130 | "logaddexp",
|
|
0 commit comments