Skip to content

Commit 9b7611c

Browse files
merge median
1 parent bc9e7c7 commit 9b7611c

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

pytensor/tensor/math.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29-
full_like,
3029
stack,
3130
switch,
32-
take_along_axis,
3331
)
3432
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3533
from pytensor.tensor.elemwise import (
@@ -2941,11 +2939,11 @@ def percentile(input, q, axis=None):
29412939
axis: None or int or list of int, optional
29422940
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
29432941
"""
2944-
input = as_tensor_variable(input)
2945-
input_ndim = input.type.ndim
2942+
x = as_tensor_variable(input)
2943+
x_ndim = x.type.ndim
29462944

29472945
if axis is None:
2948-
axis = list(range(input_ndim))
2946+
axis = list(range(x_ndim))
29492947
elif isinstance(axis, (int | np.integer)):
29502948
axis = [axis]
29512949
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
@@ -2954,17 +2952,17 @@ def percentile(input, q, axis=None):
29542952
axis = [int(a) for a in axis]
29552953

29562954
# Compute the shape of the remaining axes
2957-
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
2958-
input = input.dimshuffle(new_axes_order)
2959-
input_shape = shape(input)
2960-
remaining_axis_size = input_shape[: input.ndim - len(axis)]
2961-
flattened_axis_size = prod(input_shape[input.ndim - len(axis) :])
2962-
input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
2963-
axis = -1
2955+
new_axes_order = [i for i in range(x.ndim) if i not in axis] + list(axis)
2956+
x = x.dimshuffle(new_axes_order)
2957+
input_shape = shape(x)
2958+
remaining_axis_size = input_shape[: x.ndim - len(axis)]
2959+
x = x.reshape((*remaining_axis_size, -1))
29642960

29652961
# Sort the input tensor along the specified axis
2966-
sorted_input = input.sort(axis=axis)
2967-
input_shape = input.shape[axis]
2962+
sorted_input = x.sort(axis=-1)
2963+
slices1 = [slice(None)] * sorted_input.ndim
2964+
slices2 = [slice(None)] * sorted_input.ndim
2965+
input_shape = x.shape[-1]
29682966

29692967
if isinstance(q, (int | float)):
29702968
q = [q]
@@ -2979,18 +2977,15 @@ def percentile(input, q, axis=None):
29792977
k_floor = floor(k).astype("int64")
29802978
k_ceil = ceil(k).astype("int64")
29812979

2982-
indices1 = expand_dims(
2983-
full_like(sorted_input.take(0, axis=axis), k_floor), axis
2984-
)
2985-
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis)
2986-
2987-
val1 = take_along_axis(sorted_input, indices1, axis=axis)
2988-
val2 = take_along_axis(sorted_input, indices2, axis=axis)
2980+
slices1[-1] = slice(k_floor, k_floor + 1)
2981+
slices2[-1] = slice(k_ceil, k_ceil + 1)
2982+
val1 = sorted_input[tuple(slices1)]
2983+
val2 = sorted_input[tuple(slices2)]
29892984

29902985
d = k - k_floor
29912986
percentile_val = val1 + d * (val2 - val1)
29922987

2993-
result.append(percentile_val.squeeze(axis=axis))
2988+
result.append(percentile_val.squeeze(axis=-1))
29942989

29952990
if len(result) == 1:
29962991
result = result[0]

tests/tensor/test_math.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@
102102
neg,
103103
neq,
104104
outer,
105+
percentile,
105106
polygamma,
106107
power,
107108
ptp,
109+
quantile,
108110
rad2deg,
109111
reciprocal,
110112
round_half_away_from_zero,
@@ -3766,3 +3768,69 @@ def test_median(ndim, axis):
37663768

37673769
assert np.allclose(result_odd, expected_odd)
37683770
assert np.allclose(result_even, expected_even)
3771+
3772+
3773+
@pytest.mark.parametrize(
3774+
"ndim, axis, q",
3775+
[
3776+
(2, None, 50),
3777+
(2, 1, 33),
3778+
(2, (0, 1), 50),
3779+
(3, (1, 2), 50),
3780+
(4, (1, 3, 0), 25),
3781+
(2, None, [25, 50, 75]),
3782+
(3, (1, 2), [10, 90]),
3783+
(3, 1, 75),
3784+
(3, 0, 50),
3785+
],
3786+
)
3787+
def test_percentile(ndim, axis, q):
3788+
shape = tuple(np.arange(1, ndim + 1))
3789+
data = np.random.rand(*shape)
3790+
x = tensor(shape=np.array(data).shape)
3791+
f = function([x], percentile(x, q, axis=axis))
3792+
result = f(data.astype(x.dtype))
3793+
expected = np.percentile(data.astype(x.dtype), q, axis=axis)
3794+
assert np.allclose(result, expected)
3795+
3796+
3797+
@pytest.mark.parametrize(
3798+
"ndim, axis, q",
3799+
[
3800+
(2, None, 0.5),
3801+
(2, None, [0.25, 0.75]),
3802+
(2, 0, 0.5),
3803+
(2, (0, 1), 0.5),
3804+
(3, None, 0.5),
3805+
(3, None, [0.25, 0.75]),
3806+
(3, 0, 0.5),
3807+
(3, (1, 2), 0.5),
3808+
],
3809+
)
3810+
def test_quantile(ndim, axis, q):
3811+
shape = tuple(np.random.randint(2, 6) for _ in range(ndim))
3812+
data = np.random.rand(*shape)
3813+
3814+
x = tensor(dtype="float64", shape=(None,) * ndim)
3815+
f = function([x], quantile(x, q, axis=axis))
3816+
3817+
result = f(data.astype(x.dtype))
3818+
expected = np.quantile(data.astype(x.dtype), q, axis=axis)
3819+
3820+
assert np.allclose(result, expected)
3821+
3822+
3823+
@pytest.mark.parametrize(
3824+
"ndim, axis, q, is_percentile",
3825+
[
3826+
(2, None, [50, 120], True),
3827+
(2, 1, -0.5, False),
3828+
],
3829+
)
3830+
def test_invalid_percentile_quantile(ndim, axis, q, is_percentile):
3831+
x = tensor(dtype="float64", shape=(None,) * ndim)
3832+
with pytest.raises(ValueError):
3833+
if is_percentile:
3834+
percentile(x, q, axis)
3835+
else:
3836+
quantile(x, q, axis)

0 commit comments

Comments
 (0)