From 20a5540e6b149cf4fb60314d35d106105294728c Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 26 Feb 2025 18:17:56 +0800 Subject: [PATCH] Expose vecdot, vecmat and matvec helpers Add three new functions that expose the underlying Blockwise operations: - vecdot: Computes dot products between vectors with broadcasting - matvec: Computes matrix-vector products with broadcasting - vecmat: Computes vector-matrix products with broadcasting These match the NumPy API for similar operations and complement the existing matmul function. Each comes with appropriate error handling, parameter validation, and comprehensive test coverage. Fixes #1237 --- pytensor/tensor/math.py | 173 ++++++++++++++++++++++++++++++++++++ tests/tensor/test_math.py | 179 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 352 insertions(+) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 2aa6ad2381..0649d6dfb2 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2841,6 +2841,176 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None return out +def vecdot( + x1: "ArrayLike", + x2: "ArrayLike", + axis: int = -1, + dtype: Optional["DTypeLike"] = None, +): + """Compute the dot product of two vectors along specified dimensions. + + Parameters + ---------- + x1, x2 + Input arrays, scalars not allowed. + axis + The axis along which to compute the dot product. By default, the last + axes of the inputs are used. + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The vector dot product of the inputs computed along the specified axes. + + Raises + ------ + ValueError + If either input is a scalar value. + + Notes + ----- + This is similar to `dot` but with broadcasting. It computes the dot product + along the specified axes, treating these as vectors, and broadcasts across + the remaining axes. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("vecdot operand cannot be scalar") + + # Handle negative axis + if axis < 0: + x1_axis = axis % x1.type.ndim + x2_axis = axis % x2.type.ndim + else: + x1_axis = axis + x2_axis = axis + + # Move the axes to the end for dot product calculation + x1_perm = list(range(x1.type.ndim)) + x1_perm.append(x1_perm.pop(x1_axis)) + x1_transposed = x1.transpose(x1_perm) + + x2_perm = list(range(x2.type.ndim)) + x2_perm.append(x2_perm.pop(x2_axis)) + x2_transposed = x2.transpose(x2_perm) + + # Use the inner product operation + out = _inner_prod(x1_transposed, x2_transposed) + + if dtype is not None: + out = out.astype(dtype) + + return out + + +def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): + """Compute the matrix-vector product. + + Parameters + ---------- + x1 + Input array for the matrix with shape (..., M, K). + x2 + Input array for the vector with shape (..., K). + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The matrix-vector product with shape (..., M). + + Raises + ------ + ValueError + If any input is a scalar or if the trailing dimension of x2 does not match + the second-to-last dimension of x1. + + Notes + ----- + This is similar to `matmul` where the second argument is a vector, + but with different broadcasting rules. Broadcasting happens over all but + the last dimension of x1 and all dimensions of x2 except the last. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("matvec operand cannot be scalar") + + if x1.type.ndim < 2: + raise ValueError("First input to matvec must have at least 2 dimensions") + + if x2.type.ndim < 1: + raise ValueError("Second input to matvec must have at least 1 dimension") + + out = _matrix_vec_prod(x1, x2) + + if dtype is not None: + out = out.astype(dtype) + + return out + + +def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): + """Compute the vector-matrix product. + + Parameters + ---------- + x1 + Input array for the vector with shape (..., K). + x2 + Input array for the matrix with shape (..., K, N). + dtype + The desired data-type for the array. If not given, then the type will + be determined as the minimum type required to hold the objects in the + sequence. + + Returns + ------- + out : ndarray + The vector-matrix product with shape (..., N). + + Raises + ------ + ValueError + If any input is a scalar or if the last dimension of x1 does not match + the second-to-last dimension of x2. + + Notes + ----- + This is similar to `matmul` where the first argument is a vector, + but with different broadcasting rules. Broadcasting happens over all but + the last dimension of x1 and all but the last two dimensions of x2. + """ + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("vecmat operand cannot be scalar") + + if x1.type.ndim < 1: + raise ValueError("First input to vecmat must have at least 1 dimension") + + if x2.type.ndim < 2: + raise ValueError("Second input to vecmat must have at least 2 dimensions") + + out = _vec_matrix_prod(x1, x2) + + if dtype is not None: + out = out.astype(dtype) + + return out + + @_vectorize_node.register(Dot) def vectorize_node_dot(op, node, batched_x, batched_y): old_x, old_y = node.inputs @@ -2937,6 +3107,9 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "max_and_argmax", "max", "matmul", + "vecdot", + "matvec", + "vecmat", "argmax", "min", "argmin", diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 9ab4fd104d..39b6fb3daf 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -89,6 +89,7 @@ logaddexp, logsumexp, matmul, + matvec, max, max_and_argmax, maximum, @@ -123,6 +124,8 @@ true_div, trunc, var, + vecdot, + vecmat, ) from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( @@ -2076,6 +2079,182 @@ def is_super_shape(var1, var2): assert is_super_shape(y, g) +class TestMatrixVectorOps: + def test_vecdot(self): + """Test vecdot function with various input shapes and axis.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test vector-vector + x = vector() + y = vector() + z = vecdot(x, y) + f = function([x, y], z) + x_val = random(5, rng=rng).astype(config.floatX) + y_val = random(5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test with axis parameter + x = matrix() + y = matrix() + z0 = vecdot(x, y, axis=0) + z1 = vecdot(x, y, axis=1) + f0 = function([x, y], z0) + f1 = function([x, y], z1) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0)) + np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1)) + + # Test batched vectors + x = tensor3() + y = tensor3() + z = vecdot(x, y, axis=2) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2)) + + # Test error cases + x = scalar() + y = scalar() + with pytest.raises(ValueError): + vecdot(x, y) + + def test_matvec(self): + """Test matvec function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test matrix-vector + x = matrix() + y = vector() + z = matvec(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test batched + x = tensor3() + y = matrix() + z = matvec(x, y) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 4, rng=rng).astype(config.floatX) + expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)]) + np.testing.assert_allclose(f(x_val, y_val), expected) + + # Test error cases + x = vector() + y = vector() + with pytest.raises(ValueError): + matvec(x, y) + + x = scalar() + y = vector() + with pytest.raises(ValueError): + matvec(x, y) + + def test_vecmat(self): + """Test vecmat function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test vector-matrix + x = vector() + y = matrix() + z = vecmat(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val)) + + # Test batched + x = matrix() + y = tensor3() + z = vecmat(x, y) + f = function([x, y], z) + + x_val = random(2, 3, rng=rng).astype(config.floatX) + y_val = random(2, 3, 4, rng=rng).astype(config.floatX) + expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)]) + np.testing.assert_allclose(f(x_val, y_val), expected) + + # Test error cases + x = matrix() + y = vector() + with pytest.raises(ValueError): + vecmat(x, y) + + x = scalar() + y = matrix() + with pytest.raises(ValueError): + vecmat(x, y) + + def test_matmul(self): + """Test matmul function with various input shapes.""" + rng = np.random.default_rng(seed=utt.fetch_seed()) + + # Test matrix-matrix + x = matrix() + y = matrix() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, 5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test vector-matrix + x = vector() + y = matrix() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, 4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test matrix-vector + x = matrix() + y = vector() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, 4, rng=rng).astype(config.floatX) + y_val = random(4, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test vector-vector + x = vector() + y = vector() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(3, rng=rng).astype(config.floatX) + y_val = random(3, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test batched + x = tensor3() + y = tensor3() + z = matmul(x, y) + f = function([x, y], z) + + x_val = random(2, 3, 4, rng=rng).astype(config.floatX) + y_val = random(2, 4, 5, rng=rng).astype(config.floatX) + np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val)) + + # Test error cases + x = scalar() + y = scalar() + with pytest.raises(ValueError): + matmul(x, y) + + class TestTensordot: def TensorDot(self, axes): # Since tensordot is no longer an op, mimic the old op signature