Skip to content

Commit fa017de

Browse files
Naive implementation, do not merge
1 parent 261aaf3 commit fa017de

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

pytensor/tensor/slinalg.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,95 @@ def block_diag(*matrices: TensorVariable):
16691669
return _block_diagonal_matrix(*matrices)
16701670

16711671

1672+
def _to_banded_form(A, kl, ku):
1673+
"""
1674+
Convert a full matrix A to LAPACK banded form for gbmv.
1675+
1676+
Parameters
1677+
----------
1678+
A: np.ndarray
1679+
(m, n) banded matrix with nonzero values on the diagonals
1680+
kl: int
1681+
Number of nonzero lower diagonals of A
1682+
ku: int
1683+
Number of nonzero upper diagonals of A
1684+
1685+
Returns
1686+
-------
1687+
ab: np.ndarray
1688+
(kl + ku + 1, n) banded matrix suitable for LAPACK
1689+
"""
1690+
A = np.asarray(A)
1691+
m, n = A.shape
1692+
ab = np.zeros((kl + ku + 1, n), dtype=A.dtype, order="C")
1693+
1694+
for i, k in enumerate(range(ku, -kl - 1, -1)):
1695+
padding = (k, 0) if k >= 0 else (0, -k)
1696+
diag = np.pad(np.diag(A, k=k), padding)
1697+
ab[i, :] = diag
1698+
1699+
return ab
1700+
1701+
1702+
class BandedDot(Op):
1703+
__props__ = ("lower_diags", "upper_diags")
1704+
gufunc_signature = "(m,n),(n)->(n)"
1705+
1706+
def __init__(self, lower_diags, upper_diags):
1707+
self.lower_diags = lower_diags
1708+
self.upper_diags = upper_diags
1709+
1710+
def make_node(self, A, b):
1711+
A = as_tensor_variable(A)
1712+
B = as_tensor_variable(b)
1713+
1714+
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
1715+
output = b.type().astype(out_dtype)
1716+
1717+
return pytensor.graph.basic.Apply(self, [A, B], [output])
1718+
1719+
def perform(self, node, inputs, outputs_storage):
1720+
A, b = inputs
1721+
m, n = A.shape
1722+
alpha = 1
1723+
1724+
kl = self.lower_diags
1725+
ku = self.upper_diags
1726+
1727+
A_banded = _to_banded_form(A, kl, ku)
1728+
1729+
fn = scipy_linalg.get_blas_funcs("gbmv", (A, b))
1730+
outputs_storage[0][0] = fn(m, n, kl, ku, alpha, A_banded, b)
1731+
1732+
1733+
def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int):
1734+
"""
1735+
Specialized matrix-vector multiplication for cases when A is a banded matrix
1736+
1737+
No type-checking is done on A at runtime, so all data in A off the banded diagonals will be ignored. This will lead
1738+
to incorrect results if A is not actually a banded matrix.
1739+
1740+
Unlike dot, this function is only valid if b is a vector.
1741+
1742+
Parameters
1743+
----------
1744+
A: Tensorlike
1745+
Matrix to perform banded dot on.
1746+
b: Tensorlike
1747+
Vector to perform banded dot on.
1748+
lower_diags: int
1749+
Number of nonzero lower diagonals of A
1750+
upper_diags: int
1751+
Number of nonzero upper diagonals of A
1752+
1753+
Returns
1754+
-------
1755+
out: Tensor
1756+
The matrix multiplication result
1757+
"""
1758+
return Blockwise(BandedDot(lower_diags, upper_diags))(A, b)
1759+
1760+
16721761
__all__ = [
16731762
"cholesky",
16741763
"solve",
@@ -1683,4 +1772,5 @@ def block_diag(*matrices: TensorVariable):
16831772
"lu",
16841773
"lu_factor",
16851774
"lu_solve",
1775+
"banded_dot",
16861776
]

tests/tensor/test_slinalg.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from pytensor.graph.basic import equal_computations
1313
from pytensor.tensor import TensorVariable
1414
from pytensor.tensor.slinalg import (
15+
BandedDot,
1516
Cholesky,
1617
CholeskySolve,
1718
Solve,
1819
SolveBase,
1920
SolveTriangular,
21+
banded_dot,
2022
block_diag,
2123
cho_solve,
2224
cholesky,
@@ -1051,3 +1053,79 @@ def test_block_diagonal_blockwise():
10511053
B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX)
10521054
result = block_diag(A, B).eval()
10531055
assert result.shape == (10, batch_size, 6, 6)
1056+
1057+
1058+
def _make_banded_A(A, kl, ku):
1059+
diag_idxs = range(-kl, ku + 1)
1060+
diags = (np.diag(A, k=k) for k in diag_idxs)
1061+
return sum(np.diag(d, k=k) for k, d in zip(diag_idxs, diags))
1062+
1063+
1064+
@pytest.mark.parametrize(
1065+
"A_shape",
1066+
[
1067+
(10, 10),
1068+
],
1069+
)
1070+
@pytest.mark.parametrize(
1071+
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
1072+
)
1073+
def test_banded_dot(A_shape, kl, ku):
1074+
rng = np.random.default_rng()
1075+
1076+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1077+
b_val = rng.normal(size=(A_shape[-1],))
1078+
1079+
A = pt.tensor("A", shape=A_val.shape)
1080+
b = pt.tensor("b", shape=b_val.shape)
1081+
res = banded_dot(A, b, kl, ku)
1082+
res_2 = A @ b
1083+
1084+
fn = function([A, b], [res, res_2])
1085+
assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes)
1086+
1087+
x_val, x2_val = fn(A_val, b_val)
1088+
1089+
np.testing.assert_allclose(x_val, x2_val)
1090+
1091+
1092+
@pytest.mark.parametrize(
1093+
"A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"]
1094+
)
1095+
@pytest.mark.parametrize(
1096+
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
1097+
)
1098+
def test_banded_dot_perf(A_shape, kl, ku, benchmark):
1099+
rng = np.random.default_rng()
1100+
1101+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1102+
b_val = rng.normal(size=(A_shape[-1],))
1103+
1104+
A = pt.tensor("A", shape=A_val.shape)
1105+
b = pt.tensor("b", shape=b_val.shape)
1106+
1107+
res = banded_dot(A, b, kl, ku)
1108+
fn = function([A, b], res)
1109+
1110+
benchmark(fn, A_val, b_val)
1111+
1112+
1113+
@pytest.mark.parametrize(
1114+
"A_shape", [(10, 10), (100, 100), (1000, 1000)], ids=["10", "100", "1000"]
1115+
)
1116+
@pytest.mark.parametrize(
1117+
"kl, ku", [(1, 1), (0, 1), (2, 2)], ids=["tridiag", "upper-only", "banded"]
1118+
)
1119+
def test_dot_perf(A_shape, kl, ku, benchmark):
1120+
rng = np.random.default_rng()
1121+
1122+
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku)
1123+
b_val = rng.normal(size=(A_shape[-1],))
1124+
1125+
A = pt.tensor("A", shape=A_val.shape)
1126+
b = pt.tensor("b", shape=b_val.shape)
1127+
1128+
res = A @ b
1129+
fn = function([A, b], res)
1130+
1131+
benchmark(fn, A_val, b_val)

0 commit comments

Comments
 (0)