Skip to content

Commit 4d261b3

Browse files
Change infer_broadcastable to infer_static_shape
1 parent e4b15e4 commit 4d261b3

File tree

5 files changed

+51
-28
lines changed

5 files changed

+51
-28
lines changed

aesara/tensor/basic.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections.abc import Sequence
1111
from functools import partial
1212
from numbers import Number
13-
from typing import Optional
13+
from typing import TYPE_CHECKING, Optional
1414
from typing import Sequence as TypeSequence
1515
from typing import Tuple, Union
1616
from typing import cast as type_cast
@@ -68,6 +68,10 @@
6868
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
6969

7070

71+
if TYPE_CHECKING:
72+
from aesara.tensor import TensorLike
73+
74+
7175
def __oplist_tag(thing, tag):
7276
tags = getattr(thing, "__oplist_tags", [])
7377
tags.append(tag)
@@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13341338
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
13351339

13361340

1337-
def infer_broadcastable(shape):
1338-
"""Infer the broadcastable dimensions for `shape`.
1341+
def infer_static_shape(
1342+
shape: Union[Variable, TypeSequence[Union[Variable, int]]]
1343+
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
1344+
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
1345+
1346+
`shape` will be validated and constant folded. As a result, this function
1347+
can be expensive and shouldn't be used unless absolutely necessary.
1348+
1349+
It mostly exists as a hold-over from pre-static shape times, when it was
1350+
required in order to produce correct broadcastable arrays and prevent
1351+
some graphs from being unusable. Now, it is no longer strictly required,
1352+
so don't use it unless you want the same shape graphs to be rewritten
1353+
multiple times during graph construction.
1354+
1355+
Returns
1356+
-------
1357+
A validated sequence of symbolic shape values, and a sequence of
1358+
``None``/``int`` values that can be used as `TensorType.shape` values.
13391359
1340-
`shape` will be validated and constant folded in order to determine
1341-
which dimensions are broadcastable (i.e. equal to ``1``).
13421360
"""
13431361
from aesara.tensor.rewriting.basic import topo_constant_folding
13441362
from aesara.tensor.rewriting.shape import ShapeFeature
@@ -1362,9 +1380,10 @@ def check_type(s):
13621380
clone=True,
13631381
)
13641382
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
1365-
1366-
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
1367-
return sh, bcast
1383+
static_shape = tuple(
1384+
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
1385+
)
1386+
return sh, static_shape
13681387

13691388

13701389
class Alloc(COp):
@@ -1394,15 +1413,15 @@ class Alloc(COp):
13941413

13951414
def make_node(self, value, *shape):
13961415
v = as_tensor_variable(value)
1397-
sh, bcast = infer_broadcastable(shape)
1416+
sh, static_shape = infer_static_shape(shape)
13981417
if v.ndim > len(sh):
13991418
raise TypeError(
14001419
"The Alloc value to use has more dimensions"
14011420
" than the specified dimensions",
14021421
v.ndim,
14031422
len(sh),
14041423
)
1405-
otype = TensorType(dtype=v.dtype, shape=bcast)
1424+
otype = TensorType(dtype=v.dtype, shape=static_shape)
14061425
return Apply(self, [v] + sh, [otype()])
14071426

14081427
def perform(self, node, inputs, out_):
@@ -3823,8 +3842,8 @@ def typecode(self):
38233842
return np.dtype(self.dtype).num
38243843

38253844
def make_node(self, *_shape):
3826-
_shape, bcast = infer_broadcastable(_shape)
3827-
otype = TensorType(dtype=self.dtype, shape=bcast)
3845+
_shape, static_shape = infer_static_shape(_shape)
3846+
otype = TensorType(dtype=self.dtype, shape=static_shape)
38283847
output = otype()
38293848

38303849
output.tag.values_eq_approx = values_eq_approx_always_true

aesara/tensor/extra_ops.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,9 +1646,9 @@ def __call__(self, a, shape, **kwargs):
16461646
def make_node(self, a, *shape):
16471647
a = at.as_tensor_variable(a)
16481648

1649-
shape, bcast = at.infer_broadcastable(shape)
1649+
shape, static_shape = at.infer_static_shape(shape)
16501650

1651-
out = TensorType(dtype=a.type.dtype, shape=bcast)()
1651+
out = TensorType(dtype=a.type.dtype, shape=static_shape)()
16521652

16531653
# Attempt to prevent in-place operations on this view-based output
16541654
out.tag.indestructible = True
@@ -1670,11 +1670,14 @@ def grad(self, inputs, outputs_gradients):
16701670
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
16711671

16721672
# Determine the dimensions that were broadcast
1673-
_, shape_bcast = at.infer_broadcastable(shape)
1673+
_, static_shape = at.infer_static_shape(shape)
1674+
1675+
# TODO: This needs to be performed at run-time when static shape
1676+
# information isn't available.
16741677
bcast_sums = [
16751678
i
1676-
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
1677-
if a_b and not s_b
1679+
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
1680+
if a_s == 1 and s_s != 1
16781681
]
16791682

16801683
if bcast_sums:

aesara/tensor/random/op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
constant,
1515
get_scalar_constant_value,
1616
get_vector_length,
17-
infer_broadcastable,
17+
infer_static_shape,
1818
)
1919
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2020
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
@@ -322,7 +322,7 @@ def make_node(self, rng, size, dtype, *dist_params):
322322
)
323323

324324
shape = self._infer_shape(size, dist_params)
325-
_, bcast = infer_broadcastable(shape)
325+
_, static_shape = infer_static_shape(shape)
326326
dtype = self.dtype or dtype
327327

328328
if dtype == "floatX":
@@ -336,7 +336,7 @@ def make_node(self, rng, size, dtype, *dist_params):
336336
dtype_idx = constant(dtype, dtype="int64")
337337
dtype = all_dtypes[dtype_idx.data]
338338

339-
outtype = TensorType(dtype=dtype, shape=bcast)
339+
outtype = TensorType(dtype=dtype, shape=static_shape)
340340
out_var = outtype()
341341
inputs = (rng, size, dtype_idx) + dist_params
342342
outputs = (rng.type(), out_var)

tests/tensor/rewriting/test_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,9 @@ def test_inconsistent_constant(self):
276276

277277
assert a.owner and isinstance(a.owner.op, Alloc)
278278

279-
# `local_useless_alloc` should replace the `Alloc` with an `Assert`
280-
with pytest.raises(AssertionError):
279+
# `local_useless_alloc` should attempt to replace the `Alloc` with an
280+
# `Assert` and fail when the static shape information conflicts.
281+
with pytest.raises(TypeError):
281282
f = function([], a, mode=rewrite_mode)
282283

283284
x = at.as_tensor(self.rng.standard_normal((6, 7)))

tests/tensor/test_basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
get_vector_length,
5656
horizontal_stack,
5757
identity_like,
58-
infer_broadcastable,
58+
infer_static_shape,
5959
inverse_permutation,
6060
join,
6161
make_vector,
@@ -796,20 +796,20 @@ def test_full(self):
796796

797797
def test_infer_broadcastable():
798798
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
799-
infer_broadcastable([constant(1.0)])
799+
infer_static_shape([constant(1.0)])
800800

801801
with config.change_flags(exception_verbosity="high"), pytest.raises(
802802
TypeError, match=r"A\. x"
803803
):
804-
infer_broadcastable([dscalar("x")])
804+
infer_static_shape([dscalar("x")])
805805

806806
with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
807-
infer_broadcastable((as_tensor_variable([[1, 2]]),))
807+
infer_static_shape((as_tensor_variable([[1, 2]]),))
808808

809809
constant_size = constant([1])
810810
specify_size = specify_shape(constant_size, [1])
811-
sh, bcast = infer_broadcastable(specify_size)
812-
assert bcast == (True,)
811+
sh, static_shape = infer_static_shape(specify_size)
812+
assert static_shape == (1,)
813813

814814

815815
# This is slow for the ('int8', 3) version.

0 commit comments

Comments
 (0)