Skip to content

Commit 68b41a4

Browse files
committed
Better error for fallback of vectorize_node with non-tensor types
1 parent 301f10d commit 68b41a4

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

pytensor/tensor/blockwise.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
_vectorize_not_needed,
1515
vectorize_graph,
1616
)
17+
from pytensor.scalar import ScalarType
1718
from pytensor.tensor import as_tensor_variable
1819
from pytensor.tensor.shape import shape_padleft
19-
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
20+
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
2021
from pytensor.tensor.utils import (
2122
_parse_gufunc_signature,
2223
broadcast_static_dim_lengths,
@@ -373,6 +374,12 @@ def __str__(self):
373374

374375
@_vectorize_node.register(Op)
375376
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
377+
for inp in node.inputs:
378+
if not isinstance(inp.type, (TensorType, ScalarType)):
379+
raise NotImplementedError(
380+
f"Cannot vectorize node {node} with input {inp} of type {inp.type}"
381+
)
382+
376383
if hasattr(op, "gufunc_signature"):
377384
signature = op.gufunc_signature
378385
else:

tests/tensor/test_blockwise.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from itertools import product
23
from typing import Optional, Union
34

@@ -12,7 +13,7 @@
1213
from pytensor.graph.replace import vectorize_node
1314
from pytensor.raise_op import assert_op
1415
from pytensor.tensor import diagonal, log, tensor
15-
from pytensor.tensor.blockwise import Blockwise
16+
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1617
from pytensor.tensor.nlinalg import MatrixInverse
1718
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
1819
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
@@ -42,6 +43,19 @@ def test_vectorize_blockwise():
4243
assert new_vect_node.inputs[0] is tns4
4344

4445

46+
def test_vectorize_node_fallback_unsupported_type():
47+
x = tensor("x", shape=(2, 6))
48+
node = x[:, [0, 2, 4]].owner
49+
50+
with pytest.raises(
51+
NotImplementedError,
52+
match=re.escape(
53+
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
54+
),
55+
):
56+
vectorize_node_fallback(node.op, node, node.inputs)
57+
58+
4559
def check_blockwise_runtime_broadcasting(mode):
4660
a = tensor("a", shape=(None, 3, 5))
4761
b = tensor("b", shape=(None, 5, 3))

0 commit comments

Comments
 (0)