26
26
concatenate ,
27
27
constant ,
28
28
expand_dims ,
29
- full_like ,
30
29
stack ,
31
30
switch ,
32
- take_along_axis ,
33
31
)
34
32
from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
35
33
from pytensor .tensor .elemwise import (
@@ -2941,11 +2939,11 @@ def percentile(input, q, axis=None):
2941
2939
axis: None or int or list of int, optional
2942
2940
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2943
2941
"""
2944
- input = as_tensor_variable (input )
2945
- input_ndim = input .type .ndim
2942
+ x = as_tensor_variable (input )
2943
+ x_ndim = x .type .ndim
2946
2944
2947
2945
if axis is None :
2948
- axis = list (range (input_ndim ))
2946
+ axis = list (range (x_ndim ))
2949
2947
elif isinstance (axis , (int | np .integer )):
2950
2948
axis = [axis ]
2951
2949
elif isinstance (axis , np .ndarray ) and axis .ndim == 0 :
@@ -2954,17 +2952,17 @@ def percentile(input, q, axis=None):
2954
2952
axis = [int (a ) for a in axis ]
2955
2953
2956
2954
# Compute the shape of the remaining axes
2957
- new_axes_order = [i for i in range (input .ndim ) if i not in axis ] + list (axis )
2958
- input = input .dimshuffle (new_axes_order )
2959
- input_shape = shape (input )
2960
- remaining_axis_size = input_shape [: input .ndim - len (axis )]
2961
- flattened_axis_size = prod (input_shape [input .ndim - len (axis ) :])
2962
- input = input .reshape (concatenate ([remaining_axis_size , [flattened_axis_size ]]))
2963
- axis = - 1
2955
+ new_axes_order = [i for i in range (x .ndim ) if i not in axis ] + list (axis )
2956
+ x = x .dimshuffle (new_axes_order )
2957
+ input_shape = shape (x )
2958
+ remaining_axis_size = input_shape [: x .ndim - len (axis )]
2959
+ x = x .reshape ((* remaining_axis_size , - 1 ))
2964
2960
2965
2961
# Sort the input tensor along the specified axis
2966
- sorted_input = input .sort (axis = axis )
2967
- input_shape = input .shape [axis ]
2962
+ sorted_input = x .sort (axis = - 1 )
2963
+ slices1 = [slice (None )] * sorted_input .ndim
2964
+ slices2 = [slice (None )] * sorted_input .ndim
2965
+ input_shape = x .shape [- 1 ]
2968
2966
2969
2967
if isinstance (q , (int | float )):
2970
2968
q = [q ]
@@ -2979,18 +2977,15 @@ def percentile(input, q, axis=None):
2979
2977
k_floor = floor (k ).astype ("int64" )
2980
2978
k_ceil = ceil (k ).astype ("int64" )
2981
2979
2982
- indices1 = expand_dims (
2983
- full_like (sorted_input .take (0 , axis = axis ), k_floor ), axis
2984
- )
2985
- indices2 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k_ceil ), axis )
2986
-
2987
- val1 = take_along_axis (sorted_input , indices1 , axis = axis )
2988
- val2 = take_along_axis (sorted_input , indices2 , axis = axis )
2980
+ slices1 [- 1 ] = slice (k_floor , k_floor + 1 )
2981
+ slices2 [- 1 ] = slice (k_ceil , k_ceil + 1 )
2982
+ val1 = sorted_input [tuple (slices1 )]
2983
+ val2 = sorted_input [tuple (slices2 )]
2989
2984
2990
2985
d = k - k_floor
2991
2986
percentile_val = val1 + d * (val2 - val1 )
2992
2987
2993
- result .append (percentile_val .squeeze (axis = axis ))
2988
+ result .append (percentile_val .squeeze (axis = - 1 ))
2994
2989
2995
2990
if len (result ) == 1 :
2996
2991
result = result [0 ]
0 commit comments