Skip to content

Commit e4b15e4

Browse files
Use Composite graphs in aesara.tensor.extra_ops.broadcast_shape_iter
1 parent 9d5ab76 commit e4b15e4

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

aesara/tensor/extra_ops.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from aesara.raise_op import Assert
2424
from aesara.scalar import int32 as int_t
2525
from aesara.scalar import upcast
26+
from aesara.scalar.basic import Composite
2627
from aesara.tensor import basic as at
2728
from aesara.tensor import get_vector_length
2829
from aesara.tensor.exceptions import NotScalarConstantError
@@ -1552,16 +1553,32 @@ def broadcast_shape_iter(
15521553
# be broadcastable or equal to the one non-broadcastable
15531554
# constant `const_nt_shape_var`.
15541555
assert_dim = Assert("Could not broadcast dimensions")
1556+
1557+
scalar_nonconst_nb_shapes = [
1558+
at.scalar_from_tensor(s)
1559+
if isinstance(s.type, TensorType)
1560+
else s
1561+
for s in nonconst_nb_shapes
1562+
]
1563+
1564+
dummy_nonconst_nb_shapes = [
1565+
aes.get_scalar_type(dtype=v.dtype)()
1566+
for v in scalar_nonconst_nb_shapes
1567+
]
15551568
assert_cond = reduce(
15561569
aes.and_,
15571570
(
15581571
aes.or_(
15591572
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
15601573
)
1561-
for nbv in nonconst_nb_shapes
1574+
for nbv in dummy_nonconst_nb_shapes
15621575
),
15631576
)
1564-
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
1577+
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
1578+
1579+
bcast_dim = assert_dim(
1580+
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
1581+
)
15651582
else:
15661583
bcast_dim = const_nt_shape_var
15671584
else:
@@ -1579,21 +1596,37 @@ def broadcast_shape_iter(
15791596
result_dims.append(maybe_non_bcast_shapes[0])
15801597
continue
15811598

1599+
scalar_maybe_non_bcast_shapes = [
1600+
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
1601+
for s in maybe_non_bcast_shapes
1602+
]
1603+
dummy_maybe_non_bcast_shapes = [
1604+
aes.get_scalar_type(dtype=v.dtype)()
1605+
for v in scalar_maybe_non_bcast_shapes
1606+
]
15821607
non_bcast_vec = [
15831608
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
1584-
for nbv in maybe_non_bcast_shapes
1609+
for nbv in dummy_maybe_non_bcast_shapes
15851610
]
15861611
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
1612+
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
1613+
1614+
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
15871615

15881616
assert_dim = Assert("Could not broadcast dimensions")
15891617
assert_cond = reduce(
15901618
aes.and_,
15911619
(
1592-
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
1620+
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
15931621
for nbv in non_bcast_vec
15941622
),
15951623
)
1596-
bcast_dim = assert_dim(dim_max, assert_cond)
1624+
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
1625+
1626+
bcast_dim = assert_dim(
1627+
dim_max_op(*scalar_maybe_non_bcast_shapes),
1628+
assert_cond_op(*scalar_maybe_non_bcast_shapes),
1629+
)
15971630

15981631
result_dims.append(bcast_dim)
15991632

0 commit comments

Comments
 (0)