diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index cf510fb065..0252423926 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -1010,6 +1010,40 @@ def tensorsolve(a, b, axes=None): return TensorSolve(axes)(a, b) +def kron(a, b): + """Kronecker product. + + Same as np.kron(a, b) + + Parameters + ---------- + a: array_like + b: array_like + + Returns + ------- + array_like with a.ndim + b.ndim - 2 dimensions + """ + a = as_tensor_variable(a) + b = as_tensor_variable(b) + if a.ndim + b.ndim <= 2: + raise TypeError( + "kron: inputs dimensions must sum to 3 or more. " + f"You passed {int(a.ndim)} and {int(b.ndim)}." + ) + + if a.ndim < b.ndim: + a = ptb.expand_dims(a, tuple(range(b.ndim - a.ndim))) + elif b.ndim < a.ndim: + b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim))) + a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2))) + b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2))) + out_shape = tuple(a.shape * b.shape) + output_out_of_shape = a_reshaped * b_reshaped + output_reshaped = output_out_of_shape.reshape(out_shape) + return output_reshaped + + __all__ = [ "pinv", "inv", @@ -1025,4 +1059,5 @@ def tensorsolve(a, b, axes=None): "norm", "tensorinv", "tensorsolve", + "kron", ] diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index c162902c18..2d032e220e 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -15,7 +15,7 @@ from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.nlinalg import matrix_dot +from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.variable import TensorVariable @@ -559,51 +559,6 @@ def eigvalsh(a, b, lower=True): return Eigvalsh(lower)(a, b) -def kron(a, b): - """Kronecker product. - - Same as scipy.linalg.kron(a, b). - - Parameters - ---------- - a: array_like - b: array_like - - Returns - ------- - array_like with a.ndim + b.ndim - 2 dimensions - - Notes - ----- - numpy.kron(a, b) != scipy.linalg.kron(a, b)! - They don't have the same shape and order when - a.ndim != b.ndim != 2. - - """ - a = as_tensor_variable(a) - b = as_tensor_variable(b) - if a.ndim + b.ndim <= 2: - raise TypeError( - "kron: inputs dimensions must sum to 3 or more. " - f"You passed {int(a.ndim)} and {int(b.ndim)}." - ) - o = ptm.outer(a, b) - o = o.reshape(ptb.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim) - shf = o.dimshuffle(0, 2, 1, *range(3, o.ndim)) - if shf.ndim == 3: - shf = o.dimshuffle(1, 0, 2) - o = shf.flatten() - else: - o = shf.reshape( - ( - o.shape[0] * o.shape[2], - o.shape[1] * o.shape[3], - *(o.shape[i] for i in range(4, o.ndim)), - ) - ) - return o - - class Expm(Op): """ Compute the matrix exponential of a square array. @@ -1021,7 +976,6 @@ def block_diag(*matrices: TensorVariable): "cholesky", "solve", "eigvalsh", - "kron", "expm", "solve_discrete_lyapunov", "solve_continuous_lyapunov", diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index d39ab0b777..d10242d388 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -17,6 +17,7 @@ det, eig, eigh, + kron, lstsq, matrix_dot, matrix_inverse, @@ -580,3 +581,42 @@ def test_eval(self): t_binv1 = tf_b1(self.b1) assert _allclose(t_binv, n_binv) assert _allclose(t_binv1, n_binv1) + + +class TestKron(utt.InferShapeTester): + rng = np.random.default_rng(43) + + def setup_method(self): + self.op = kron + super().setup_method() + + @pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]) + @pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]) + def test_perform(self, shp0, shp1): + if len(shp0) + len(shp1) == 2: + pytest.skip("Sum of shp0 and shp1 must be more than 2") + x = tensor(dtype="floatX", shape=(None,) * len(shp0)) + a = np.asarray(self.rng.random(shp0)).astype(config.floatX) + y = tensor(dtype="floatX", shape=(None,) * len(shp1)) + f = function([x, y], kron(x, y)) + b = self.rng.random(shp1).astype(config.floatX) + out = f(a, b) + # Using the np.kron to compare outputs + np_val = np.kron(a, b) + np.testing.assert_allclose(out, np_val) + + @pytest.mark.parametrize( + "i, shp0, shp1", + [(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))], + ) + def test_kron_commutes_with_inv(self, i, shp0, shp1): + if (pytensor.config.floatX == "float32") & (i == 2): + pytest.skip("Half precision insufficient for test 3 to pass") + x = tensor(dtype="floatX", shape=(None,) * len(shp0)) + a = np.asarray(self.rng.random(shp0)).astype(config.floatX) + y = tensor(dtype="floatX", shape=(None,) * len(shp1)) + b = self.rng.random(shp1).astype(config.floatX) + lhs_f = function([x, y], pinv(kron(x, y))) + rhs_f = function([x, y], kron(pinv(x), pinv(y))) + atol = 1e-4 if config.floatX == "float32" else 1e-12 + np.testing.assert_allclose(lhs_f(a, b), rhs_f(a, b), atol=atol) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index a2cc3c52e8..d39c370ed3 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -20,7 +20,6 @@ cholesky, eigvalsh, expm, - kron, solve, solve_continuous_lyapunov, solve_discrete_are, @@ -512,46 +511,6 @@ def test_expm_grad_3(): utt.verify_grad(expm, [A], rng=rng) -class TestKron(utt.InferShapeTester): - rng = np.random.default_rng(43) - - def setup_method(self): - self.op = kron - super().setup_method() - - def test_perform(self): - for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]: - x = tensor(dtype="floatX", shape=(None,) * len(shp0)) - a = np.asarray(self.rng.random(shp0)).astype(config.floatX) - for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]: - if len(shp0) + len(shp1) == 2: - continue - y = tensor(dtype="floatX", shape=(None,) * len(shp1)) - f = function([x, y], kron(x, y)) - b = self.rng.random(shp1).astype(config.floatX) - out = f(a, b) - # Newer versions of scipy want 4 dimensions at least, - # so we have to add a dimension to a and flatten the result. - if len(shp0) + len(shp1) == 3: - scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten() - else: - scipy_val = scipy.linalg.kron(a, b) - np.testing.assert_allclose(out, scipy_val) - - def test_numpy_2d(self): - for shp0 in [(2, 3)]: - x = tensor(dtype="floatX", shape=(None,) * len(shp0)) - a = np.asarray(self.rng.random(shp0)).astype(config.floatX) - for shp1 in [(6, 7)]: - if len(shp0) + len(shp1) == 2: - continue - y = tensor(dtype="floatX", shape=(None,) * len(shp1)) - f = function([x, y], kron(x, y)) - b = self.rng.random(shp1).astype(config.floatX) - out = f(a, b) - assert np.allclose(out, np.kron(a, b)) - - def test_solve_discrete_lyapunov_via_direct_real(): N = 5 rng = np.random.default_rng(utt.fetch_seed())