Skip to content

Compute static shape types in outputs of Join #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# sys.path.append(os.path.abspath('some/directory'))

import os
import sys
import pytensor

# General configuration
Expand Down Expand Up @@ -60,7 +61,7 @@
if os.environ.get("READTHEDOCS", False):
rtd_version = os.environ.get("READTHEDOCS_VERSION", "")
if rtd_version.lower() == "stable":
version = pymc.__version__.split("+")[0]
version = pytensor.__version__.split("+")[0]
elif rtd_version.lower() == "latest":
version = "dev"
else:
Expand Down
4 changes: 2 additions & 2 deletions doc/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ channels:
- conda-forge
- nodefaults
dependencies:
- python=3.7
- python=3.9
- gcc_linux-64
- gxx_linux-64
- numpy
- scipy
- six
- sphinx>=5.1.0
- sphinx>=5.1.0,<6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add this to https://github.com/pymc-devs/pymc-sphinx-theme, otherwise we'll soon have the same issue in pymc, pymc-examples... (if we aren't already). We use the theme straight from github so the effect will be immediate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- mock
- pillow
- pip
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- pytest-xdist
- pytest-benchmark
# For building docs
- sphinx>=5.1.0
- sphinx>=5.1.0,<6
- sphinx_rtd_theme
- pygments
- pydot
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ tests = [
"pytest-benchmark",
]
rtd = [
"sphinx>=1.3.0",
"sphinx_rtd_theme",
"sphinx>=5.1.0,<6",
"pygments",
"pydot",
"pydot2",
Expand Down
60 changes: 42 additions & 18 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2217,15 +2217,14 @@ def make_node(self, axis, *tensors):
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# the loops.
ndim = tensors[0].type.ndim
out_shape = [None] * ndim

if not isinstance(axis, int):
try:
axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
pass

ndim = tensors[0].type.ndim
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T
Expand All @@ -2241,30 +2240,55 @@ def make_node(self, axis, *tensors):
)
if axis < 0:
axis += ndim

for x in tensors:
for current_axis, s in enumerate(x.type.shape):
# Constant negative axis can no longer be negative at
# this point. It safe to compare this way.
if current_axis == axis:
continue
if s == 1:
out_shape[current_axis] = 1
try:
out_shape[axis] = None
except IndexError:
if axis > ndim - 1:
raise ValueError(
f"Axis value {axis} is out of range for the given input dimensions"
)
# NOTE: Constant negative axis can no longer be negative at this point.

in_shapes = [x.type.shape for x in tensors]
in_ndims = [len(s) for s in in_shapes]
if set(in_ndims) != {ndim}:
raise TypeError(
"Only tensors with the same number of dimensions can be joined."
f" Input ndims were: {in_ndims}."
)

# Determine output shapes from a matrix of input shapes
in_shapes = np.array(in_shapes)
out_shape = [None] * ndim
for d in range(ndim):
ins = in_shapes[:, d]
if d == axis:
# Any unknown size along the axis means we can't sum
if None in ins:
out_shape[d] = None
else:
out_shape[d] = sum(ins)
else:
inset = set(in_shapes[:, d])
# Other dims must match exactly,
# or if a mix of None and ? the output will be ?
# otherwise the input shapes are incompatible.
if len(inset) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does join allow broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like it does - if so it wasn't tested, and NumPy does not allow broadcasting inputs to np.concatenate either

(out_shape[d],) = inset
elif len(inset - {None}) == 1:
(out_shape[d],) = inset - {None}
else:
raise ValueError(
f"all input array dimensions other than the specified `axis` ({axis})"
" must match exactly, or be unknown (None),"
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
)
else:
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
out_shape = [None] * tensors[0].type.ndim

if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
Comment on lines +2288 to +2291
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not relevant for the len(tensors) == 1 branch


inputs = [as_tensor_variable(axis)] + list(tensors)

Expand Down
1 change: 0 additions & 1 deletion pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,6 @@ def movable(i):
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
assert mi.owner.inputs[0].type.is_super(tip.type)
tip = mi.owner.op(tip, *mi.owner.inputs[1:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
Expand Down
38 changes: 28 additions & 10 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,18 +1667,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(1, a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)

# Opt can remplace the int by an PyTensor constant
c = self.join_op(constant(1), a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)

# In case futur opt insert other useless stuff
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)

f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand Down Expand Up @@ -1783,15 +1780,21 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
c = TensorType(dtype=self.floatX, shape=(1, None, None, None, None, None))()
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()

f = self.join_op(0, a, b, c, d, e)
fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
assert f.type.shape == (5, 1, 1, 1, None, 1)
assert fb == (False, True, True, True, False, True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the broadcastability assers even though this information is redundant from the current interpretation of shape == 1.
If in the future we re-introduce a broadcastable attribute, this should help updating the test.


g = self.join_op(1, a, b, c, d, e)
gb = tuple(s == 1 for s in g.type.shape)
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
assert g.type.shape == (1, None, 1, 1, None, 1)
assert gb == (True, False, True, True, False, True)

h = self.join_op(4, a, b, c, d, e)
hb = tuple(s == 1 for s in h.type.shape)
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
assert h.type.shape == (1, 1, 1, 1, None, 1)
assert hb == (True, True, True, True, False, True)

f = function([a, b, c, d, e], f, mode=self.mode)
topo = f.maker.fgraph.toposort()
Expand Down Expand Up @@ -1903,9 +1906,24 @@ def test_mixed_ndim_error(self):
rng = np.random.default_rng(seed=utt.fetch_seed())
v = self.shared(rng.random(4).astype(self.floatX))
m = self.shared(rng.random((4, 4)).astype(self.floatX))
with pytest.raises(TypeError):
with pytest.raises(TypeError, match="same number of dimensions"):
self.join_op(0, v, m)

def test_static_shape_inference(self):
a = at.tensor(dtype="int8", shape=(2, 3))
b = at.tensor(dtype="int8", shape=(2, 5))
assert at.join(1, a, b).type.shape == (2, 8)
assert at.join(-1, a, b).type.shape == (2, 8)

# Check early informative errors from static shape info
with pytest.raises(ValueError, match="must match exactly"):
at.join(0, at.ones((2, 3)), at.ones((2, 5)))

# Check partial inference
d = at.tensor(dtype="int8", shape=(2, None))
assert at.join(1, a, b, d).type.shape == (2, None)
return

def test_split_0elem(self):
rng = np.random.default_rng(seed=utt.fetch_seed())
m = self.shared(rng.random((4, 6)).astype(self.floatX))
Expand Down