Skip to content

Commit 519c933

Browse files
Add L_op
1 parent 2b5c51d commit 519c933

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pytensor/tensor/slinalg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pytensor
1313
import pytensor.tensor as pt
14+
from pytensor import Variable
1415
from pytensor.gradient import DisconnectedType
1516
from pytensor.graph.basic import Apply
1617
from pytensor.graph.op import Op
@@ -1714,6 +1715,24 @@ def perform(self, node, inputs, outputs_storage):
17141715
fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype)
17151716
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
17161717

1718+
def L_op(
1719+
self,
1720+
inputs: Sequence[Variable],
1721+
outputs: Sequence[Variable],
1722+
output_grads: Sequence[Variable],
1723+
) -> list[Variable]:
1724+
# This is exactly the same as the usual gradient of a matrix-vector product, except that the banded structure
1725+
# is exploited.
1726+
A, x = inputs
1727+
(G_bar,) = output_grads
1728+
1729+
A_bar = pt.outer(G_bar, x.T)
1730+
x_bar = banded_dot(
1731+
A.T, G_bar, lower_diags=self.lower_diags, upper_diags=self.upper_diags
1732+
)
1733+
1734+
return [A_bar, x_bar]
1735+
17171736

17181737
def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int):
17191738
"""

tests/tensor/test_slinalg.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,3 +1090,32 @@ def test_banded_dot(kl, ku, stride):
10901090
rtol = 1e-4 if config.floatX == "float32" else 1e-8
10911091

10921092
np.testing.assert_allclose(out_val, out_2_val, atol=atol, rtol=rtol)
1093+
1094+
1095+
def test_banded_dot_grad():
1096+
rng = np.random.default_rng()
1097+
size = 10
1098+
1099+
A_val = _make_banded_A(rng.normal(size=(size, size)), kl=1, ku=1).astype(
1100+
config.floatX
1101+
)
1102+
x_val = rng.normal(size=(size,)).astype(config.floatX)
1103+
1104+
def make_banded_pt(A):
1105+
# Like structured solve Ops, we have to incldue the transformation from an unconstrained matrix A to a banded
1106+
# matrix on the compute graph. Otherwise, the random perturbations used by verify_grad will result in invalid
1107+
# input matrices.
1108+
1109+
diag_idxs = range(-1, 2)
1110+
diags = (pt.diag(A, k=k) for k in diag_idxs)
1111+
return sum(pt.diag(d, k=k) for k, d in zip(diag_idxs, diags))
1112+
1113+
def test_fn(A, x):
1114+
return banded_dot(make_banded_pt(A), x, lower_diags=1, upper_diags=1).sum()
1115+
1116+
utt.verify_grad(
1117+
test_fn,
1118+
[A_val, x_val],
1119+
rng=rng,
1120+
eps=1e-4 if config.floatX == "float32" else 1e-8,
1121+
)

0 commit comments

Comments
 (0)