Skip to content

Commit 7cd054e

Browse files
committed
Fix bug in shape inference of AdvancedSubtensor with slices
1 parent 3e02efa commit 7cd054e

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

pytensor/tensor/subtensor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
zscalar,
6060
)
6161
from pytensor.tensor.type_other import (
62+
MakeSlice,
6263
NoneConst,
6364
NoneTypeT,
6465
SliceConstant,
@@ -527,11 +528,20 @@ def basic_shape(shape, indices):
527528
if isinstance(idx, slice):
528529
res_shape += (slice_len(idx, n),)
529530
elif isinstance(getattr(idx, "type", None), SliceType):
530-
if idx.owner:
531-
idx_inputs = idx.owner.inputs
531+
if idx.owner is None:
532+
if not isinstance(idx, Constant):
533+
# This is an input slice, we can't reason symbolically on it.
534+
# We don't even know if we will get None entries or integers
535+
res_shape += (None,)
536+
continue
537+
else:
538+
sl: slice = idx.data
539+
slice_inputs = (sl.start, sl.stop, sl.step)
540+
elif isinstance(idx.owner.op, MakeSlice):
541+
slice_inputs = idx.owner.inputs
532542
else:
533-
idx_inputs = (None,)
534-
res_shape += (slice_len(slice(*idx_inputs), n),)
543+
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
544+
res_shape += (slice_len(slice(*slice_inputs), n),)
535545
elif idx is None:
536546
res_shape += (ps.ScalarConstant(ps.int64, 1),)
537547
elif isinstance(getattr(idx, "type", None), NoneTypeT):
@@ -2728,6 +2738,11 @@ def is_bool_index(idx):
27282738
res_shape = list(
27292739
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
27302740
)
2741+
for i, res_dim_length in enumerate(res_shape):
2742+
if res_dim_length is None:
2743+
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
2744+
# We must compute the Op to find its shape
2745+
res_shape[i] = Shape_i(i)(node.out)
27312746

27322747
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
27332748
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]

tests/tensor/test_subtensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.compile.mode import Mode
1616
from pytensor.configdefaults import config
1717
from pytensor.gradient import grad
18+
from pytensor.graph import Constant
1819
from pytensor.graph.op import get_test_value
1920
from pytensor.graph.rewriting.utils import is_same_graph
2021
from pytensor.printing import pprint
@@ -37,6 +38,7 @@
3738
advanced_inc_subtensor1,
3839
advanced_set_subtensor,
3940
advanced_set_subtensor1,
41+
advanced_subtensor,
4042
advanced_subtensor1,
4143
as_index_literal,
4244
basic_shape,
@@ -2145,7 +2147,17 @@ def test_adv_sub_slice(self):
21452147
slc = slicetype()
21462148
f = pytensor.function([slc], var[slc], mode=self.mode)
21472149
s = slice(1, 3)
2148-
f(s)
2150+
assert f(s).shape == (2, 3)
2151+
2152+
f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode)
2153+
assert f_shape0(s) == 2
2154+
2155+
f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode)
2156+
assert not any(
2157+
isinstance(node.op, AdvancedSubtensor)
2158+
for node in f_shape1.maker.fgraph.toposort()
2159+
)
2160+
assert f_shape1(s) == 3
21492161

21502162
def test_adv_grouped(self):
21512163
# Reported in https://github.com/Theano/Theano/issues/6152
@@ -2611,6 +2623,14 @@ def test_AdvancedSubtensor_bool_mixed(self):
26112623
AdvancedSubtensor,
26122624
)
26132625

2626+
def test_advanced_subtensor_constant_slice(self):
2627+
x = dmatrix("x")
2628+
constant_slice = pytensor.as_symbolic(slice(1, None, None))
2629+
assert isinstance(constant_slice, Constant)
2630+
adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int")
2631+
y = advanced_subtensor(x, constant_slice, adv_indices)
2632+
assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3)
2633+
26142634

26152635
@config.change_flags(compute_test_value="raise")
26162636
def test_basic_shape():

0 commit comments

Comments
 (0)