Skip to content

Commit d159f06

Browse files
authored
Fix bug in AdvancedSubtensor infer_shape (#101)
* Fix bug in AdvancedSubtensor infer_shape The underlying utility `indexed_result_shape` was off by 1 in terms of when do the advanced index operations have to be brought to the front of the array.
1 parent befc177 commit d159f06

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

pytensor/tensor/subtensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
489489
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
490490
idx_groups = group_indices(indices)
491491

492-
if len(idx_groups) > 2 or len(idx_groups) > 1 and not idx_groups[0][0]:
493-
# Bring adv. index groups to the front and merge each group
492+
if len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]):
493+
# This means that there are at least two groups of advanced indexing separated by basic indexing
494+
# In this case NumPy places the advanced index groups in the front of the array
495+
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
494496
idx_groups = sorted(idx_groups, key=lambda x: x[0])
495497
idx_groups = groupby(
496498
chain.from_iterable(d_idx for _, d_idx in idx_groups),

tests/tensor/test_subtensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,6 +2517,10 @@ def bcast_shape_tuple(x):
25172517
),
25182518
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
25192519
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
2520+
(
2521+
np.arange(np.prod((5, 6, 7))).reshape((5, 6, 7)),
2522+
(slice(None, None), [1, 2, 3], slice(None, None)),
2523+
),
25202524
],
25212525
)
25222526
@config.change_flags(compute_test_value="raise")

0 commit comments

Comments
 (0)