From a187464ef89363d296ab6fd09b4f12f2430dfa3d Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Wed, 19 Mar 2025 11:22:37 +0100 Subject: [PATCH 1/7] Implement gradient for QR decomposition --- pytensor/tensor/nlinalg.py | 72 ++++++++++++++++++++++++++++++++++++ tests/tensor/test_nlinalg.py | 12 ++++++ 2 files changed, 84 insertions(+) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index ee33f6533c..1df1847781 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -512,6 +512,78 @@ def perform(self, node, inputs, outputs): else: outputs[0][0] = res + def L_op(self, inputs, outputs, output_grads): + """ + Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here: + https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803 + + And from ..[2] which describes a solution in the square matrix case. + + References + ---------- + .. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018) + .. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives + of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation", + Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454 + """ + + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) + *_, m, n = A.type.shape + + def _H(x: ptb.TensorVariable): + return x.conj().T + + def _copyutl(x: ptb.TensorVariable): + return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) + + if self.mode == "raw" or (self.mode == "complete" and m != n): + raise NotImplementedError("Gradient of qr not implemented") + + elif m < n: + raise NotImplementedError( + "Gradient of qr not implemented for m x n matrices with m < n" + ) + + elif self.mode == "r": + # We need all the components of the QR to compute the gradient of A even if we only + # use the upper triangular component in the cost function. + Q, R = qr(A, mode="reduced") + dR = cast(ptb.TensorVariable, output_grads[0]) + R_dRt = R @ _H(dR) + Rinvt = _H(inv(R)) + A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR) + return [A_bar] + + else: + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) + + new_output_grads = [] + is_disconnected = [ + isinstance(x.type, DisconnectedType) for x in output_grads + ] + if all(is_disconnected): + # This should never be reached by Pytensor + return [DisconnectedType()()] # pragma: no cover + + for disconnected, output_grad, output in zip( + is_disconnected, output_grads, [Q, R], strict=True + ): + if disconnected: + new_output_grads.append(output.zeros_like()) + else: + new_output_grads.append(output_grad) + + (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) + + Rinvt = _H(inv(R)) + Qt_dQ = _H(Q) @ dQ + R_dRt = R @ _H(dR) + A_bar = ( + Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ + ) @ Rinvt + Q @ dR + + return [A_bar] + def qr(a, mode="reduced"): """ diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 4b83446c5f..6da3594b04 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -152,6 +152,18 @@ def test_qr_modes(): assert "name 'complete' is not defined" in str(e) +@pytest.mark.parametrize("shape", [(3, 3), (6, 3)], ids=["shape=(3, 3)", "shape=(6,3)"]) +@pytest.mark.parametrize("output", [0, 1], ids=["Q", "R"]) +def test_qr_grad(shape, output): + rng = np.random.default_rng(utt.fetch_seed()) + + def _test_fn(x): + return qr(x, mode="reduced")[output] + + a = rng.standard_normal(shape).astype(config.floatX) + utt.verify_grad(_test_fn, [a], rng=np.random) + + class TestSvd(utt.InferShapeTester): op_class = SVD From ee9aaa29dfb37c42aaadb1e900c068be47c20885 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Wed, 19 Mar 2025 15:23:36 +0100 Subject: [PATCH 2/7] replace inv by solve_triangular and improve test coverage in qr gradient --- pytensor/tensor/nlinalg.py | 16 ++++++----- tests/tensor/test_nlinalg.py | 52 +++++++++++++++++++++++++++++++----- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 1df1847781..9b98f53708 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -527,11 +527,13 @@ def L_op(self, inputs, outputs, output_grads): Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454 """ + from pytensor.tensor.slinalg import solve_triangular + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) *_, m, n = A.type.shape def _H(x: ptb.TensorVariable): - return x.conj().T + return x.conj().mT def _copyutl(x: ptb.TensorVariable): return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) @@ -550,8 +552,9 @@ def _copyutl(x: ptb.TensorVariable): Q, R = qr(A, mode="reduced") dR = cast(ptb.TensorVariable, output_grads[0]) R_dRt = R @ _H(dR) - Rinvt = _H(inv(R)) - A_bar = Q @ ((ptb.tril(R_dRt - _H(R_dRt), k=-1)) @ Rinvt + dR) + M = ptb.tril(R_dRt - _H(R_dRt), k=-1) + M_Rinvt = _H(solve_triangular(R, _H(M))) + A_bar = Q @ (M_Rinvt + dR) return [A_bar] else: @@ -575,12 +578,11 @@ def _copyutl(x: ptb.TensorVariable): (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) - Rinvt = _H(inv(R)) Qt_dQ = _H(Q) @ dQ R_dRt = R @ _H(dR) - A_bar = ( - Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ - ) @ Rinvt + Q @ dR + M = Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ + M_Rinvt = _H(solve_triangular(R, _H(M))) + A_bar = M_Rinvt + Q @ dR return [A_bar] diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 6da3594b04..a2b4e7beac 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -152,16 +152,56 @@ def test_qr_modes(): assert "name 'complete' is not defined" in str(e) -@pytest.mark.parametrize("shape", [(3, 3), (6, 3)], ids=["shape=(3, 3)", "shape=(6,3)"]) -@pytest.mark.parametrize("output", [0, 1], ids=["Q", "R"]) -def test_qr_grad(shape, output): +@pytest.mark.parametrize( + "shape, gradient_test_case, mode", + ( + [(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]] + + [((3, 3), 0, "raw")] + ), + ids=( + [ + f"shape={s}, gradient_test_case={c}, mode=reduced" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [ + f"shape={s}, gradient_test_case={c}, mode=complete" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]] + + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] + ), +) +def test_qr_grad(shape, gradient_test_case, mode): rng = np.random.default_rng(utt.fetch_seed()) - def _test_fn(x): - return qr(x, mode="reduced")[output] + def _test_fn(x, case=2, mode="reduced"): + if case == 0: + return qr(x, mode=mode)[0].sum() + elif case == 1: + return qr(x, mode=mode)[1].sum() + elif case == 2: + Q, R = qr(x, mode=mode) + return Q.sum() + R.sum() + m, n = shape a = rng.standard_normal(shape).astype(config.floatX) - utt.verify_grad(_test_fn, [a], rng=np.random) + + if m < n or (mode == "complete" and m != n) or mode == "raw": + with pytest.raises(NotImplementedError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) + + else: + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random + ) class TestSvd(utt.InferShapeTester): From 0e47b7d5d3162ddb238326fc59877ac420c337bd Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Fri, 21 Mar 2025 17:03:22 +0100 Subject: [PATCH 3/7] qr decomposition gradient: add symbolic shape check --- pytensor/tensor/nlinalg.py | 29 +++++++++++++++++++++-------- tests/tensor/test_nlinalg.py | 9 ++++++++- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 9b98f53708..77e9fa2b9e 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -11,6 +11,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.raise_op import Assert from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm @@ -530,7 +531,6 @@ def L_op(self, inputs, outputs, output_grads): from pytensor.tensor.slinalg import solve_triangular (A,) = (cast(ptb.TensorVariable, x) for x in inputs) - *_, m, n = A.type.shape def _H(x: ptb.TensorVariable): return x.conj().mT @@ -538,18 +538,33 @@ def _H(x: ptb.TensorVariable): def _copyutl(x: ptb.TensorVariable): return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) - if self.mode == "raw" or (self.mode == "complete" and m != n): - raise NotImplementedError("Gradient of qr not implemented") + if self.mode == "raw": + raise NotImplementedError("Gradient of qr not implemented for mode=raw") - elif m < n: - raise NotImplementedError( - "Gradient of qr not implemented for m x n matrices with m < n" + elif self.mode == "complete": + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m != n and mode=complete" ) + R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1])) elif self.mode == "r": + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m < n and mode=r" + ) + A = qr_assert_op(A, ptm.ge(A.shape[0], A.shape[1])) # We need all the components of the QR to compute the gradient of A even if we only # use the upper triangular component in the cost function. Q, R = qr(A, mode="reduced") + + else: + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m < n and mode=reduced" + ) + R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1])) + + if self.mode == "r": dR = cast(ptb.TensorVariable, output_grads[0]) R_dRt = R @ _H(dR) M = ptb.tril(R_dRt - _H(R_dRt), k=-1) @@ -558,8 +573,6 @@ def _copyutl(x: ptb.TensorVariable): return [A_bar] else: - Q, R = (cast(ptb.TensorVariable, x) for x in outputs) - new_output_grads = [] is_disconnected = [ isinstance(x.type, DisconnectedType) for x in output_grads diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index a2b4e7beac..26cc5683ac 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -190,13 +190,20 @@ def _test_fn(x, case=2, mode="reduced"): m, n = shape a = rng.standard_normal(shape).astype(config.floatX) - if m < n or (mode == "complete" and m != n) or mode == "raw": + if mode == "raw": with pytest.raises(NotImplementedError): utt.verify_grad( partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random, ) + elif m < n or (mode == "complete" and m != n): + with pytest.raises(AssertionError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) else: utt.verify_grad( From 9e5e76519522f4706204a65533b19f057b0df525 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Sat, 22 Mar 2025 17:00:27 +0100 Subject: [PATCH 4/7] qr decomposition gradient: extend gradient to more input shapes * for mode=reduced or mode=r, all input shapes are accepted * for mode=complete, shapes m x n where m <= n are accepted --- pytensor/tensor/nlinalg.py | 71 +++++++++++++++--------------------- tests/tensor/test_nlinalg.py | 3 +- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 77e9fa2b9e..908df36404 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -5,6 +5,7 @@ import numpy as np +import pytensor.tensor as pt from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType @@ -515,64 +516,43 @@ def perform(self, node, inputs, outputs): def L_op(self, inputs, outputs, output_grads): """ - Reverse-mode gradient of the QR function. Adapted from ..[1], which is used in the forward-mode implementation in jax here: - https://github.com/jax-ml/jax/blob/54691b125ab4b6f88c751dae460e4d51f5cf834a/jax/_src/lax/linalg.py#L1803 - - And from ..[2] which describes a solution in the square matrix case. + Reverse-mode gradient of the QR function. References ---------- - .. [1] Townsend, James. "Differentiating the qr decomposition." online draft https://j-towns.github.io/papers/qr-derivative.pdf (2018) - .. [2] Sebastian F. Walter , Lutz Lehmann & René Lamour. "On evaluating higher-order derivatives - of the QR decomposition of tall matrices with full column rank in forward and reverse mode algorithmic differentiation", - Optimization Methods and Software, 27:2, 391-403, DOI: 10.1080/10556788.2011.610454 + .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ + .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 """ from pytensor.tensor.slinalg import solve_triangular (A,) = (cast(ptb.TensorVariable, x) for x in inputs) + m, n = A.shape def _H(x: ptb.TensorVariable): return x.conj().mT - def _copyutl(x: ptb.TensorVariable): - return ptb.triu(x, k=0) + _H(ptb.triu(x, k=1)) + def _copyltu(x: ptb.TensorVariable): + return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) if self.mode == "raw": raise NotImplementedError("Gradient of qr not implemented for mode=raw") - elif self.mode == "complete": - Q, R = (cast(ptb.TensorVariable, x) for x in outputs) - qr_assert_op = Assert( - "Gradient of qr not implemented for m x n matrices with m != n and mode=complete" - ) - R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1])) - elif self.mode == "r": - qr_assert_op = Assert( - "Gradient of qr not implemented for m x n matrices with m < n and mode=r" - ) - A = qr_assert_op(A, ptm.ge(A.shape[0], A.shape[1])) # We need all the components of the QR to compute the gradient of A even if we only # use the upper triangular component in the cost function. Q, R = qr(A, mode="reduced") + dQ = Q.zeros_like() + dR = cast(ptb.TensorVariable, output_grads[0]) else: Q, R = (cast(ptb.TensorVariable, x) for x in outputs) - qr_assert_op = Assert( - "Gradient of qr not implemented for m x n matrices with m < n and mode=reduced" - ) - R = qr_assert_op(R, ptm.eq(R.shape[0], R.shape[1])) - - if self.mode == "r": - dR = cast(ptb.TensorVariable, output_grads[0]) - R_dRt = R @ _H(dR) - M = ptb.tril(R_dRt - _H(R_dRt), k=-1) - M_Rinvt = _H(solve_triangular(R, _H(M))) - A_bar = Q @ (M_Rinvt + dR) - return [A_bar] + if self.mode == "complete": + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" + ) + R = qr_assert_op(R, ptm.le(m, n)) - else: new_output_grads = [] is_disconnected = [ isinstance(x.type, DisconnectedType) for x in output_grads @@ -591,13 +571,22 @@ def _copyutl(x: ptb.TensorVariable): (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) - Qt_dQ = _H(Q) @ dQ - R_dRt = R @ _H(dR) - M = Q @ (ptb.tril(R_dRt - _H(R_dRt), k=-1) - _copyutl(Qt_dQ)) + dQ - M_Rinvt = _H(solve_triangular(R, _H(M))) - A_bar = M_Rinvt + Q @ dR - - return [A_bar] + # gradient expression when m >= n + M = R @ _H(dR) - _H(dQ) @ Q + K = dQ + Q @ _copyltu(M) + A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) + + # gradient expression when m < n + Y = A[:, m:] + U = R[:, :m] + dU, dV = dR[:, :m], dR[:, m:] + dQ_Yt_dV = dQ + Y @ _H(dV) + M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q + X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) + Y_bar = Q @ dV + A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) + + return [pt.switch(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] def qr(a, mode="reduced"): diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 26cc5683ac..e6ccd42ffa 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -197,7 +197,8 @@ def _test_fn(x, case=2, mode="reduced"): [a], rng=np.random, ) - elif m < n or (mode == "complete" and m != n): + + elif mode == "complete" and m > n: with pytest.raises(AssertionError): utt.verify_grad( partial(_test_fn, case=gradient_test_case, mode=mode), From ac48c11fd10576c8df3df8f8a1f990296ba0d940 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Tue, 25 Mar 2025 16:54:34 +0100 Subject: [PATCH 5/7] qr decompostion gradient: replace swtich by ifelse --- pytensor/tensor/nlinalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 908df36404..8fff2a2f59 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -11,6 +11,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.ifelse import ifelse from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.raise_op import Assert from pytensor.tensor import TensorLike @@ -586,7 +587,7 @@ def _copyltu(x: ptb.TensorVariable): Y_bar = Q @ dV A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) - return [pt.switch(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] + return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] def qr(a, mode="reduced"): From a6ae03b5b60cb830ee2b900db0f274b82c5a270a Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Wed, 26 Mar 2025 12:43:36 +0100 Subject: [PATCH 6/7] qr decomposition gradient: add xfail pytest for complex inputs --- tests/tensor/test_nlinalg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index e6ccd42ffa..7811dcb3c9 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -175,7 +175,8 @@ def test_qr_modes(): + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] ), ) -def test_qr_grad(shape, gradient_test_case, mode): +@pytest.mark.parametrize("is_complex", [True, False], ["complex", "real"]) +def test_qr_grad(shape, gradient_test_case, mode, is_complex): rng = np.random.default_rng(utt.fetch_seed()) def _test_fn(x, case=2, mode="reduced"): @@ -187,8 +188,13 @@ def _test_fn(x, case=2, mode="reduced"): Q, R = qr(x, mode=mode) return Q.sum() + R.sum() + if is_complex: + pytest.xfail("Complex inputs currently not supported by verify_grad") + m, n = shape a = rng.standard_normal(shape).astype(config.floatX) + if is_complex: + a += 1j * rng.standard_normal(shape).astype(config.floatX) if mode == "raw": with pytest.raises(NotImplementedError): From 4edc69867d50da43f47df769bc320732f8ed732a Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Wed, 26 Mar 2025 12:52:32 +0100 Subject: [PATCH 7/7] qr decomposition gradient: fix tests --- tests/tensor/test_nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 7811dcb3c9..c8ae3ac4cb 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -175,7 +175,7 @@ def test_qr_modes(): + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] ), ) -@pytest.mark.parametrize("is_complex", [True, False], ["complex", "real"]) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) def test_qr_grad(shape, gradient_test_case, mode, is_complex): rng = np.random.default_rng(utt.fetch_seed())