Skip to content

Commit 394c2fd

Browse files
Modified median logic to use more suitable
1 parent 1c11dbe commit 394c2fd

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pytensor/tensor/math.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,20 +1588,20 @@ def median(input, axis=None):
15881588
-----
15891589
This function uses the numpy implementation of median.
15901590
"""
1591+
from pytensor.ifelse import ifelse
15911592

15921593
input = as_tensor_variable(input)
15931594
sorted_input = input.sort(axis=axis)
15941595
shape = input.shape[axis]
15951596
k = extract_constant(shape) // 2
1596-
if extract_constant(shape % 2) == 0:
1597-
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
1598-
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1599-
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
1600-
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
1601-
median_val = (ans1 + ans2) / 2.0
1602-
else:
1603-
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1604-
median_val = take_along_axis(sorted_input, indices, axis=axis)
1597+
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
1598+
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1599+
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
1600+
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
1601+
median_val_even = (ans1 + ans2) / 2.0
1602+
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
1603+
median_val_odd = take_along_axis(sorted_input, indices, axis=axis)
1604+
median_val = ifelse(eq(mod(shape, 2), 0), median_val_even, median_val_odd)
16051605
median_val.name = "median"
16061606
return median_val.squeeze(axis=axis)
16071607

0 commit comments

Comments
 (0)