Skip to content

Commit b248e5b

Browse files
twieckiclaude
andcommitted
Simplify matrix/vector helper functions
- Remove redundant dimension checks that Blockwise already handles - Streamline test cases while keeping essential coverage - Based on PR feedback from Ricardo 🤖 Generated with Claude Code Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 20a5540 commit b248e5b

File tree

2 files changed

+0
-66
lines changed

2 files changed

+0
-66
lines changed

pytensor/tensor/math.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,11 +2866,6 @@ def vecdot(
28662866
out : ndarray
28672867
The vector dot product of the inputs computed along the specified axes.
28682868
2869-
Raises
2870-
------
2871-
ValueError
2872-
If either input is a scalar value.
2873-
28742869
Notes
28752870
-----
28762871
This is similar to `dot` but with broadcasting. It computes the dot product
@@ -2880,9 +2875,6 @@ def vecdot(
28802875
x1 = as_tensor_variable(x1)
28812876
x2 = as_tensor_variable(x2)
28822877

2883-
if x1.type.ndim == 0 or x2.type.ndim == 0:
2884-
raise ValueError("vecdot operand cannot be scalar")
2885-
28862878
# Handle negative axis
28872879
if axis < 0:
28882880
x1_axis = axis % x1.type.ndim
@@ -2928,12 +2920,6 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29282920
out : ndarray
29292921
The matrix-vector product with shape (..., M).
29302922
2931-
Raises
2932-
------
2933-
ValueError
2934-
If any input is a scalar or if the trailing dimension of x2 does not match
2935-
the second-to-last dimension of x1.
2936-
29372923
Notes
29382924
-----
29392925
This is similar to `matmul` where the second argument is a vector,
@@ -2943,15 +2929,6 @@ def matvec(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29432929
x1 = as_tensor_variable(x1)
29442930
x2 = as_tensor_variable(x2)
29452931

2946-
if x1.type.ndim == 0 or x2.type.ndim == 0:
2947-
raise ValueError("matvec operand cannot be scalar")
2948-
2949-
if x1.type.ndim < 2:
2950-
raise ValueError("First input to matvec must have at least 2 dimensions")
2951-
2952-
if x2.type.ndim < 1:
2953-
raise ValueError("Second input to matvec must have at least 1 dimension")
2954-
29552932
out = _matrix_vec_prod(x1, x2)
29562933

29572934
if dtype is not None:
@@ -2979,12 +2956,6 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29792956
out : ndarray
29802957
The vector-matrix product with shape (..., N).
29812958
2982-
Raises
2983-
------
2984-
ValueError
2985-
If any input is a scalar or if the last dimension of x1 does not match
2986-
the second-to-last dimension of x2.
2987-
29882959
Notes
29892960
-----
29902961
This is similar to `matmul` where the first argument is a vector,
@@ -2994,15 +2965,6 @@ def vecmat(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
29942965
x1 = as_tensor_variable(x1)
29952966
x2 = as_tensor_variable(x2)
29962967

2997-
if x1.type.ndim == 0 or x2.type.ndim == 0:
2998-
raise ValueError("vecmat operand cannot be scalar")
2999-
3000-
if x1.type.ndim < 1:
3001-
raise ValueError("First input to vecmat must have at least 1 dimension")
3002-
3003-
if x2.type.ndim < 2:
3004-
raise ValueError("Second input to vecmat must have at least 2 dimensions")
3005-
30062968
out = _vec_matrix_prod(x1, x2)
30072969

30082970
if dtype is not None:

tests/tensor/test_math.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,12 +2116,6 @@ def test_vecdot(self):
21162116
y_val = random(2, 3, 4, rng=rng).astype(config.floatX)
21172117
np.testing.assert_allclose(f(x_val, y_val), np.sum(x_val * y_val, axis=2))
21182118

2119-
# Test error cases
2120-
x = scalar()
2121-
y = scalar()
2122-
with pytest.raises(ValueError):
2123-
vecdot(x, y)
2124-
21252119
def test_matvec(self):
21262120
"""Test matvec function with various input shapes."""
21272121
rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -2147,17 +2141,6 @@ def test_matvec(self):
21472141
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
21482142
np.testing.assert_allclose(f(x_val, y_val), expected)
21492143

2150-
# Test error cases
2151-
x = vector()
2152-
y = vector()
2153-
with pytest.raises(ValueError):
2154-
matvec(x, y)
2155-
2156-
x = scalar()
2157-
y = vector()
2158-
with pytest.raises(ValueError):
2159-
matvec(x, y)
2160-
21612144
def test_vecmat(self):
21622145
"""Test vecmat function with various input shapes."""
21632146
rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -2183,17 +2166,6 @@ def test_vecmat(self):
21832166
expected = np.array([np.dot(x_val[i], y_val[i]) for i in range(2)])
21842167
np.testing.assert_allclose(f(x_val, y_val), expected)
21852168

2186-
# Test error cases
2187-
x = matrix()
2188-
y = vector()
2189-
with pytest.raises(ValueError):
2190-
vecmat(x, y)
2191-
2192-
x = scalar()
2193-
y = matrix()
2194-
with pytest.raises(ValueError):
2195-
vecmat(x, y)
2196-
21972169
def test_matmul(self):
21982170
"""Test matmul function with various input shapes."""
21992171
rng = np.random.default_rng(seed=utt.fetch_seed())

0 commit comments

Comments
 (0)