Skip to content

Compute static shape types in outputs of Join #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 4, 2023

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jan 3, 2023

Motivation for these changes

Making static shapes propagate through Joins such as pt.concatenate.

Closes #163

Implementation details

The previous double-for implementation was rather tricky to refactor, which is why I went with a simple matrix-based approach.

The first commit makes the corresponding tests and their output easier to read.

Checklist

New features

  • Join now propagate static shape information more reliably.

Maintenance

  • Docs build was updated to Python 3.9

@michaelosthege michaelosthege added enhancement New feature or request shape inference labels Jan 3, 2023
Comment on lines +2288 to +2291
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not relevant for the len(tensors) == 1 branch

f = self.join_op(0, a, b, c, d, e)
fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
assert f.type.shape == (5, 1, 1, 1, None, 1)
assert fb == (False, True, True, True, False, True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the broadcastability assers even though this information is redundant from the current interpretation of shape == 1.
If in the future we re-introduce a broadcastable attribute, this should help updating the test.

@michaelosthege michaelosthege force-pushed the join-shapes branch 4 times, most recently from 9278ef7 to d8896ea Compare January 3, 2023 16:56
- gcc_linux-64
- gxx_linux-64
- numpy
- scipy
- six
- sphinx>=5.1.0
- sphinx>=5.1.0,<6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add this to https://github.com/pymc-devs/pymc-sphinx-theme, otherwise we'll soon have the same issue in pymc, pymc-examples... (if we aren't already). We use the theme straight from github so the effect will be immediate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelosthege michaelosthege marked this pull request as ready for review January 3, 2023 17:11
# Other dims must match exactly,
# or if a mix of None and ? the output will be ?
# otherwise the input shapes are incompatible.
if len(inset) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does join allow broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like it does - if so it wasn't tested, and NumPy does not allow broadcasting inputs to np.concatenate either

@ricardoV94 ricardoV94 changed the title Join shapes Provide static shape types in Join outputs Jan 3, 2023
@ricardoV94 ricardoV94 changed the title Provide static shape types in Join outputs Provide static shape types in outputs of Join Jan 3, 2023
@ricardoV94 ricardoV94 changed the title Provide static shape types in outputs of Join Compute static shape types in outputs of Join Jan 3, 2023
@twiecki twiecki merged commit 25236cf into pymc-devs:main Jan 4, 2023
@michaelosthege michaelosthege deleted the join-shapes branch January 4, 2023 09:22
@covertg covertg mentioned this pull request Jan 31, 2023
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Join looses static shape information
4 participants