|
9 | 9 | import pytensor
|
10 | 10 | import pytensor.scalar as scal
|
11 | 11 | import pytensor.tensor.basic as at
|
| 12 | +from pytensor import function |
12 | 13 | from pytensor.compile import DeepCopyOp, shared
|
13 | 14 | from pytensor.compile.io import In
|
14 | 15 | from pytensor.configdefaults import config
|
15 | 16 | from pytensor.graph.op import get_test_value
|
16 | 17 | from pytensor.graph.rewriting.utils import is_same_graph
|
17 | 18 | from pytensor.printing import pprint
|
18 | 19 | from pytensor.scalar.basic import as_scalar
|
19 |
| -from pytensor.tensor import get_vector_length |
| 20 | +from pytensor.tensor import get_vector_length, vectorize |
| 21 | +from pytensor.tensor.blockwise import Blockwise |
20 | 22 | from pytensor.tensor.elemwise import DimShuffle
|
21 | 23 | from pytensor.tensor.math import exp, isinf
|
22 | 24 | from pytensor.tensor.math import sum as at_sum
|
@@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected):
|
2709 | 2711 | x = at.tensor(dtype="float64", shape=x_shape)
|
2710 | 2712 | y = x[indices]
|
2711 | 2713 | assert y.type.shape == expected
|
| 2714 | + |
| 2715 | + |
| 2716 | +def test_vectorize_subtensor_without_batch_indices(): |
| 2717 | + signature = "(t1,t2,t3),()->(t1,t3)" |
| 2718 | + |
| 2719 | + def core_fn(x, start): |
| 2720 | + return x[:, start, :] |
| 2721 | + |
| 2722 | + x = tensor(shape=(11, 7, 5, 3)) |
| 2723 | + start = tensor(shape=(), dtype="int") |
| 2724 | + vectorize_pt = function( |
| 2725 | + [x, start], vectorize(core_fn, signature=signature)(x, start) |
| 2726 | + ) |
| 2727 | + assert not any( |
| 2728 | + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes |
| 2729 | + ) |
| 2730 | + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) |
| 2731 | + start_test = np.random.randint(0, x.type.shape[-2]) |
| 2732 | + vectorize_np = np.vectorize(core_fn, signature=signature) |
| 2733 | + np.testing.assert_allclose( |
| 2734 | + vectorize_pt(x_test, start_test), |
| 2735 | + vectorize_np(x_test, start_test), |
| 2736 | + ) |
| 2737 | + |
| 2738 | + # If we vectorize start, we should get a Blockwise that still works |
| 2739 | + x = tensor(shape=(11, 7, 5, 3)) |
| 2740 | + start = tensor(shape=(11,), dtype="int") |
| 2741 | + vectorize_pt = function( |
| 2742 | + [x, start], vectorize(core_fn, signature=signature)(x, start) |
| 2743 | + ) |
| 2744 | + assert any( |
| 2745 | + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes |
| 2746 | + ) |
| 2747 | + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) |
| 2748 | + start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0]) |
| 2749 | + vectorize_np = np.vectorize(core_fn, signature=signature) |
| 2750 | + np.testing.assert_allclose( |
| 2751 | + vectorize_pt(x_test, start_test), |
| 2752 | + vectorize_np(x_test, start_test), |
| 2753 | + ) |
0 commit comments