|
| 1 | +import re |
1 | 2 | from itertools import product
|
2 | 3 | from typing import Optional, Union
|
3 | 4 |
|
|
13 | 14 | from pytensor.graph.replace import vectorize_node
|
14 | 15 | from pytensor.raise_op import assert_op
|
15 | 16 | from pytensor.tensor import diagonal, log, tensor
|
16 |
| -from pytensor.tensor.blockwise import Blockwise |
| 17 | +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback |
17 | 18 | from pytensor.tensor.nlinalg import MatrixInverse
|
18 | 19 | from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
|
19 | 20 | from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
|
@@ -43,6 +44,19 @@ def test_vectorize_blockwise():
|
43 | 44 | assert new_vect_node.inputs[0] is tns4
|
44 | 45 |
|
45 | 46 |
|
| 47 | +def test_vectorize_node_fallback_unsupported_type(): |
| 48 | + x = tensor("x", shape=(2, 6)) |
| 49 | + node = x[:, [0, 2, 4]].owner |
| 50 | + |
| 51 | + with pytest.raises( |
| 52 | + NotImplementedError, |
| 53 | + match=re.escape( |
| 54 | + "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice" |
| 55 | + ), |
| 56 | + ): |
| 57 | + vectorize_node_fallback(node.op, node, node.inputs) |
| 58 | + |
| 59 | + |
46 | 60 | def check_blockwise_runtime_broadcasting(mode):
|
47 | 61 | a = tensor("a", shape=(None, 3, 5))
|
48 | 62 | b = tensor("b", shape=(None, 5, 3))
|
|
0 commit comments