Skip to content

Commit 31d593d

Browse files
committed
Support all gradient cases for ExtractDiag
Also fixes wrong gradient for negative offsets
1 parent c011572 commit 31d593d

File tree

2 files changed

+67
-48
lines changed

2 files changed

+67
-48
lines changed

pytensor/tensor/basic.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import builtins
9-
import warnings
109
from functools import partial
1110
from numbers import Number
1211
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
@@ -20,7 +19,7 @@
2019
import pytensor.scalar.sharedvar
2120
from pytensor import compile, config, printing
2221
from pytensor import scalar as aes
23-
from pytensor.gradient import DisconnectedType, grad_not_implemented, grad_undefined
22+
from pytensor.gradient import DisconnectedType, grad_undefined
2423
from pytensor.graph.basic import Apply, Constant, Variable
2524
from pytensor.graph.fg import FunctionGraph
2625
from pytensor.graph.op import Op
@@ -3407,15 +3406,18 @@ def __init__(self, offset=0, axis1=0, axis2=1, view=False):
34073406
self.view = view
34083407
if self.view:
34093408
self.view_map = {0: [0]}
3410-
self.offset = offset
34113409
if axis1 < 0 or axis2 < 0:
34123410
raise NotImplementedError(
34133411
"ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead."
34143412
)
34153413
if axis1 == axis2:
34163414
raise ValueError("axis1 and axis2 cannot be the same")
3415+
# Sort axis
3416+
if axis1 > axis2:
3417+
axis1, axis2, offset = axis2, axis1, -offset
34173418
self.axis1 = axis1
34183419
self.axis2 = axis2
3420+
self.offset = offset
34193421

34203422
def make_node(self, x):
34213423
x = as_tensor_variable(x)
@@ -3436,20 +3438,29 @@ def perform(self, node, inputs, outputs):
34363438
z[0] = z[0].copy()
34373439

34383440
def grad(self, inputs, gout):
3441+
# Avoid circular import
3442+
from pytensor.tensor.subtensor import set_subtensor
3443+
34393444
(x,) = inputs
34403445
(gz,) = gout
34413446

3442-
if x.ndim == 2:
3443-
x = zeros_like(x)
3444-
xdiag = AllocDiag(offset=self.offset)(gz)
3445-
return [
3446-
pytensor.tensor.subtensor.set_subtensor(
3447-
x[: xdiag.shape[0], : xdiag.shape[1]], xdiag
3448-
)
3449-
]
3447+
axis1, axis2, offset = self.axis1, self.axis2, self.offset
3448+
3449+
# Start with zeros (and axes in the front)
3450+
x_grad = zeros_like(moveaxis(x, (axis1, axis2), (0, 1)))
3451+
3452+
# Fill zeros with output diagonal
3453+
xdiag = AllocDiag(offset=0, axis1=0, axis2=1)(gz)
3454+
z_len = xdiag.shape[0]
3455+
if offset >= 0:
3456+
diag_slices = (slice(None, z_len), slice(offset, offset + z_len))
34503457
else:
3451-
warnings.warn("Gradient of ExtractDiag only works for matrices.")
3452-
return [grad_not_implemented(self, 0, x)]
3458+
diag_slices = (slice(abs(offset), abs(offset) + z_len), slice(None, z_len))
3459+
x_grad = set_subtensor(x_grad[diag_slices], xdiag)
3460+
3461+
# Put axes back in their original positions
3462+
x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2))
3463+
return [x_grad]
34533464

34543465
def infer_shape(self, fgraph, node, shapes):
34553466
from pytensor.tensor.math import clip, minimum
@@ -3514,10 +3525,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
35143525

35153526

35163527
class AllocDiag(Op):
3517-
"""An `Op` that copies a vector to the diagonal of an empty matrix.
3518-
3519-
It does the inverse of `ExtractDiag`.
3520-
"""
3528+
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
35213529

35223530
__props__ = ("offset", "axis1", "axis2")
35233531

tests/tensor/test_basic.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3552,16 +3552,10 @@ class TestDiag:
35523552
"""
35533553
Test that linalg.diag has the same behavior as numpy.diag.
35543554
numpy.diag has two behaviors:
3555-
(1) when given a vector, it returns a matrix with that vector as the
3556-
diagonal.
3557-
(2) when given a matrix, returns a vector which is the diagonal of the
3558-
matrix.
3555+
(1) when given a vector, it returns a matrix with that vector as the diagonal.
3556+
(2) when given a matrix, returns a vector which is the diagonal of the matrix.
35593557
3560-
(1) and (2) are tested by test_alloc_diag and test_extract_diag
3561-
respectively.
3562-
3563-
test_diag test makes sure that linalg.diag instantiates
3564-
the right op based on the dimension of the input.
3558+
(1) and (2) are further tested by TestAllocDiag and TestExtractDiag, respectively.
35653559
"""
35663560

35673561
def setup_method(self):
@@ -3571,6 +3565,7 @@ def setup_method(self):
35713565
self.type = TensorType
35723566

35733567
def test_diag(self):
3568+
"""Makes sure that diag instantiates the right op based on the dimension of the input."""
35743569
rng = np.random.default_rng(utt.fetch_seed())
35753570

35763571
# test vector input
@@ -3609,38 +3604,55 @@ def test_diag(self):
36093604
f = function([], g)
36103605
assert np.array_equal(f(), np.diag(xx))
36113606

3612-
def test_infer_shape(self):
3607+
3608+
class TestExtractDiag:
3609+
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
3610+
@pytest.mark.parametrize("offset", (-1, 0, 2))
3611+
def test_infer_shape(self, offset, axis1, axis2):
36133612
rng = np.random.default_rng(utt.fetch_seed())
36143613

3615-
x = vector()
3616-
g = diag(x)
3617-
f = pytensor.function([x], g.shape)
3618-
topo = f.maker.fgraph.toposort()
3619-
if config.mode != "FAST_COMPILE":
3620-
assert sum(isinstance(node.op, AllocDiag) for node in topo) == 0
3621-
for shp in [5, 0, 1]:
3622-
m = rng.random(shp).astype(self.floatX)
3623-
assert (f(m) == np.diag(m).shape).all()
3624-
3625-
x = matrix()
3626-
g = diag(x)
3614+
x = matrix("x")
3615+
g = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)(x)
36273616
f = pytensor.function([x], g.shape)
36283617
topo = f.maker.fgraph.toposort()
36293618
if config.mode != "FAST_COMPILE":
36303619
assert sum(isinstance(node.op, ExtractDiag) for node in topo) == 0
36313620
for shp in [(5, 3), (3, 5), (5, 1), (1, 5), (5, 0), (0, 5), (1, 0), (0, 1)]:
3632-
m = rng.random(shp).astype(self.floatX)
3633-
assert (f(m) == np.diag(m).shape).all()
3621+
m = rng.random(shp).astype(config.floatX)
3622+
assert (
3623+
f(m) == np.diagonal(m, offset=offset, axis1=axis1, axis2=axis2).shape
3624+
).all()
36343625

3635-
def test_diag_grad(self):
3626+
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
3627+
@pytest.mark.parametrize("offset", (0, 1, -1))
3628+
def test_grad_2d(self, offset, axis1, axis2):
3629+
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
36363630
rng = np.random.default_rng(utt.fetch_seed())
3637-
x = rng.random(5)
3638-
utt.verify_grad(diag, [x], rng=rng)
36393631
x = rng.random((5, 3))
3640-
utt.verify_grad(diag, [x], rng=rng)
3632+
utt.verify_grad(diag_fn, [x], rng=rng)
3633+
3634+
@pytest.mark.parametrize(
3635+
"axis1, axis2",
3636+
[
3637+
(0, 1),
3638+
(1, 0),
3639+
(1, 2),
3640+
(2, 1),
3641+
(0, 2),
3642+
(2, 0),
3643+
],
3644+
)
3645+
@pytest.mark.parametrize("offset", (0, 1, -1))
3646+
def test_grad_3d(self, offset, axis1, axis2):
3647+
diag_fn = ExtractDiag(offset=offset, axis1=axis1, axis2=axis2)
3648+
rng = np.random.default_rng(utt.fetch_seed())
3649+
x = rng.random((5, 4, 3))
3650+
utt.verify_grad(diag_fn, [x], rng=rng)
36413651

36423652

36433653
class TestAllocDiag:
3654+
# TODO: Separate perform, grad and infer_shape tests
3655+
36443656
def setup_method(self):
36453657
self.alloc_diag = AllocDiag
36463658
self.mode = pytensor.compile.mode.get_default_mode()
@@ -3674,7 +3686,7 @@ def test_alloc_diag_values(self):
36743686
(-2, 0, 1),
36753687
(-1, 1, 2),
36763688
]:
3677-
# Test AllocDiag values
3689+
# Test perform
36783690
if np.maximum(axis1, axis2) > len(test_val.shape):
36793691
continue
36803692
adiag_op = self.alloc_diag(offset=offset, axis1=axis1, axis2=axis2)
@@ -3688,7 +3700,6 @@ def test_alloc_diag_values(self):
36883700
# Test infer_shape
36893701
f_shape = pytensor.function([x], adiag_op(x).shape, mode="FAST_RUN")
36903702

3691-
# pytensor.printing.debugprint(f_shape.maker.fgraph.outputs[0])
36923703
output_shape = f_shape(test_val)
36933704
assert not any(
36943705
isinstance(node.op, self.alloc_diag)
@@ -3699,6 +3710,7 @@ def test_alloc_diag_values(self):
36993710
).shape
37003711
assert np.all(rediag_shape == test_val.shape)
37013712

3713+
# Test grad
37023714
diag_x = adiag_op(x)
37033715
sum_diag_x = at_sum(diag_x)
37043716
grad_x = pytensor.grad(sum_diag_x, x)
@@ -3710,7 +3722,6 @@ def test_alloc_diag_values(self):
37103722
true_grad_input = np.diagonal(
37113723
grad_diag_input, offset=offset, axis1=axis1, axis2=axis2
37123724
)
3713-
37143725
assert np.all(true_grad_input == grad_input)
37153726

37163727

0 commit comments

Comments
 (0)