|
26 | 26 | concatenate,
|
27 | 27 | constant,
|
28 | 28 | expand_dims,
|
| 29 | + full_like, |
29 | 30 | stack,
|
30 | 31 | switch,
|
| 32 | + take_along_axis, |
31 | 33 | )
|
32 | 34 | from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
|
33 | 35 | from pytensor.tensor.elemwise import (
|
@@ -2926,6 +2928,104 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
2926 | 2928 | return x
|
2927 | 2929 |
|
2928 | 2930 |
|
| 2931 | +def percentile(input, q, axis=None): |
| 2932 | + """ |
| 2933 | + Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation. |
| 2934 | +
|
| 2935 | + Parameters |
| 2936 | + ---------- |
| 2937 | + input: TensorVariable |
| 2938 | + The input tensor. |
| 2939 | + q: float or list of floats |
| 2940 | + Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive. |
| 2941 | + axis: None or int or list of int, optional |
| 2942 | + Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array. |
| 2943 | + """ |
| 2944 | + input = as_tensor_variable(input) |
| 2945 | + input_ndim = input.type.ndim |
| 2946 | + |
| 2947 | + if axis is None: |
| 2948 | + axis = list(range(input_ndim)) |
| 2949 | + elif isinstance(axis, (int | np.integer)): |
| 2950 | + axis = [axis] |
| 2951 | + elif isinstance(axis, np.ndarray) and axis.ndim == 0: |
| 2952 | + axis = [int(axis)] |
| 2953 | + else: |
| 2954 | + axis = [int(a) for a in axis] |
| 2955 | + |
| 2956 | + # Compute the shape of the remaining axes |
| 2957 | + new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis) |
| 2958 | + input = input.dimshuffle(new_axes_order) |
| 2959 | + input_shape = shape(input) |
| 2960 | + remaining_axis_size = input_shape[: input.ndim - len(axis)] |
| 2961 | + flattened_axis_size = prod(input_shape[input.ndim - len(axis) :]) |
| 2962 | + input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]])) |
| 2963 | + axis = -1 |
| 2964 | + |
| 2965 | + # Sort the input tensor along the specified axis |
| 2966 | + sorted_input = input.sort(axis=axis) |
| 2967 | + input_shape = input.shape[axis] |
| 2968 | + |
| 2969 | + if isinstance(q, (int | float)): |
| 2970 | + q = [q] |
| 2971 | + |
| 2972 | + for percentile in q: |
| 2973 | + if percentile < 0 or percentile > 100: |
| 2974 | + raise ValueError("Percentiles must be in the range [0, 100]") |
| 2975 | + |
| 2976 | + result = [] |
| 2977 | + for percentile in q: |
| 2978 | + k = (percentile / 100.0) * (input_shape - 1) |
| 2979 | + k_floor = floor(k).astype("int64") |
| 2980 | + k_ceil = ceil(k).astype("int64") |
| 2981 | + |
| 2982 | + indices1 = expand_dims( |
| 2983 | + full_like(sorted_input.take(0, axis=axis), k_floor), axis |
| 2984 | + ) |
| 2985 | + indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis) |
| 2986 | + |
| 2987 | + val1 = take_along_axis(sorted_input, indices1, axis=axis) |
| 2988 | + val2 = take_along_axis(sorted_input, indices2, axis=axis) |
| 2989 | + |
| 2990 | + d = k - k_floor |
| 2991 | + percentile_val = val1 + d * (val2 - val1) |
| 2992 | + |
| 2993 | + result.append(percentile_val.squeeze(axis=axis)) |
| 2994 | + |
| 2995 | + if len(result) == 1: |
| 2996 | + result = result[0] |
| 2997 | + else: |
| 2998 | + result = stack(result) |
| 2999 | + |
| 3000 | + result.name = "percentile" |
| 3001 | + return result |
| 3002 | + |
| 3003 | + |
| 3004 | +def quantile(input, q, axis=None): |
| 3005 | + """ |
| 3006 | + Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation. |
| 3007 | +
|
| 3008 | + Parameters |
| 3009 | + ---------- |
| 3010 | + input: TensorVariable |
| 3011 | + The input tensor. |
| 3012 | + q: float or list of floats |
| 3013 | + Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. |
| 3014 | + axis: None or int or list of int, optional |
| 3015 | + Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array. |
| 3016 | + """ |
| 3017 | + if isinstance(q, (int | float)): |
| 3018 | + q = [q] |
| 3019 | + |
| 3020 | + for quantile in q: |
| 3021 | + if quantile < 0 or quantile > 1: |
| 3022 | + raise ValueError("Quantiles must be in the range [0, 1]") |
| 3023 | + |
| 3024 | + percentiles = [100.0 * x for x in q] |
| 3025 | + |
| 3026 | + return percentile(input, percentiles, axis) |
| 3027 | + |
| 3028 | + |
2929 | 3029 | # NumPy logical aliases
|
2930 | 3030 | square = sqr
|
2931 | 3031 |
|
@@ -3080,6 +3180,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
3080 | 3180 | "outer",
|
3081 | 3181 | "any",
|
3082 | 3182 | "all",
|
| 3183 | + "percentile", |
| 3184 | + "quantile", |
3083 | 3185 | "ptp",
|
3084 | 3186 | "power",
|
3085 | 3187 | "logaddexp",
|
|
0 commit comments