Skip to content

Commit deea8dd

Browse files
committed
Fully support ExtractDiag in numba
1 parent 2138cd6 commit deea8dd

File tree

2 files changed

+65
-8
lines changed

2 files changed

+65
-8
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,44 @@ def split(tensor, axis, indices):
150150

151151

152152
@numba_funcify.register(ExtractDiag)
153-
def numba_funcify_ExtractDiag(op, **kwargs):
154-
offset = op.offset
155-
# axis1 = op.axis1
156-
# axis2 = op.axis2
157-
158-
@numba_basic.numba_njit(inline="always")
159-
def extract_diag(x):
160-
return np.diag(x, k=offset)
153+
def numba_funcify_ExtractDiag(op, node, **kwargs):
154+
view = op.view
155+
axis1, axis2, offset = op.axis1, op.axis2, op.offset
156+
157+
if node.inputs[0].type.ndim == 2:
158+
159+
@numba_basic.numba_njit(inline="always")
160+
def extract_diag(x):
161+
out = np.diag(x, k=offset)
162+
163+
if not view:
164+
out = out.copy()
165+
166+
return out
167+
168+
else:
169+
axis1p1 = axis1 + 1
170+
axis2p1 = axis2 + 1
171+
leading_dims = (slice(None),) * axis1
172+
middle_dims = (slice(None),) * (axis2 - axis1 - 1)
173+
174+
@numba_basic.numba_njit(inline="always")
175+
def extract_diag(x):
176+
if offset >= 0:
177+
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset))
178+
else:
179+
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
180+
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
181+
out_shape = base_shape + (diag_len,)
182+
out = np.empty(out_shape)
183+
184+
for i in range(diag_len):
185+
if offset >= 0:
186+
new_entry = x[leading_dims + (i,) + middle_dims + (i + offset,)]
187+
else:
188+
new_entry = x[leading_dims + (i - offset,) + middle_dims + (i,)]
189+
out[..., i] = new_entry
190+
return out
161191

162192
return extract_diag
163193

tests/link/numba/test_tensor_basic.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
)
1818

1919

20+
pytest.importorskip("numba")
21+
from pytensor.link.numba.dispatch import numba_funcify
22+
23+
2024
rng = np.random.default_rng(42849)
2125

2226

@@ -366,6 +370,12 @@ def test_Split_view():
366370
),
367371
0,
368372
),
373+
(
374+
set_test_value(
375+
at.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
376+
),
377+
-1,
378+
),
369379
(
370380
set_test_value(at.vector(), np.arange(10, dtype=config.floatX)),
371381
0,
@@ -386,6 +396,23 @@ def test_ExtractDiag(val, offset):
386396
)
387397

388398

399+
@pytest.mark.parametrize("k", range(-5, 4))
400+
@pytest.mark.parametrize(
401+
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3))
402+
)
403+
@pytest.mark.parametrize("reverse_axis", (False, True))
404+
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
405+
if reverse_axis:
406+
axis1, axis2 = axis2, axis1
407+
408+
x = at.tensor4("x")
409+
x_shape = (2, 3, 4, 5)
410+
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
411+
out = at.diagonal(x, k, axis1, axis2)
412+
numba_fn = numba_funcify(out.owner.op, out.owner)
413+
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))
414+
415+
389416
@pytest.mark.parametrize(
390417
"n, m, k, dtype",
391418
[

0 commit comments

Comments
 (0)