Skip to content

Commit 4f475c8

Browse files
committed
Add blockwise and Cholesky
1 parent b66d859 commit 4f475c8

File tree

5 files changed

+113
-0
lines changed

5 files changed

+113
-0
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@
1111
import pytensor.link.pytorch.dispatch.shape
1212
import pytensor.link.pytorch.dispatch.sort
1313
import pytensor.link.pytorch.dispatch.subtensor
14+
import pytensor.link.pytorch.dispatch.nlinalg
15+
import pytensor.link.pytorch.dispatch.slinalg
16+
import pytensor.link.pytorch.dispatch.blockwise
1417
# isort: on
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
3+
from pytensor.graph import FunctionGraph
4+
from pytensor.link.pytorch.dispatch import pytorch_funcify
5+
from pytensor.tensor.blockwise import Blockwise
6+
7+
8+
@pytorch_funcify.register(Blockwise)
9+
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
10+
batched_dims = op.batch_ndim(node)
11+
core_node = op._create_dummy_core_node(node.inputs)
12+
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13+
core_func = pytorch_funcify(core_fgraph)
14+
if len(node.outputs) == 1:
15+
16+
def inner_func(*inputs):
17+
return core_func(*inputs)[0]
18+
else:
19+
inner_func = core_func
20+
21+
for _ in range(batched_dims):
22+
inner_func = torch.vmap(inner_func)
23+
24+
def batcher(*inputs):
25+
op._check_runtime_broadcast(node, inputs)
26+
return inner_func(*inputs)
27+
28+
return batcher
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch.linalg
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.slinalg import Cholesky, SolveTriangular
5+
6+
7+
@pytorch_funcify.register(Cholesky)
8+
def pytorch_funcify_Cholesky(op, **kwargs):
9+
lower = op.lower
10+
11+
def cholesky(a, lower=lower):
12+
return torch.linalg.cholesky(a, upper=not lower)
13+
14+
return cholesky
15+
16+
17+
@pytorch_funcify.register(SolveTriangular)
18+
def pytorch_funcify_SolveTriangular(op, **kwargs):
19+
lower = op.lower
20+
trans = op.trans
21+
unit_diagonal = op.unit_diagonal
22+
23+
def solve_triangular(A, b):
24+
return torch.linalg.solve_triangular(
25+
A, b, upper=not lower, unit_triangle=unit_diagonal, left=trans == "T"
26+
)
27+
28+
return solve_triangular

tests/link/pytorch/test_blockwise.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
3+
from pytensor.graph.replace import vectorize_node
4+
from pytensor.tensor import tensor
5+
from pytensor.tensor.blockwise import Blockwise
6+
from pytensor.tensor.nlinalg import MatrixInverse
7+
8+
9+
torch = pytest.importorskip("torch")
10+
11+
12+
def test_vectorize_blockwise():
13+
mat = tensor(shape=(None, None))
14+
tns = tensor(shape=(None, None, None))
15+
16+
# Something that falls back to Blockwise
17+
node = MatrixInverse()(mat).owner
18+
vect_node = vectorize_node(node, tns)
19+
assert isinstance(vect_node.op, Blockwise) and isinstance(
20+
vect_node.op.core_op, MatrixInverse
21+
)
22+
assert vect_node.op.signature == ("(m,m)->(m,m)")
23+
assert vect_node.inputs[0] is tns
24+
25+
# Useless blockwise
26+
tns4 = tensor(shape=(5, None, None, None))
27+
new_vect_node = vectorize_node(vect_node, tns4)
28+
assert new_vect_node.op is vect_node.op
29+
assert isinstance(new_vect_node.op, Blockwise) and isinstance(
30+
new_vect_node.op.core_op, MatrixInverse
31+
)
32+
assert new_vect_node.inputs[0] is tns4

tests/link/pytorch/test_slinalg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
from pytensor.tensor import tensor
6+
from pytensor.tensor.slinalg import cholesky
7+
8+
9+
@pytest.mark.parametrize(
10+
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
11+
)
12+
def test_batched_mvnormal_logp_and_dlogp(cov_batch_shape):
13+
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))
14+
15+
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))
16+
17+
test_values = np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape))
18+
19+
chol_cov = cholesky(cov, lower=True, on_error="raise")
20+
21+
fn = pytensor.function([cov], [chol_cov])
22+
assert np.all(np.isclose(fn(test_values), np.linalg.cholesky(test_values)))

0 commit comments

Comments
 (0)