Skip to content

Commit 734009a

Browse files
committed
Improve static output shape of AdvancedSubtensor1
1 parent 3db127e commit 734009a

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,8 +1992,7 @@ def make_node(self, x, ilist):
19921992
raise TypeError("index must be vector")
19931993
if x_.type.ndim == 0:
19941994
raise TypeError("cannot index into a scalar")
1995-
out_shape = (ilist_.type.shape[0],) + x_.type.shape[1:]
1996-
out_shape = tuple(1 if s == 1 else None for s in out_shape)
1995+
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
19971996
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
19981997

19991998
def perform(self, node, inp, out_):

tests/tensor/test_subtensor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
advanced_inc_subtensor1,
3535
advanced_set_subtensor,
3636
advanced_set_subtensor1,
37+
advanced_subtensor1,
3738
as_index_literal,
3839
basic_shape,
3940
get_canonical_form_slice,
@@ -2707,12 +2708,26 @@ def test_index_vars_to_types():
27072708
[(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
27082709
],
27092710
)
2710-
def test_static_shapes(x_shape, indices, expected):
2711+
def test_subtensor_static_shapes(x_shape, indices, expected):
27112712
x = ptb.tensor(dtype="float64", shape=x_shape)
27122713
y = x[indices]
27132714
assert y.type.shape == expected
27142715

27152716

2717+
@pytest.mark.parametrize(
2718+
"x_shape, indices, expected",
2719+
[
2720+
[(None, 5, None, 3), vector(shape=(1,)), (1, 5, None, 3)],
2721+
[(None, 5, None, 3), vector(shape=(2,)), (2, 5, None, 3)],
2722+
[(None, 5, None, 3), vector(shape=(None,)), (None, 5, None, 3)],
2723+
],
2724+
)
2725+
def test_advanced_subtensor1_static_shapes(x_shape, indices, expected):
2726+
x = ptb.tensor(dtype="float64", shape=x_shape)
2727+
y = advanced_subtensor1(x, indices.astype(int))
2728+
assert y.type.shape == expected
2729+
2730+
27162731
def test_vectorize_subtensor_without_batch_indices():
27172732
signature = "(t1,t2,t3),()->(t1,t3)"
27182733

0 commit comments

Comments
 (0)