-
Notifications
You must be signed in to change notification settings - Fork 132
Compute static shape types in outputs of Join
#164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
if not builtins.all(x.ndim == len(out_shape) for x in tensors): | ||
raise TypeError( | ||
"Only tensors with the same number of dimensions can be joined" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not relevant for the len(tensors) == 1
branch
f = self.join_op(0, a, b, c, d, e) | ||
fb = tuple(s == 1 for s in f.type.shape) | ||
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5] | ||
assert f.type.shape == (5, 1, 1, 1, None, 1) | ||
assert fb == (False, True, True, True, False, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept the broadcastability assers even though this information is redundant from the current interpretation of shape == 1
.
If in the future we re-introduce a broadcastable
attribute, this should help updating the test.
9278ef7
to
d8896ea
Compare
- gcc_linux-64 | ||
- gxx_linux-64 | ||
- numpy | ||
- scipy | ||
- six | ||
- sphinx>=5.1.0 | ||
- sphinx>=5.1.0,<6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add this to https://github.com/pymc-devs/pymc-sphinx-theme, otherwise we'll soon have the same issue in pymc, pymc-examples... (if we aren't already). We use the theme straight from github so the effect will be immediate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
513c722
to
78ce076
Compare
# Other dims must match exactly, | ||
# or if a mix of None and ? the output will be ? | ||
# otherwise the input shapes are incompatible. | ||
if len(inset) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does join allow broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't look like it does - if so it wasn't tested, and NumPy does not allow broadcasting inputs to np.concatenate
either
Join
Join
Join
Motivation for these changes
Making static shapes propagate through
Join
s such aspt.concatenate
.Closes #163
Implementation details
The previous double-
for
implementation was rather tricky to refactor, which is why I went with a simple matrix-based approach.The first commit makes the corresponding tests and their output easier to read.
Checklist
New features
Join
now propagate static shape information more reliably.Maintenance