Skip to content

Expose vecdot, vecmat and matvec helpers #1248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,6 +2841,176 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
return out


def vecdot(
x1: "ArrayLike",
x2: "ArrayLike",
axis: int = -1,
dtype: Optional["DTypeLike"] = None,
):
"""Compute the dot product of two vectors along specified dimensions.

Parameters
----------
x1, x2
Input arrays, scalars not allowed.
axis
The axis along which to compute the dot product. By default, the last
axes of the inputs are used.
dtype
The desired data-type for the array. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.

Returns
-------
out : ndarray
The vector dot product of the inputs computed along the specified axes.

Raises
------
ValueError
If either input is a scalar value.

Notes
-----
This is similar to `dot` but with broadcasting. It computes the dot product
along the specified axes, treating these as vectors, and broadcasts across
the remaining axes.
"""
x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)

if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("vecdot operand cannot be scalar")

# Handle negative axis
if axis < 0:
x1_axis = axis % x1.type.ndim
x2_axis = axis % x2.type.ndim
else:
x1_axis = axis
x2_axis = axis

# Move the axes to the end for dot product calculation
x1_perm = list(range(x1.type.ndim))
x1_perm.append(x1_perm.pop(x1_axis))
x1_transposed = x1.transpose(x1_perm)

x2_perm = list(range(x2.type.ndim))
x2_perm.append(x2_perm.pop(x2_axis))
x2_transposed = x2.transpose(x2_perm)

# Use the inner product operation
out = _inner_prod(x1_transposed, x2_transposed)

if dtype is not None:
out = out.astype(dtype)

return out


def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
"""Compute the matrix-vector product.

Parameters
----------
x1
Input array for the matrix with shape (..., M, K).
x2
Input array for the vector with shape (..., K).
dtype
The desired data-type for the array. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.

Returns
-------
out : ndarray
The matrix-vector product with shape (..., M).

Raises
------
ValueError
If any input is a scalar or if the trailing dimension of x2 does not match
the second-to-last dimension of x1.

Notes
-----
This is similar to `matmul` where the second argument is a vector,
but with different broadcasting rules. Broadcasting happens over all but
the last dimension of x1 and all dimensions of x2 except the last.
"""
x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)

if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matvec operand cannot be scalar")

if x1.type.ndim < 2:
raise ValueError("First input to matvec must have at least 2 dimensions")

if x2.type.ndim < 1:
raise ValueError("Second input to matvec must have at least 1 dimension")

out = _matrix_vec_prod(x1, x2)

if dtype is not None:
out = out.astype(dtype)

return out


def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
"""Compute the vector-matrix product.

Parameters
----------
x1
Input array for the vector with shape (..., K).
x2
Input array for the matrix with shape (..., K, N).
dtype
The desired data-type for the array. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.

Returns
-------
out : ndarray
The vector-matrix product with shape (..., N).

Raises
------
ValueError
If any input is a scalar or if the last dimension of x1 does not match
the second-to-last dimension of x2.

Notes
-----
This is similar to `matmul` where the first argument is a vector,
but with different broadcasting rules. Broadcasting happens over all but
the last dimension of x1 and all but the last two dimensions of x2.
"""
x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)

if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("vecmat operand cannot be scalar")

if x1.type.ndim < 1:
raise ValueError("First input to vecmat must have at least 1 dimension")

if x2.type.ndim < 2:
raise ValueError("Second input to vecmat must have at least 2 dimensions")

out = _vec_matrix_prod(x1, x2)

if dtype is not None:
out = out.astype(dtype)

return out


@_vectorize_node.register(Dot)
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
Expand Down Expand Up @@ -2937,6 +3107,9 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"max_and_argmax",
"max",
"matmul",
"vecdot",
"matvec",
"vecmat",
"argmax",
"min",
"argmin",
Expand Down
179 changes: 179 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
logaddexp,
logsumexp,
matmul,
matvec,
max,
max_and_argmax,
maximum,
Expand Down Expand Up @@ -123,6 +124,8 @@
true_div,
trunc,
var,
vecdot,
vecmat,
)
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import (
Expand Down Expand Up @@ -2076,6 +2079,182 @@ def is_super_shape(var1, var2):
assert is_super_shape(y, g)


class TestMatrixVectorOps:
def test_vecdot(self):
"""Test vecdot function with various input shapes and axis."""
rng = np.random.default_rng(seed=utt.fetch_seed())

# Test vector-vector
x = vector()
y = vector()
z = vecdot(x, y)
f = function([x, y], z)
x_val = random(5, rng=rng).astype(config.floatX)
y_val = random(5, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))

# Test with axis parameter
x = matrix()
y = matrix()
z0 = vecdot(x, y, axis=0)
z1 = vecdot(x, y, axis=1)
f0 = function([x, y], z0)
f1 = function([x, y], z1)

x_val = random(3, 4, rng=rng).astype(config.floatX)
y_val = random(3, 4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f0(x_val, y_val), np.sum(x_val * y_val, axis=0))
np.testing.assert_allclose(f1(x_val, y_val), np.sum(x_val * y_val, axis=1))

# Test batched vectors
x = tensor3()
y = tensor3()
z = vecdot(x, y, axis=2)
f = function([x, y], z)

x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))

# Test error cases
x = scalar()
y = scalar()
with pytest.raises(ValueError):
vecdot(x, y)

def test_matvec(self):
"""Test matvec function with various input shapes."""
rng = np.random.default_rng(seed=utt.fetch_seed())

# Test matrix-vector
x = matrix()
y = vector()
z = matvec(x, y)
f = function([x, y], z)

x_val = random(3, 4, rng=rng).astype(config.floatX)
y_val = random(4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))

# Test batched
x = tensor3()
y = matrix()
z = matvec(x, y)
f = function([x, y], z)

x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
y_val = random(2, 4, rng=rng).astype(config.floatX)
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
np.testing.assert_allclose(f(x_val, y_val), expected)

# Test error cases
x = vector()
y = vector()
with pytest.raises(ValueError):
matvec(x, y)

x = scalar()
y = vector()
with pytest.raises(ValueError):
matvec(x, y)

def test_vecmat(self):
"""Test vecmat function with various input shapes."""
rng = np.random.default_rng(seed=utt.fetch_seed())

# Test vector-matrix
x = vector()
y = matrix()
z = vecmat(x, y)
f = function([x, y], z)

x_val = random(3, rng=rng).astype(config.floatX)
y_val = random(3, 4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.dot(x_val, y_val))

# Test batched
x = matrix()
y = tensor3()
z = vecmat(x, y)
f = function([x, y], z)

x_val = random(2, 3, rng=rng).astype(config.floatX)
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
np.testing.assert_allclose(f(x_val, y_val), expected)

# Test error cases
x = matrix()
y = vector()
with pytest.raises(ValueError):
vecmat(x, y)

x = scalar()
y = matrix()
with pytest.raises(ValueError):
vecmat(x, y)

def test_matmul(self):
"""Test matmul function with various input shapes."""
rng = np.random.default_rng(seed=utt.fetch_seed())

# Test matrix-matrix
x = matrix()
y = matrix()
z = matmul(x, y)
f = function([x, y], z)

x_val = random(3, 4, rng=rng).astype(config.floatX)
y_val = random(4, 5, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))

# Test vector-matrix
x = vector()
y = matrix()
z = matmul(x, y)
f = function([x, y], z)

x_val = random(3, rng=rng).astype(config.floatX)
y_val = random(3, 4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))

# Test matrix-vector
x = matrix()
y = vector()
z = matmul(x, y)
f = function([x, y], z)

x_val = random(3, 4, rng=rng).astype(config.floatX)
y_val = random(4, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))

# Test vector-vector
x = vector()
y = vector()
z = matmul(x, y)
f = function([x, y], z)

x_val = random(3, rng=rng).astype(config.floatX)
y_val = random(3, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))

# Test batched
x = tensor3()
y = tensor3()
z = matmul(x, y)
f = function([x, y], z)

x_val = random(2, 3, 4, rng=rng).astype(config.floatX)
y_val = random(2, 4, 5, rng=rng).astype(config.floatX)
np.testing.assert_allclose(f(x_val, y_val), np.matmul(x_val, y_val))

# Test error cases
x = scalar()
y = scalar()
with pytest.raises(ValueError):
matmul(x, y)


class TestTensordot:
def TensorDot(self, axes):
# Since tensordot is no longer an op, mimic the old op signature
Expand Down
Loading