|
34 | 34 | advanced_inc_subtensor1,
|
35 | 35 | advanced_set_subtensor,
|
36 | 36 | advanced_set_subtensor1,
|
| 37 | + advanced_subtensor1, |
37 | 38 | as_index_literal,
|
38 | 39 | basic_shape,
|
39 | 40 | get_canonical_form_slice,
|
@@ -2707,12 +2708,26 @@ def test_index_vars_to_types():
|
2707 | 2708 | [(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
|
2708 | 2709 | ],
|
2709 | 2710 | )
|
2710 |
| -def test_static_shapes(x_shape, indices, expected): |
| 2711 | +def test_subtensor_static_shapes(x_shape, indices, expected): |
2711 | 2712 | x = ptb.tensor(dtype="float64", shape=x_shape)
|
2712 | 2713 | y = x[indices]
|
2713 | 2714 | assert y.type.shape == expected
|
2714 | 2715 |
|
2715 | 2716 |
|
| 2717 | +@pytest.mark.parametrize( |
| 2718 | + "x_shape, indices, expected", |
| 2719 | + [ |
| 2720 | + [(None, 5, None, 3), vector(shape=(1,)), (1, 5, None, 3)], |
| 2721 | + [(None, 5, None, 3), vector(shape=(2,)), (2, 5, None, 3)], |
| 2722 | + [(None, 5, None, 3), vector(shape=(None,)), (None, 5, None, 3)], |
| 2723 | + ], |
| 2724 | +) |
| 2725 | +def test_advanced_subtensor1_static_shapes(x_shape, indices, expected): |
| 2726 | + x = ptb.tensor(dtype="float64", shape=x_shape) |
| 2727 | + y = advanced_subtensor1(x, indices.astype(int)) |
| 2728 | + assert y.type.shape == expected |
| 2729 | + |
| 2730 | + |
2716 | 2731 | def test_vectorize_subtensor_without_batch_indices():
|
2717 | 2732 | signature = "(t1,t2,t3),()->(t1,t3)"
|
2718 | 2733 |
|
|
0 commit comments