|
| 1 | +import re |
1 | 2 | from itertools import product
|
2 | 3 | from typing import Optional, Union
|
3 | 4 |
|
|
12 | 13 | from pytensor.graph.replace import vectorize_node
|
13 | 14 | from pytensor.raise_op import assert_op
|
14 | 15 | from pytensor.tensor import diagonal, log, tensor
|
15 |
| -from pytensor.tensor.blockwise import Blockwise |
| 16 | +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback |
16 | 17 | from pytensor.tensor.nlinalg import MatrixInverse
|
17 | 18 | from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
|
18 | 19 | from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
|
@@ -42,6 +43,19 @@ def test_vectorize_blockwise():
|
42 | 43 | assert new_vect_node.inputs[0] is tns4
|
43 | 44 |
|
44 | 45 |
|
| 46 | +def test_vectorize_node_fallback_unsupported_type(): |
| 47 | + x = tensor("x", shape=(2, 6)) |
| 48 | + node = x[:, [0, 2, 4]].owner |
| 49 | + |
| 50 | + with pytest.raises( |
| 51 | + NotImplementedError, |
| 52 | + match=re.escape( |
| 53 | + "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice" |
| 54 | + ), |
| 55 | + ): |
| 56 | + vectorize_node_fallback(node.op, node, node.inputs) |
| 57 | + |
| 58 | + |
45 | 59 | def check_blockwise_runtime_broadcasting(mode):
|
46 | 60 | a = tensor("a", shape=(None, 3, 5))
|
47 | 61 | b = tensor("b", shape=(None, 5, 3))
|
|
0 commit comments