@@ -1588,20 +1588,20 @@ def median(input, axis=None):
1588
1588
-----
1589
1589
This function uses the numpy implementation of median.
1590
1590
"""
1591
+ from pytensor .ifelse import ifelse
1591
1592
1592
1593
input = as_tensor_variable (input )
1593
1594
sorted_input = input .sort (axis = axis )
1594
1595
shape = input .shape [axis ]
1595
1596
k = extract_constant (shape ) // 2
1596
- if extract_constant (shape % 2 ) == 0 :
1597
- indices1 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k - 1 ), axis )
1598
- indices2 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k ), axis )
1599
- ans1 = take_along_axis (sorted_input , indices1 , axis = axis )
1600
- ans2 = take_along_axis (sorted_input , indices2 , axis = axis )
1601
- median_val = (ans1 + ans2 ) / 2.0
1602
- else :
1603
- indices = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k ), axis )
1604
- median_val = take_along_axis (sorted_input , indices , axis = axis )
1597
+ indices1 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k - 1 ), axis )
1598
+ indices2 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k ), axis )
1599
+ ans1 = take_along_axis (sorted_input , indices1 , axis = axis )
1600
+ ans2 = take_along_axis (sorted_input , indices2 , axis = axis )
1601
+ median_val_even = (ans1 + ans2 ) / 2.0
1602
+ indices = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k ), axis )
1603
+ median_val_odd = take_along_axis (sorted_input , indices , axis = axis )
1604
+ median_val = ifelse (eq (mod (shape , 2 ), 0 ), median_val_even , median_val_odd )
1605
1605
median_val .name = "median"
1606
1606
return median_val .squeeze (axis = axis )
1607
1607
0 commit comments