Skip to content

Commit c011572

Browse files
committed
Validate axis in AllocDiag and ExtractDiag
1 parent ed43f02 commit c011572

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3408,6 +3408,12 @@ def __init__(self, offset=0, axis1=0, axis2=1, view=False):
34083408
if self.view:
34093409
self.view_map = {0: [0]}
34103410
self.offset = offset
3411+
if axis1 < 0 or axis2 < 0:
3412+
raise NotImplementedError(
3413+
"ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead."
3414+
)
3415+
if axis1 == axis2:
3416+
raise ValueError("axis1 and axis2 cannot be the same")
34113417
self.axis1 = axis1
34123418
self.axis2 = axis2
34133419

@@ -3502,6 +3508,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
35023508
tensor : symbolic tensor
35033509
35043510
"""
3511+
a = as_tensor_variable(a)
3512+
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=a.type.ndim)
35053513
return ExtractDiag(offset, axis1, axis2)(a)
35063514

35073515

@@ -3529,6 +3537,10 @@ def __init__(self, offset=0, axis1=0, axis2=1):
35293537
the diagonals will be allocated. Defaults to second axis (i.e. 1).
35303538
"""
35313539
self.offset = offset
3540+
if axis1 < 0 or axis2 < 0:
3541+
raise NotImplementedError("AllocDiag does not support negative axis")
3542+
if axis1 == axis2:
3543+
raise ValueError("axis1 and axis2 cannot be the same")
35323544
self.axis1 = axis1
35333545
self.axis2 = axis2
35343546

tests/tensor/test_basic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3714,6 +3714,14 @@ def test_alloc_diag_values(self):
37143714
assert np.all(true_grad_input == grad_input)
37153715

37163716

3717+
def test_diagonal_negative_axis():
3718+
x = np.arange(2 * 3 * 3).reshape((2, 3, 3))
3719+
np.testing.assert_allclose(
3720+
at.diagonal(x, axis1=-1, axis2=-2).eval(),
3721+
np.diagonal(x, axis1=-1, axis2=-2),
3722+
)
3723+
3724+
37173725
def test_transpose():
37183726
x1 = dvector("x1")
37193727
x2 = dmatrix("x2")

0 commit comments

Comments
 (0)