Skip to content

Commit 17e6aa3

Browse files
Add quantile and quantile dependent percentile
1 parent 9b7611c commit 17e6aa3

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

pytensor/tensor/math.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,18 +2926,18 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
29262926
return x
29272927

29282928

2929-
def percentile(input, q, axis=None):
2929+
def quantile(input, q, axis=None):
29302930
"""
2931-
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
2931+
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
29322932
29332933
Parameters
29342934
----------
29352935
input: TensorVariable
29362936
The input tensor.
29372937
q: float or list of floats
2938-
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
2938+
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
29392939
axis: None or int or list of int, optional
2940-
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2940+
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
29412941
"""
29422942
x = as_tensor_variable(input)
29432943
x_ndim = x.type.ndim
@@ -2967,13 +2967,13 @@ def percentile(input, q, axis=None):
29672967
if isinstance(q, (int | float)):
29682968
q = [q]
29692969

2970-
for percentile in q:
2971-
if percentile < 0 or percentile > 100:
2972-
raise ValueError("Percentiles must be in the range [0, 100]")
2970+
for quantile in q:
2971+
if quantile < 0 or quantile > 1:
2972+
raise ValueError("Quantiles must be in the range [0, 1]")
29732973

29742974
result = []
2975-
for percentile in q:
2976-
k = (percentile / 100.0) * (input_shape - 1)
2975+
for quantile in q:
2976+
k = (quantile) * (input_shape - 1)
29772977
k_floor = floor(k).astype("int64")
29782978
k_ceil = ceil(k).astype("int64")
29792979

@@ -2983,42 +2983,42 @@ def percentile(input, q, axis=None):
29832983
val2 = sorted_input[tuple(slices2)]
29842984

29852985
d = k - k_floor
2986-
percentile_val = val1 + d * (val2 - val1)
2986+
quantile_val = val1 + d * (val2 - val1)
29872987

2988-
result.append(percentile_val.squeeze(axis=-1))
2988+
result.append(quantile_val.squeeze(axis=-1))
29892989

29902990
if len(result) == 1:
29912991
result = result[0]
29922992
else:
29932993
result = stack(result)
29942994

2995-
result.name = "percentile"
2995+
result.name = "quantile"
29962996
return result
29972997

29982998

2999-
def quantile(input, q, axis=None):
2999+
def percentile(input, q, axis=None):
30003000
"""
3001-
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
3001+
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
30023002
30033003
Parameters
30043004
----------
30053005
input: TensorVariable
30063006
The input tensor.
30073007
q: float or list of floats
3008-
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
3008+
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
30093009
axis: None or int or list of int, optional
3010-
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
3010+
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
30113011
"""
30123012
if isinstance(q, (int | float)):
30133013
q = [q]
30143014

3015-
for quantile in q:
3016-
if quantile < 0 or quantile > 1:
3017-
raise ValueError("Quantiles must be in the range [0, 1]")
3015+
for percentile in q:
3016+
if percentile < 0 or percentile > 100:
3017+
raise ValueError("Percentiles must be in the range [0, 100]")
30183018

3019-
percentiles = [100.0 * x for x in q]
3019+
quantiles = [x / 100 for x in q]
30203020

3021-
return percentile(input, percentiles, axis)
3021+
return quantile(input, quantiles, axis)
30223022

30233023

30243024
# NumPy logical aliases

0 commit comments

Comments
 (0)