Skip to content

Commit bc9e7c7

Browse files
Add support for numpy like percentile and quantile
1 parent a377c22 commit bc9e7c7

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

pytensor/tensor/math.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29+
full_like,
2930
stack,
3031
switch,
32+
take_along_axis,
3133
)
3234
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3335
from pytensor.tensor.elemwise import (
@@ -2926,6 +2928,104 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
29262928
return x
29272929

29282930

2931+
def percentile(input, q, axis=None):
2932+
"""
2933+
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
2934+
2935+
Parameters
2936+
----------
2937+
input: TensorVariable
2938+
The input tensor.
2939+
q: float or list of floats
2940+
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
2941+
axis: None or int or list of int, optional
2942+
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2943+
"""
2944+
input = as_tensor_variable(input)
2945+
input_ndim = input.type.ndim
2946+
2947+
if axis is None:
2948+
axis = list(range(input_ndim))
2949+
elif isinstance(axis, (int | np.integer)):
2950+
axis = [axis]
2951+
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
2952+
axis = [int(axis)]
2953+
else:
2954+
axis = [int(a) for a in axis]
2955+
2956+
# Compute the shape of the remaining axes
2957+
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
2958+
input = input.dimshuffle(new_axes_order)
2959+
input_shape = shape(input)
2960+
remaining_axis_size = input_shape[: input.ndim - len(axis)]
2961+
flattened_axis_size = prod(input_shape[input.ndim - len(axis) :])
2962+
input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
2963+
axis = -1
2964+
2965+
# Sort the input tensor along the specified axis
2966+
sorted_input = input.sort(axis=axis)
2967+
input_shape = input.shape[axis]
2968+
2969+
if isinstance(q, (int | float)):
2970+
q = [q]
2971+
2972+
for percentile in q:
2973+
if percentile < 0 or percentile > 100:
2974+
raise ValueError("Percentiles must be in the range [0, 100]")
2975+
2976+
result = []
2977+
for percentile in q:
2978+
k = (percentile / 100.0) * (input_shape - 1)
2979+
k_floor = floor(k).astype("int64")
2980+
k_ceil = ceil(k).astype("int64")
2981+
2982+
indices1 = expand_dims(
2983+
full_like(sorted_input.take(0, axis=axis), k_floor), axis
2984+
)
2985+
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis)
2986+
2987+
val1 = take_along_axis(sorted_input, indices1, axis=axis)
2988+
val2 = take_along_axis(sorted_input, indices2, axis=axis)
2989+
2990+
d = k - k_floor
2991+
percentile_val = val1 + d * (val2 - val1)
2992+
2993+
result.append(percentile_val.squeeze(axis=axis))
2994+
2995+
if len(result) == 1:
2996+
result = result[0]
2997+
else:
2998+
result = stack(result)
2999+
3000+
result.name = "percentile"
3001+
return result
3002+
3003+
3004+
def quantile(input, q, axis=None):
3005+
"""
3006+
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
3007+
3008+
Parameters
3009+
----------
3010+
input: TensorVariable
3011+
The input tensor.
3012+
q: float or list of floats
3013+
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
3014+
axis: None or int or list of int, optional
3015+
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
3016+
"""
3017+
if isinstance(q, (int | float)):
3018+
q = [q]
3019+
3020+
for quantile in q:
3021+
if quantile < 0 or quantile > 1:
3022+
raise ValueError("Quantiles must be in the range [0, 1]")
3023+
3024+
percentiles = [100.0 * x for x in q]
3025+
3026+
return percentile(input, percentiles, axis)
3027+
3028+
29293029
# NumPy logical aliases
29303030
square = sqr
29313031

@@ -3080,6 +3180,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30803180
"outer",
30813181
"any",
30823182
"all",
3183+
"percentile",
3184+
"quantile",
30833185
"ptp",
30843186
"power",
30853187
"logaddexp",

0 commit comments

Comments
 (0)