Skip to content

Commit 71c58f3

Browse files
committed
Deprecate AllocDiag Op in favor of equivalent PyTensor graph
1 parent deea8dd commit 71c58f3

File tree

6 files changed

+52
-81
lines changed

6 files changed

+52
-81
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pytensor.tensor import get_vector_length
99
from pytensor.tensor.basic import (
1010
Alloc,
11-
AllocDiag,
1211
AllocEmpty,
1312
ARange,
1413
ExtractDiag,
@@ -32,16 +31,6 @@
3231
"""
3332

3433

35-
@jax_funcify.register(AllocDiag)
36-
def jax_funcify_AllocDiag(op, **kwargs):
37-
offset = op.offset
38-
39-
def allocdiag(v, offset=offset):
40-
return jnp.diag(v, k=offset)
41-
42-
return allocdiag
43-
44-
4534
@jax_funcify.register(AllocEmpty)
4635
def jax_funcify_AllocEmpty(op, **kwargs):
4736
def allocempty(*shape):

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytensor.link.utils import compile_function_src, unique_name_generator
88
from pytensor.tensor.basic import (
99
Alloc,
10-
AllocDiag,
1110
AllocEmpty,
1211
ARange,
1312
ExtractDiag,
@@ -93,17 +92,6 @@ def alloc(val, {", ".join(shape_var_names)}):
9392
return numba_basic.numba_njit(alloc_fn)
9493

9594

96-
@numba_funcify.register(AllocDiag)
97-
def numba_funcify_AllocDiag(op, **kwargs):
98-
offset = op.offset
99-
100-
@numba_basic.numba_njit(inline="always")
101-
def allocdiag(v):
102-
return np.diag(v, k=offset)
103-
104-
return allocdiag
105-
106-
10795
@numba_funcify.register(ARange)
10896
def numba_funcify_ARange(op, **kwargs):
10997
dtype = np.dtype(op.dtype)

pytensor/tensor/basic.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import builtins
9+
import warnings
910
from functools import partial
1011
from numbers import Number
1112
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
@@ -3450,7 +3451,7 @@ def grad(self, inputs, gout):
34503451
x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
34513452

34523453
# Fill zeros with output diagonal
3453-
xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz)
3454+
xdiag = alloc_diag(gz, offset=0, axis1=0, axis2=1)
34543455
z_len = xdiag.shape[0]
34553456
if offset >= 0:
34563457
diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
@@ -3544,6 +3545,10 @@ def __init__(self, offset=0, axis1=0, axis2=1):
35443545
Axis to be used as the second axis of the 2-D sub-arrays to which
35453546
the diagonals will be allocated. Defaults to second axis (i.e. 1).
35463547
"""
3548+
warnings.warn(
3549+
"AllocDiag is deprecated. Use `alloc_diag` instead",
3550+
FutureWarning,
3551+
)
35473552
self.offset = offset
35483553
if axis1 < 0 or axis2 < 0:
35493554
raise NotImplementedError("AllocDiag does not support negative axis")
@@ -3625,6 +3630,43 @@ def __setstate__(self, state):
36253630
self.axis2 = 1
36263631

36273632

3633+
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
3634+
"""Insert a vector on the diagonal of a zero-ed matrix.
3635+
3636+
diagonal(alloc_diag(x)) == x
3637+
"""
3638+
from pytensor.tensor import set_subtensor
3639+
3640+
diag = as_tensor_variable(diag)
3641+
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
3642+
if axis1 > axis2:
3643+
axis1, axis2 = axis2, axis1
3644+
3645+
# Create array with one extra dimension for resulting matrix
3646+
result_shape = tuple(diag.shape)[:-1] + (diag.shape[-1] + abs(offset),) * 2
3647+
result = zeros(result_shape, dtype=diag.dtype)
3648+
3649+
# Create slice for diagonal in final 2 axes
3650+
idxs = arange(diag.shape[-1])
3651+
diagonal_slice = (slice(None),) * (len(result_shape) - 2) + (
3652+
idxs + np.maximum(0, -offset),
3653+
idxs + np.maximum(0, offset),
3654+
)
3655+
3656+
# Fill in final 2 axes with diag
3657+
result = set_subtensor(result[diagonal_slice], diag)
3658+
3659+
if diag.type.ndim > 1:
3660+
# Re-order axes so they correspond to diagonals at axis1, axis2
3661+
axes = list(range(diag.type.ndim - 1))
3662+
last_idx = axes[-1]
3663+
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
3664+
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
3665+
result = result.transpose(axes)
3666+
3667+
return result
3668+
3669+
36283670
def diag(v, k=0):
36293671
"""
36303672
A helper function for two ops: `ExtractDiag` and
@@ -3650,7 +3692,7 @@ def diag(v, k=0):
36503692
_v = as_tensor_variable(v)
36513693

36523694
if _v.ndim == 1:
3653-
return AllocDiag(k)(_v)
3695+
return alloc_diag(_v, offset=k)
36543696
elif _v.ndim == 2:
36553697
return diagonal(_v, offset=k)
36563698
else:

tests/link/jax/test_slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_jax_basic():
8585
],
8686
)
8787

88-
out = at.diag(b)
88+
out = at.diag(at.specify_shape(b, shape=(10,)))
8989
out_fg = FunctionGraph([b], [out])
9090
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
9191

tests/link/numba/test_tensor_basic.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,6 @@ def test_AllocEmpty():
5757
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
5858

5959

60-
@pytest.mark.parametrize(
61-
"v, offset",
62-
[
63-
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0),
64-
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 1),
65-
(set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), -1),
66-
],
67-
)
68-
def test_AllocDiag(v, offset):
69-
g = atb.AllocDiag(offset=offset)(v)
70-
g_fg = FunctionGraph(outputs=[g])
71-
72-
compare_numba_and_py(
73-
g_fg,
74-
[
75-
i.tag.test_value
76-
for i in g_fg.inputs
77-
if not isinstance(i, (SharedVariable, Constant))
78-
],
79-
)
80-
81-
8260
@pytest.mark.parametrize(
8361
"v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))]
8462
)

tests/tensor/test_basic.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from pytensor.tensor import NoneConst
2424
from pytensor.tensor.basic import (
2525
Alloc,
26-
AllocDiag,
2726
AllocEmpty,
2827
ARange,
2928
Choose,
@@ -92,7 +91,7 @@
9291
from pytensor.tensor.exceptions import NotScalarConstantError
9392
from pytensor.tensor.math import dense_dot
9493
from pytensor.tensor.math import sum as at_sum
95-
from pytensor.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
94+
from pytensor.tensor.shape import Reshape, Shape_i, shape_padright, specify_shape
9695
from pytensor.tensor.type import (
9796
TensorType,
9897
bscalar,
@@ -3571,7 +3570,6 @@ def test_diag(self):
35713570
# test vector input
35723571
x = vector()
35733572
g = diag(x)
3574-
assert isinstance(g.owner.op, AllocDiag)
35753573
f = pytensor.function([x], g)
35763574
for shp in [5, 0, 1]:
35773575
m = rng.random(shp).astype(self.floatX)
@@ -3654,10 +3652,6 @@ def test_grad_3d(self, offset, axis1, axis2):
36543652
class TestAllocDiag:
36553653
# TODO: Separate perform, grad and infer_shape tests
36563654

3657-
def setup_method(self):
3658-
self.alloc_diag = AllocDiag
3659-
self.mode = pytensor.compile.mode.get_default_mode()
3660-
36613655
def _generator(self):
36623656
dims = 4
36633657
shape = (5,) * dims
@@ -3690,34 +3684,28 @@ def test_alloc_diag_values(self):
36903684
# Test perform
36913685
if np.maximum(axis1, axis2) > len(test_val.shape):
36923686
continue
3693-
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2)
3694-
f = pytensor.function([x], adiag_op(x))
3695-
# AllocDiag and extract the diagonal again
3696-
# to check
3687+
diag_x = at.alloc_diag(x, offset=offset, axis1=axis1, axis2=axis2)
3688+
f = pytensor.function([x], diag_x)
3689+
# alloc_diag and extract the diagonal again to check for correctness
36973690
diag_arr = f(test_val)
36983691
rediag = np.diagonal(diag_arr, offset=offset, axis1=axis1, axis2=axis2)
36993692
assert np.all(rediag == test_val)
37003693

37013694
# Test infer_shape
3702-
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN")
3695+
f_shape = pytensor.function([x], diag_x.shape, mode="FAST_RUN")
37033696

37043697
output_shape = f_shape(test_val)
3705-
assert not any(
3706-
isinstance(node.op, self.alloc_diag)
3707-
for node in f_shape.maker.fgraph.toposort()
3708-
)
37093698
rediag_shape = np.diagonal(
37103699
np.ones(output_shape), offset=offset, axis1=axis1, axis2=axis2
37113700
).shape
37123701
assert np.all(rediag_shape == test_val.shape)
37133702

37143703
# Test grad
3715-
diag_x = adiag_op(x)
37163704
sum_diag_x = at_sum(diag_x)
37173705
grad_x = pytensor.grad(sum_diag_x, x)
37183706
grad_diag_x = pytensor.grad(sum_diag_x, diag_x)
3719-
f_grad_x = pytensor.function([x], grad_x, mode=self.mode)
3720-
f_grad_diag_x = pytensor.function([x], grad_diag_x, mode=self.mode)
3707+
f_grad_x = pytensor.function([x], grad_x)
3708+
f_grad_diag_x = pytensor.function([x], grad_diag_x)
37213709
grad_input = f_grad_x(test_val)
37223710
grad_diag_input = f_grad_diag_x(test_val)
37233711
true_grad_input = np.diagonal(
@@ -3894,20 +3882,6 @@ def test_ExtractDiag(self):
38943882
atens3_diag = ExtractDiag(1, 2, 0)(atens3)
38953883
self._compile_and_check([atens3], [atens3_diag], [atens3_val], ExtractDiag)
38963884

3897-
def test_AllocDiag(self):
3898-
advec = dvector()
3899-
advec_val = random(4)
3900-
self._compile_and_check([advec], [AllocDiag()(advec)], [advec_val], AllocDiag)
3901-
3902-
# Shape
3903-
# 'opt.Makevector' precludes optimizer from disentangling
3904-
# elements of shape
3905-
adtens = tensor3()
3906-
adtens_val = random(4, 5, 3)
3907-
self._compile_and_check(
3908-
[adtens], [Shape()(adtens)], [adtens_val], (MakeVector, Shape)
3909-
)
3910-
39113885
def test_Split(self):
39123886
aiscal = iscalar()
39133887
aivec = ivector()

0 commit comments

Comments
 (0)