diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 1152f02d8a..ab2856b694 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -119,8 +119,8 @@ def convolve1d( if mode == "same": # We implement "same" as "valid" with padded `in1`. in1_batch_shape = tuple(in1.shape)[:-1] - zeros_left = in2.shape[0] // 2 - zeros_right = (in2.shape[0] - 1) // 2 + zeros_left = in2.shape[-1] // 2 + zeros_right = (in2.shape[-1] - 1) // 2 in1 = join( -1, zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype), diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index 968e408485..fe353b18fb 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -47,3 +47,16 @@ def test_convolve1d_batch(): res_np = np.convolve(x_test[0], y_test[0]) np.testing.assert_allclose(res[0], res_np, rtol=rtol) np.testing.assert_allclose(res[1], res_np, rtol=rtol) + + +def test_convolve1d_batch_same(): + x = matrix("data") + y = matrix("kernel") + out = convolve1d(x, y, mode="same") + + rng = np.random.default_rng(38) + x_test = rng.normal(size=(2, 8)).astype(x.dtype) + y_test = rng.normal(size=(2, 8)).astype(x.dtype) + + res = out.eval({x: x_test, y: y_test}) + assert res.shape == (2, 8)