|
3 | 3 |
|
4 | 4 | import pytensor
|
5 | 5 | import pytensor.tensor as pt
|
6 |
| -from pytensor.graph.replace import vectorize_node |
7 |
| -from pytensor.tensor import tensor |
8 |
| -from pytensor.tensor.blockwise import Blockwise |
9 |
| -from pytensor.tensor.nlinalg import MatrixInverse |
10 | 6 | from pytensor.tensor.shape import specify_broadcastable
|
11 | 7 |
|
12 | 8 |
|
13 | 9 | torch = pytest.importorskip("torch")
|
14 | 10 |
|
15 | 11 |
|
16 |
| -def test_vectorize_blockwise(): |
17 |
| - mat = tensor(shape=(None, None)) |
18 |
| - tns = tensor(shape=(None, None, None)) |
19 |
| - |
20 |
| - # Something that falls back to Blockwise |
21 |
| - node = MatrixInverse()(mat).owner |
22 |
| - vect_node = vectorize_node(node, tns) |
23 |
| - assert isinstance(vect_node.op, Blockwise) and isinstance( |
24 |
| - vect_node.op.core_op, MatrixInverse |
25 |
| - ) |
26 |
| - assert vect_node.op.signature == ("(m,m)->(m,m)") |
27 |
| - assert vect_node.inputs[0] is tns |
28 |
| - |
29 |
| - # Useless blockwise |
30 |
| - tns4 = tensor(shape=(5, None, None, None)) |
31 |
| - new_vect_node = vectorize_node(vect_node, tns4) |
32 |
| - assert new_vect_node.op is vect_node.op |
33 |
| - assert isinstance(new_vect_node.op, Blockwise) and isinstance( |
34 |
| - new_vect_node.op.core_op, MatrixInverse |
35 |
| - ) |
36 |
| - assert new_vect_node.inputs[0] is tns4 |
37 |
| - |
38 |
| - |
39 | 12 | def test_blockwise_broadcast():
|
40 | 13 | _x = np.random.rand(5, 1, 2, 3)
|
41 | 14 | _y = np.random.rand(3, 3, 2)
|
|
0 commit comments