Skip to content

Commit 2138cd6

Browse files
committed
Restrict diag input ndim to 1 and 2 like numpy
1 parent 31d593d commit 2138cd6

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3630,7 +3630,7 @@ def diag(v, k=0):
36303630
A helper function for two ops: `ExtractDiag` and
36313631
`AllocDiag`. The name `diag` is meant to keep it consistent
36323632
with numpy. It both accepts tensor vector and tensor matrix.
3633-
While the passed tensor variable `v` has `v.ndim>=2`, it builds a
3633+
While the passed tensor variable `v` has `v.ndim==2`, it builds a
36343634
`ExtractDiag` instance, and returns a vector with its entries equal to
36353635
`v`'s main diagonal; otherwise if `v.ndim` is `1`, it builds an `AllocDiag`
36363636
instance, and returns a matrix with `v` at its k-th diaogonal.
@@ -3651,10 +3651,10 @@ def diag(v, k=0):
36513651

36523652
if _v.ndim == 1:
36533653
return AllocDiag(k)(_v)
3654-
elif _v.ndim >= 2:
3654+
elif _v.ndim == 2:
36553655
return diagonal(_v, offset=k)
36563656
else:
3657-
raise ValueError("Number of dimensions of `v` must be greater than one.")
3657+
raise ValueError("Input must be 1- or 2-d.")
36583658

36593659

36603660
def stacklists(arg):

tests/tensor/test_basic.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3593,17 +3593,18 @@ def test_diag(self):
35933593
# The right matrix is created
35943594
assert (r == v).all()
35953595

3596-
# Test scalar input
3597-
xx = scalar()
3598-
with pytest.raises(ValueError):
3599-
diag(xx)
3600-
36013596
# Test passing a list
36023597
xx = [[1, 2], [3, 4]]
36033598
g = diag(xx)
36043599
f = function([], g)
36053600
assert np.array_equal(f(), np.diag(xx))
36063601

3602+
@pytest.mark.parametrize("inp", (scalar, tensor3))
3603+
def test_diag_invalid_input_ndim(self, inp):
3604+
x = inp()
3605+
with pytest.raises(ValueError, match="Input must be 1- or 2-d."):
3606+
diag(x)
3607+
36073608

36083609
class TestExtractDiag:
36093610
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])

0 commit comments

Comments
 (0)