Skip to content

Commit c3bda8a

Browse files
Add quantile and quantile dependent percentile
1 parent 28c1529 commit c3bda8a

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

pytensor/tensor/math.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,18 +2870,18 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
28702870
return x
28712871

28722872

2873-
def percentile(input, q, axis=None):
2873+
def quantile(input, q, axis=None):
28742874
"""
2875-
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
2875+
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
28762876
28772877
Parameters
28782878
----------
28792879
input: TensorVariable
28802880
The input tensor.
28812881
q: float or list of floats
2882-
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
2882+
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
28832883
axis: None or int or list of int, optional
2884-
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2884+
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
28852885
"""
28862886
x = as_tensor_variable(input)
28872887
x_ndim = x.type.ndim
@@ -2911,13 +2911,13 @@ def percentile(input, q, axis=None):
29112911
if isinstance(q, (int | float)):
29122912
q = [q]
29132913

2914-
for percentile in q:
2915-
if percentile < 0 or percentile > 100:
2916-
raise ValueError("Percentiles must be in the range [0, 100]")
2914+
for quantile in q:
2915+
if quantile < 0 or quantile > 1:
2916+
raise ValueError("Quantiles must be in the range [0, 1]")
29172917

29182918
result = []
2919-
for percentile in q:
2920-
k = (percentile / 100.0) * (input_shape - 1)
2919+
for quantile in q:
2920+
k = (quantile) * (input_shape - 1)
29212921
k_floor = floor(k).astype("int64")
29222922
k_ceil = ceil(k).astype("int64")
29232923

@@ -2927,42 +2927,42 @@ def percentile(input, q, axis=None):
29272927
val2 = sorted_input[tuple(slices2)]
29282928

29292929
d = k - k_floor
2930-
percentile_val = val1 + d * (val2 - val1)
2930+
quantile_val = val1 + d * (val2 - val1)
29312931

2932-
result.append(percentile_val.squeeze(axis=-1))
2932+
result.append(quantile_val.squeeze(axis=-1))
29332933

29342934
if len(result) == 1:
29352935
result = result[0]
29362936
else:
29372937
result = stack(result)
29382938

2939-
result.name = "percentile"
2939+
result.name = "quantile"
29402940
return result
29412941

29422942

2943-
def quantile(input, q, axis=None):
2943+
def percentile(input, q, axis=None):
29442944
"""
2945-
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
2945+
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
29462946
29472947
Parameters
29482948
----------
29492949
input: TensorVariable
29502950
The input tensor.
29512951
q: float or list of floats
2952-
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2952+
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
29532953
axis: None or int or list of int, optional
2954-
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2954+
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
29552955
"""
29562956
if isinstance(q, (int | float)):
29572957
q = [q]
29582958

2959-
for quantile in q:
2960-
if quantile < 0 or quantile > 1:
2961-
raise ValueError("Quantiles must be in the range [0, 1]")
2959+
for percentile in q:
2960+
if percentile < 0 or percentile > 100:
2961+
raise ValueError("Percentiles must be in the range [0, 100]")
29622962

2963-
percentiles = [100.0 * x for x in q]
2963+
quantiles = [x / 100 for x in q]
29642964

2965-
return percentile(input, percentiles, axis)
2965+
return quantile(input, quantiles, axis)
29662966

29672967

29682968
# NumPy logical aliases

0 commit comments

Comments
 (0)