-
Notifications
You must be signed in to change notification settings - Fork 132
Compute static shape types in outputs of Join
#164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2217,15 +2217,14 @@ def make_node(self, axis, *tensors): | |
# except for the axis dimension. | ||
# Initialize bcastable all false, and then fill in some trues with | ||
# the loops. | ||
ndim = tensors[0].type.ndim | ||
out_shape = [None] * ndim | ||
|
||
if not isinstance(axis, int): | ||
try: | ||
axis = int(get_scalar_constant_value(axis)) | ||
except NotScalarConstantError: | ||
pass | ||
|
||
ndim = tensors[0].type.ndim | ||
if isinstance(axis, int): | ||
# Basically, broadcastable -> length 1, but the | ||
# converse does not hold. So we permit e.g. T/F/T | ||
|
@@ -2241,30 +2240,55 @@ def make_node(self, axis, *tensors): | |
) | ||
if axis < 0: | ||
axis += ndim | ||
|
||
for x in tensors: | ||
for current_axis, s in enumerate(x.type.shape): | ||
# Constant negative axis can no longer be negative at | ||
# this point. It safe to compare this way. | ||
if current_axis == axis: | ||
continue | ||
if s == 1: | ||
out_shape[current_axis] = 1 | ||
try: | ||
out_shape[axis] = None | ||
except IndexError: | ||
if axis > ndim - 1: | ||
raise ValueError( | ||
f"Axis value {axis} is out of range for the given input dimensions" | ||
) | ||
# NOTE: Constant negative axis can no longer be negative at this point. | ||
|
||
in_shapes = [x.type.shape for x in tensors] | ||
in_ndims = [len(s) for s in in_shapes] | ||
if set(in_ndims) != {ndim}: | ||
raise TypeError( | ||
"Only tensors with the same number of dimensions can be joined." | ||
f" Input ndims were: {in_ndims}." | ||
) | ||
|
||
# Determine output shapes from a matrix of input shapes | ||
in_shapes = np.array(in_shapes) | ||
out_shape = [None] * ndim | ||
for d in range(ndim): | ||
ins = in_shapes[:, d] | ||
if d == axis: | ||
# Any unknown size along the axis means we can't sum | ||
if None in ins: | ||
out_shape[d] = None | ||
else: | ||
out_shape[d] = sum(ins) | ||
else: | ||
inset = set(in_shapes[:, d]) | ||
# Other dims must match exactly, | ||
# or if a mix of None and ? the output will be ? | ||
# otherwise the input shapes are incompatible. | ||
if len(inset) == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does join allow broadcasting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't look like it does - if so it wasn't tested, and NumPy does not allow broadcasting inputs to |
||
(out_shape[d],) = inset | ||
elif len(inset - {None}) == 1: | ||
(out_shape[d],) = inset - {None} | ||
else: | ||
raise ValueError( | ||
f"all input array dimensions other than the specified `axis` ({axis})" | ||
" must match exactly, or be unknown (None)," | ||
f" but along dimension {d}, the inputs shapes are incompatible: {ins}" | ||
) | ||
else: | ||
# When the axis may vary, no dimension can be guaranteed to be | ||
# broadcastable. | ||
out_shape = [None] * tensors[0].type.ndim | ||
|
||
if not builtins.all(x.ndim == len(out_shape) for x in tensors): | ||
raise TypeError( | ||
"Only tensors with the same number of dimensions can be joined" | ||
) | ||
if not builtins.all(x.ndim == len(out_shape) for x in tensors): | ||
raise TypeError( | ||
"Only tensors with the same number of dimensions can be joined" | ||
) | ||
Comment on lines
+2288
to
+2291
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not relevant for the |
||
|
||
inputs = [as_tensor_variable(axis)] + list(tensors) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1667,18 +1667,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self): | |
a = self.shared(a_val, shape=(None, None, 1)) | ||
b = self.shared(b_val, shape=(1, None, 1)) | ||
c = self.join_op(1, a, b) | ||
assert c.type.shape[0] == 1 and c.type.shape[2] == 1 | ||
assert c.type.shape[1] != 1 | ||
assert c.type.shape == (1, None, 1) | ||
|
||
# Opt can remplace the int by an PyTensor constant | ||
c = self.join_op(constant(1), a, b) | ||
assert c.type.shape[0] == 1 and c.type.shape[2] == 1 | ||
assert c.type.shape[1] != 1 | ||
assert c.type.shape == (1, None, 1) | ||
|
||
# In case futur opt insert other useless stuff | ||
c = self.join_op(cast(constant(1), dtype="int32"), a, b) | ||
assert c.type.shape[0] == 1 and c.type.shape[2] == 1 | ||
assert c.type.shape[1] != 1 | ||
assert c.type.shape == (1, None, 1) | ||
|
||
f = function([], c, mode=self.mode) | ||
topo = f.maker.fgraph.toposort() | ||
|
@@ -1783,15 +1780,21 @@ def test_broadcastable_flags_many_dims_and_inputs(self): | |
c = TensorType(dtype=self.floatX, shape=(1, None, None, None, None, None))() | ||
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))() | ||
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))() | ||
|
||
f = self.join_op(0, a, b, c, d, e) | ||
fb = tuple(s == 1 for s in f.type.shape) | ||
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5] | ||
assert f.type.shape == (5, 1, 1, 1, None, 1) | ||
assert fb == (False, True, True, True, False, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept the broadcastability assers even though this information is redundant from the current interpretation of |
||
|
||
g = self.join_op(1, a, b, c, d, e) | ||
gb = tuple(s == 1 for s in g.type.shape) | ||
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5] | ||
assert g.type.shape == (1, None, 1, 1, None, 1) | ||
assert gb == (True, False, True, True, False, True) | ||
|
||
h = self.join_op(4, a, b, c, d, e) | ||
hb = tuple(s == 1 for s in h.type.shape) | ||
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5] | ||
assert h.type.shape == (1, 1, 1, 1, None, 1) | ||
assert hb == (True, True, True, True, False, True) | ||
|
||
f = function([a, b, c, d, e], f, mode=self.mode) | ||
topo = f.maker.fgraph.toposort() | ||
|
@@ -1903,9 +1906,24 @@ def test_mixed_ndim_error(self): | |
rng = np.random.default_rng(seed=utt.fetch_seed()) | ||
v = self.shared(rng.random(4).astype(self.floatX)) | ||
m = self.shared(rng.random((4, 4)).astype(self.floatX)) | ||
with pytest.raises(TypeError): | ||
with pytest.raises(TypeError, match="same number of dimensions"): | ||
self.join_op(0, v, m) | ||
|
||
def test_static_shape_inference(self): | ||
a = at.tensor(dtype="int8", shape=(2, 3)) | ||
b = at.tensor(dtype="int8", shape=(2, 5)) | ||
assert at.join(1, a, b).type.shape == (2, 8) | ||
assert at.join(-1, a, b).type.shape == (2, 8) | ||
|
||
# Check early informative errors from static shape info | ||
with pytest.raises(ValueError, match="must match exactly"): | ||
at.join(0, at.ones((2, 3)), at.ones((2, 5))) | ||
|
||
# Check partial inference | ||
d = at.tensor(dtype="int8", shape=(2, None)) | ||
assert at.join(1, a, b, d).type.shape == (2, None) | ||
return | ||
|
||
def test_split_0elem(self): | ||
rng = np.random.default_rng(seed=utt.fetch_seed()) | ||
m = self.shared(rng.random((4, 6)).astype(self.floatX)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add this to https://github.com/pymc-devs/pymc-sphinx-theme, otherwise we'll soon have the same issue in pymc, pymc-examples... (if we aren't already). We use the theme straight from github so the effect will be immediate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done: pymc-devs/pymc-sphinx-theme@973f6dd