Skip to content

Commit efa845a

Browse files
Add jax dispatch for KroneckerProduct Op (#822)
1 parent 0e29d76 commit efa845a

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Det,
99
Eig,
1010
Eigh,
11+
KroneckerProduct,
1112
MatrixInverse,
1213
MatrixPinv,
1314
QRFull,
@@ -104,6 +105,14 @@ def batched_dot(a, b):
104105
return batched_dot
105106

106107

108+
@jax_funcify.register(KroneckerProduct)
109+
def jax_funcify_KroneckerProduct(op, **kwargs):
110+
def _kron(x, y):
111+
return jnp.kron(x, y)
112+
113+
return _kron
114+
115+
107116
@jax_funcify.register(Max)
108117
def jax_funcify_Max(op, **kwargs):
109118
axis = op.axis

tests/link/jax/test_nlinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,15 @@ def test_pinv_hermitian():
165165
assert not np.allclose(
166166
jax_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
167167
)
168+
169+
170+
def test_kron():
171+
x = matrix("x")
172+
y = matrix("y")
173+
z = pt_nlinalg.kron(x, y)
174+
175+
fgraph = FunctionGraph([x, y], [z])
176+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
177+
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
178+
179+
compare_jax_and_py(fgraph, [x_np, y_np])

0 commit comments

Comments
 (0)