Skip to content

Commit 679b2f7

Browse files
JAX dispatches for LU Ops
1 parent 1aa9a39 commit 679b2f7

File tree

3 files changed

+105
-26
lines changed

3 files changed

+105
-26
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.tensor.slinalg import (
7+
LU,
78
BlockDiagonal,
89
Cholesky,
910
Eigvalsh,
11+
LUFactor,
12+
PivotToPermutations,
1013
Solve,
1114
SolveTriangular,
1215
)
@@ -93,3 +96,46 @@ def block_diag(*inputs):
9396
return jax.scipy.linalg.block_diag(*inputs)
9497

9598
return block_diag
99+
100+
101+
@jax_funcify.register(PivotToPermutations)
102+
def jax_funcify_PivotToPermutation(op, **kwargs):
103+
inverse = op.inverse
104+
105+
def pivot_to_permutations(pivots):
106+
p_inv = jax.lax.linalg.lu_pivots_to_permutation(pivots, pivots.shape[0])
107+
if inverse:
108+
return p_inv
109+
return jax.numpy.argsort(p_inv)
110+
111+
return pivot_to_permutations
112+
113+
114+
@jax_funcify.register(LU)
115+
def jax_funcify_LU(op, **kwargs):
116+
permute_l = op.permute_l
117+
p_indices = op.p_indices
118+
check_finite = op.check_finite
119+
120+
if p_indices:
121+
raise ValueError("JAX does not support the p_indices argument")
122+
123+
def lu(*inputs):
124+
return jax.scipy.linalg.lu(
125+
*inputs, permute_l=permute_l, check_finite=check_finite
126+
)
127+
128+
return lu
129+
130+
131+
@jax_funcify.register(LUFactor)
132+
def jax_funcify_LUFactor(op, **kwargs):
133+
check_finite = op.check_finite
134+
overwrite_a = op.overwrite_a
135+
136+
def lu_factor(a):
137+
return jax.scipy.linalg.lu_factor(
138+
a, check_finite=check_finite, overwrite_a=overwrite_a
139+
)
140+
141+
return lu_factor

pytensor/tensor/slinalg.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010

1111
import pytensor
1212
import pytensor.tensor as pt
13-
from pytensor.compile.builders import OpFromGraph
1413
from pytensor.gradient import DisconnectedType
15-
from pytensor.graph.basic import Apply, Variable
14+
from pytensor.graph.basic import Apply
1615
from pytensor.graph.op import Op
1716
from pytensor.tensor import TensorLike, as_tensor_variable
1817
from pytensor.tensor import basic as ptb
@@ -616,7 +615,7 @@ def perform(self, node, inputs, outputs):
616615
outputs[0][0] = np.argsort(p_inv)
617616

618617

619-
def pivot_to_permutation(p: TensorLike, inverse=False) -> Variable:
618+
def pivot_to_permutation(p: TensorLike, inverse=False):
620619
p = pt.as_tensor_variable(p)
621620
return PivotToPermutations(inverse=inverse)(p)
622621

@@ -724,29 +723,6 @@ def lu_factor(
724723
)
725724

726725

727-
class LUSolve(OpFromGraph):
728-
"""Solve a system of linear equations given the LU decomposition of the matrix."""
729-
730-
__props__ = ("trans", "b_ndim", "check_finite", "overwrite_b")
731-
732-
def __init__(
733-
self,
734-
inputs: list[Variable],
735-
outputs: list[Variable],
736-
trans: bool = False,
737-
b_ndim: int | None = None,
738-
check_finite: bool = False,
739-
overwrite_b: bool = False,
740-
**kwargs,
741-
):
742-
self.trans = trans
743-
self.b_ndim = b_ndim
744-
self.check_finite = check_finite
745-
self.overwrite_b = overwrite_b
746-
747-
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
748-
749-
750726
def lu_solve(
751727
LU_and_pivots: tuple[TensorLike, TensorLike],
752728
b: TensorLike,

tests/link/jax/test_slinalg.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,60 @@ def test_jax_solve_discrete_lyapunov(
228228
jax_mode="JAX",
229229
assert_fn=partial(np.testing.assert_allclose, atol=atol, rtol=rtol),
230230
)
231+
232+
233+
@pytest.mark.parametrize(
234+
"permute_l, p_indices",
235+
[(True, False), (False, True), (False, False)],
236+
ids=["PL", "p_indices", "P"],
237+
)
238+
@pytest.mark.parametrize("complex", [False, True], ids=["real", "complex"])
239+
@pytest.mark.parametrize("shape", [(3, 5, 5), (5, 5)], ids=["batched", "not_batched"])
240+
def test_jax_lu(permute_l, p_indices, complex, shape: tuple[int]):
241+
rng = np.random.default_rng()
242+
A = pt.tensor(
243+
"A",
244+
shape=shape,
245+
dtype=f"complex{int(config.floatX[-2:]) * 2}" if complex else config.floatX,
246+
)
247+
out = pt_slinalg.lu(A, permute_l=permute_l, p_indices=p_indices)
248+
249+
x = rng.normal(size=shape).astype(config.floatX)
250+
if complex:
251+
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
252+
253+
if p_indices:
254+
with pytest.raises(
255+
ValueError, match="JAX does not support the p_indices argument"
256+
):
257+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
258+
else:
259+
compare_jax_and_py(graph_inputs=[A], graph_outputs=out, test_inputs=[x])
260+
261+
262+
@pytest.mark.parametrize("shape", [(5, 5), (5, 5, 5)], ids=["matrix", "batch"])
263+
def test_jax_lu_factor(shape):
264+
rng = np.random.default_rng(utt.fetch_seed())
265+
A = pt.tensor(name="A", shape=shape)
266+
A_value = rng.normal(size=shape).astype(config.floatX)
267+
out = pt_slinalg.lu_factor(A)
268+
269+
compare_jax_and_py(
270+
[A],
271+
out,
272+
[A_value],
273+
)
274+
275+
276+
@pytest.mark.parametrize("b_shape", [(5,), (5, 5)])
277+
def test_jax_lu_solve(b_shape):
278+
rng = np.random.default_rng(utt.fetch_seed())
279+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
280+
b_val = rng.normal(size=b_shape).astype(config.floatX)
281+
282+
A = pt.tensor(name="A", shape=(5, 5))
283+
b = pt.tensor(name="b", shape=b_shape)
284+
lu_and_pivots = pt_slinalg.lu_factor(A)
285+
out = pt_slinalg.lu_solve(lu_and_pivots, b)
286+
287+
compare_jax_and_py([A, b], [out], [A_val, b_val])

0 commit comments

Comments
 (0)