Skip to content

Commit 378cb40

Browse files
authored
Rewriting the kron function using JAX implementation (#684)
* Update the kron function to use numpy implementation and move the function to `tensor.nlinalg.py`
1 parent f97d9ea commit 378cb40

File tree

4 files changed

+76
-88
lines changed

4 files changed

+76
-88
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None):
10101010
return TensorSolve(axes)(a, b)
10111011

10121012

1013+
def kron(a, b):
1014+
"""Kronecker product.
1015+
1016+
Same as np.kron(a, b)
1017+
1018+
Parameters
1019+
----------
1020+
a: array_like
1021+
b: array_like
1022+
1023+
Returns
1024+
-------
1025+
array_like with a.ndim + b.ndim - 2 dimensions
1026+
"""
1027+
a = as_tensor_variable(a)
1028+
b = as_tensor_variable(b)
1029+
if a.ndim + b.ndim <= 2:
1030+
raise TypeError(
1031+
"kron: inputs dimensions must sum to 3 or more. "
1032+
f"You passed {int(a.ndim)} and {int(b.ndim)}."
1033+
)
1034+
1035+
if a.ndim < b.ndim:
1036+
a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim)))
1037+
elif b.ndim < a.ndim:
1038+
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
1039+
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
1040+
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
1041+
out_shape = tuple(a.shape * b.shape)
1042+
output_out_of_shape = a_reshaped * b_reshaped
1043+
output_reshaped = output_out_of_shape.reshape(out_shape)
1044+
return output_reshaped
1045+
1046+
10131047
__all__ = [
10141048
"pinv",
10151049
"inv",
@@ -1025,4 +1059,5 @@ def tensorsolve(a, b, axes=None):
10251059
"norm",
10261060
"tensorinv",
10271061
"tensorsolve",
1062+
"kron",
10281063
]

pytensor/tensor/slinalg.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytensor.tensor import basic as ptb
1616
from pytensor.tensor import math as ptm
1717
from pytensor.tensor.blockwise import Blockwise
18-
from pytensor.tensor.nlinalg import matrix_dot
18+
from pytensor.tensor.nlinalg import kron, matrix_dot
1919
from pytensor.tensor.shape import reshape
2020
from pytensor.tensor.type import matrix, tensor, vector
2121
from pytensor.tensor.variable import TensorVariable
@@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True):
559559
return Eigvalsh(lower)(a, b)
560560

561561

562-
def kron(a, b):
563-
"""Kronecker product.
564-
565-
Same as scipy.linalg.kron(a, b).
566-
567-
Parameters
568-
----------
569-
a: array_like
570-
b: array_like
571-
572-
Returns
573-
-------
574-
array_like with a.ndim + b.ndim - 2 dimensions
575-
576-
Notes
577-
-----
578-
numpy.kron(a, b) != scipy.linalg.kron(a, b)!
579-
They don't have the same shape and order when
580-
a.ndim != b.ndim != 2.
581-
582-
"""
583-
a = as_tensor_variable(a)
584-
b = as_tensor_variable(b)
585-
if a.ndim + b.ndim <= 2:
586-
raise TypeError(
587-
"kron: inputs dimensions must sum to 3 or more. "
588-
f"You passed {int(a.ndim)} and {int(b.ndim)}."
589-
)
590-
o = ptm.outer(a, b)
591-
o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
592-
shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim))
593-
if shf.ndim == 3:
594-
shf = o.dimshuffle(1, 0, 2)
595-
o = shf.flatten()
596-
else:
597-
o = shf.reshape(
598-
(
599-
o.shape[0] * o.shape[2],
600-
o.shape[1] * o.shape[3],
601-
*(o.shape[i] for i in range(4, o.ndim)),
602-
)
603-
)
604-
return o
605-
606-
607562
class Expm(Op):
608563
"""
609564
Compute the matrix exponential of a square array.
@@ -1021,7 +976,6 @@ def block_diag(*matrices: TensorVariable):
1021976
"cholesky",
1022977
"solve",
1023978
"eigvalsh",
1024-
"kron",
1025979
"expm",
1026980
"solve_discrete_lyapunov",
1027981
"solve_continuous_lyapunov",

tests/tensor/test_nlinalg.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
det,
1818
eig,
1919
eigh,
20+
kron,
2021
lstsq,
2122
matrix_dot,
2223
matrix_inverse,
@@ -580,3 +581,42 @@ def test_eval(self):
580581
t_binv1 = tf_b1(self.b1)
581582
assert _allclose(t_binv, n_binv)
582583
assert _allclose(t_binv1, n_binv1)
584+
585+
586+
class TestKron(utt.InferShapeTester):
587+
rng = np.random.default_rng(43)
588+
589+
def setup_method(self):
590+
self.op = kron
591+
super().setup_method()
592+
593+
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
594+
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
595+
def test_perform(self, shp0, shp1):
596+
if len(shp0) + len(shp1) == 2:
597+
pytest.skip("Sum of shp0 and shp1 must be more than 2")
598+
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
599+
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
600+
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
601+
f = function([x, y], kron(x, y))
602+
b = self.rng.random(shp1).astype(config.floatX)
603+
out = f(a, b)
604+
# Using the np.kron to compare outputs
605+
np_val = np.kron(a, b)
606+
np.testing.assert_allclose(out, np_val)
607+
608+
@pytest.mark.parametrize(
609+
"i, shp0, shp1",
610+
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
611+
)
612+
def test_kron_commutes_with_inv(self, i, shp0, shp1):
613+
if (pytensor.config.floatX == "float32") & (i == 2):
614+
pytest.skip("Half precision insufficient for test 3 to pass")
615+
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
616+
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
617+
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
618+
b = self.rng.random(shp1).astype(config.floatX)
619+
lhs_f = function([x, y], pinv(kron(x, y)))
620+
rhs_f = function([x, y], kron(pinv(x), pinv(y)))
621+
atol = 1e-4 if config.floatX == "float32" else 1e-12
622+
np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol)

tests/tensor/test_slinalg.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
cholesky,
2121
eigvalsh,
2222
expm,
23-
kron,
2423
solve,
2524
solve_continuous_lyapunov,
2625
solve_discrete_are,
@@ -512,46 +511,6 @@ def test_expm_grad_3():
512511
utt.verify_grad(expm, [A], rng=rng)
513512

514513

515-
class TestKron(utt.InferShapeTester):
516-
rng = np.random.default_rng(43)
517-
518-
def setup_method(self):
519-
self.op = kron
520-
super().setup_method()
521-
522-
def test_perform(self):
523-
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
524-
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
525-
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
526-
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
527-
if len(shp0) + len(shp1) == 2:
528-
continue
529-
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
530-
f = function([x, y], kron(x, y))
531-
b = self.rng.random(shp1).astype(config.floatX)
532-
out = f(a, b)
533-
# Newer versions of scipy want 4 dimensions at least,
534-
# so we have to add a dimension to a and flatten the result.
535-
if len(shp0) + len(shp1) == 3:
536-
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
537-
else:
538-
scipy_val = scipy.linalg.kron(a, b)
539-
np.testing.assert_allclose(out, scipy_val)
540-
541-
def test_numpy_2d(self):
542-
for shp0 in [(2, 3)]:
543-
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
544-
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
545-
for shp1 in [(6, 7)]:
546-
if len(shp0) + len(shp1) == 2:
547-
continue
548-
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
549-
f = function([x, y], kron(x, y))
550-
b = self.rng.random(shp1).astype(config.floatX)
551-
out = f(a, b)
552-
assert np.allclose(out, np.kron(a, b))
553-
554-
555514
def test_solve_discrete_lyapunov_via_direct_real():
556515
N = 5
557516
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)