-
Notifications
You must be signed in to change notification settings - Fork 135
Fix bug in AdvancedSubtensor infer_shape #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in AdvancedSubtensor infer_shape #101
Conversation
The underlying utility `indexed_result_shape` was off by 1 in terms of when do the advanced index operations have to be brought to the front of the array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, apart from the possible parenthesis nitpick :-)
pytensor/tensor/subtensor.py
Outdated
@@ -489,8 +489,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): | |||
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape)) | |||
idx_groups = group_indices(indices) | |||
|
|||
if len(idx_groups) > 2 or len(idx_groups) > 1 and not idx_groups[0][0]: | |||
# Bring adv. index groups to the front and merge each group | |||
if len(idx_groups) > 3 or len(idx_groups) == 3 and not idx_groups[0][0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to look up operator precedence for and
and or
here. :-)
Turns out and has higher precedence, which I think is correct here.
Is that just me, or should we maybe add parenthesis here to make it obvious?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always in favor of parenthesis. Brings us closer to lisp!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #101 +/- ##
==========================================
- Coverage 80.42% 74.26% -6.17%
==========================================
Files 170 175 +5
Lines 45376 48929 +3553
Branches 11082 10395 -687
==========================================
- Hits 36495 36335 -160
- Misses 6654 10291 +3637
- Partials 2227 2303 +76
|
The underlying utility
indexed_result_shape
was off by 1 in terms of when do the advanced index operations have to be brought to the front of the array.Closes #98