Skip to content

Commit c1cb563

Browse files
Support None
1 parent 394c2fd commit c1cb563

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

pytensor/tensor/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,10 @@ def median(input, axis=None):
15901590
"""
15911591
from pytensor.ifelse import ifelse
15921592

1593+
if axis is None:
1594+
input = input.flatten()
1595+
axis = 0
1596+
15931597
input = as_tensor_variable(input)
15941598
sorted_input = input.sort(axis=axis)
15951599
shape = input.shape[axis]

tests/tensor/test_math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3738,11 +3738,13 @@ def test_nan_to_num(nan, posinf, neginf):
37383738
"data, axis",
37393739
[
37403740
# 1D array
3741+
([1, 7, 3, 6, 5, 2, 4], None),
37413742
([1, 7, 3, 6, 5, 2, 4], 0),
37423743
# 2D array
37433744
([[6, 2], [4, 3], [1, 5]], 0),
37443745
([[6, 2], [4, 3], [1, 5]], 1),
37453746
# 3D array
3747+
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], None),
37463748
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
37473749
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
37483750
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),

0 commit comments

Comments
 (0)