diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 2aa6ad2381..0649d6dfb2 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2841,6 +2841,176 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None return out +def vecdot( + x1: "ArrayLike", + x2: "ArrayLike", + axis: int = -1, + dtype: Optional["DTypeLike"] = None, +): + """Compute the dot product of two vectors along specified dimensions. + + Parameters + ---------- + x1, x2 + Input arrays, scalars not allowed. + axis + The axis along which to compute the dot product. By default, the last + axes of the inputs are used. + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The vector dot product of the inputs computed along the specified axes. + + Raises + ------ + ValueError + If either input is a scalar value. + + Notes + ----- + This is similar to `dot` but with broadcasting. It computes the dot product + along the specified axes, treating these as vectors, and broadcasts across + the remaining axes. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("vecdot operand cannot be scalar") + + # Handle negative axis + if axis < 0: + x1_axis = axis % x1.type.ndim + x2_axis = axis % x2.type.ndim + else: + x1_axis = axis + x2_axis = axis + + # Move the axes to the end for dot product calculation + x1_perm = list(range(x1.type.ndim)) + x1_perm.append(x1_perm.pop(x1_axis)) + x1_transposed = x1.transpose(x1_perm) + + x2_perm = list(range(x2.type.ndim)) + x2_perm.append(x2_perm.pop(x2_axis)) + x2_transposed = x2.transpose(x2_perm) + + # Use the inner product operation + out = _inner_prod(x1_transposed, x2_transposed) + + if dtype is not None: + out = out.astype(dtype) + + return out + + +def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): + """Compute the matrix-vector product. + + Parameters + ---------- + x1 + Input array for the matrix with shape (..., M, K). + x2 + Input array for the vector with shape (..., K). + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The matrix-vector product with shape (..., M). + + Raises + ------ + ValueError + If any input is a scalar or if the trailing dimension of x2 does not match + the second-to-last dimension of x1. + + Notes + ----- + This is similar to `matmul` where the second argument is a vector, + but with different broadcasting rules. Broadcasting happens over all but + the last dimension of x1 and all dimensions of x2 except the last. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("matvec operand cannot be scalar") + + if x1.type.ndim < 2: + raise ValueError("First input to matvec must have at least 2 dimensions") + + if x2.type.ndim < 1: + raise ValueError("Second input to matvec must have at least 1 dimension") + + out = _matrix_vec_prod(x1, x2) + + if dtype is not None: + out = out.astype(dtype) + + return out + + +def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): + """Compute the vector-matrix product. + + Parameters + ---------- + x1 + Input array for the vector with shape (..., K). + x2 + Input array for the matrix with shape (..., K, N). + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The vector-matrix product with shape (..., N). + + Raises + ------ + ValueError + If any input is a scalar or if the last dimension of x1 does not match + the second-to-last dimension of x2. + + Notes + ----- + This is similar to `matmul` where the first argument is a vector, + but with different broadcasting rules. Broadcasting happens over all but + the last dimension of x1 and all but the last two dimensions of x2. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("vecmat operand cannot be scalar") + + if x1.type.ndim < 1: + raise ValueError("First input to vecmat must have at least 1 dimension") + + if x2.type.ndim < 2: + raise ValueError("Second input to vecmat must have at least 2 dimensions") + + out = _vec_matrix_prod(x1, x2) + + if dtype is not None: + out = out.astype(dtype) + + return out + + @_vectorize_node.register(Dot) def vectorize_node_dot(op, node, batched_x, batched_y): old_x, old_y = node.inputs @@ -2937,6 +3107,9 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "max_and_argmax", "max", "matmul", + "vecdot", + "matvec", + "vecmat", "argmax", "min", "argmin", diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 9ab4fd104d..39b6fb3daf 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -89,6 +89,7 @@ logaddexp, logsumexp, matmul, + matvec, max, max_and_argmax, maximum, @@ -123,6 +124,8 @@ true_div, trunc, var, + vecdot, + vecmat, ) from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( @@ -2076,6 +2079,182 @@ def is_super_shape(var1, var2): assert is_super_shape(y, g) +class TestMatrixVectorOps: + def test_vecdot(self): + """Test vecdot function with various input shapes and axis.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test vector-vector + x = vector() + y = vector() + z = vecdot(x, y) + f = function([x, y], z) + x_val = random(5, rng=rng).astype(config.floatX) + y_val = random(5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test with axis parameter + x = matrix() + y = matrix() + z0 = vecdot(x, y, axis=0) + z1 = vecdot(x, y, axis=1) + f0 = function([x, y], z0) + f1 = function([x, y], z1) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0)) + np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1)) + + # Test batched vectors + x = tensor3() + y = tensor3() + z = vecdot(x, y, axis=2) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2)) + + # Test error cases + x = scalar() + y = scalar() + with pytest.raises(ValueError): + vecdot(x, y) + + def test_matvec(self): + """Test matvec function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test matrix-vector + x = matrix() + y = vector() + z = matvec(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test batched + x = tensor3() + y = matrix() + z = matvec(x, y) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 4, rng=rng).astype(config.floatX) + expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)]) + np.testing.assert_allclose(f(x_val, y_val), expected) + + # Test error cases + x = vector() + y = vector() + with pytest.raises(ValueError): + matvec(x, y) + + x = scalar() + y = vector() + with pytest.raises(ValueError): + matvec(x, y) + + def test_vecmat(self): + """Test vecmat function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test vector-matrix + x = vector() + y = matrix() + z = vecmat(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test batched + x = matrix() + y = tensor3() + z = vecmat(x, y) + f = function([x, y], z) + + x_val = random(2, 3, rng=rng).astype(config.floatX) + y_val = random(2, 3, 4, rng=rng).astype(config.floatX) + expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)]) + np.testing.assert_allclose(f(x_val, y_val), expected) + + # Test error cases + x = matrix() + y = vector() + with pytest.raises(ValueError): + vecmat(x, y) + + x = scalar() + y = matrix() + with pytest.raises(ValueError): + vecmat(x, y) + + def test_matmul(self): + """Test matmul function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test matrix-matrix + x = matrix() + y = matrix() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, 5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test vector-matrix + x = vector() + y = matrix() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test matrix-vector + x = matrix() + y = vector() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test vector-vector + x = vector() + y = vector() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test batched + x = tensor3() + y = tensor3() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 4, 5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test error cases + x = scalar() + y = scalar() + with pytest.raises(ValueError): + matmul(x, y) + + class TestTensordot: def TensorDot(self, axes): # Since tensordot is no longer an op, mimic the old op signature