23
23
from aesara .raise_op import Assert
24
24
from aesara .scalar import int32 as int_t
25
25
from aesara .scalar import upcast
26
+ from aesara .scalar .basic import Composite
26
27
from aesara .tensor import basic as at
27
28
from aesara .tensor import get_vector_length
28
29
from aesara .tensor .exceptions import NotScalarConstantError
@@ -1552,16 +1553,32 @@ def broadcast_shape_iter(
1552
1553
# be broadcastable or equal to the one non-broadcastable
1553
1554
# constant `const_nt_shape_var`.
1554
1555
assert_dim = Assert ("Could not broadcast dimensions" )
1556
+
1557
+ scalar_nonconst_nb_shapes = [
1558
+ at .scalar_from_tensor (s )
1559
+ if isinstance (s .type , TensorType )
1560
+ else s
1561
+ for s in nonconst_nb_shapes
1562
+ ]
1563
+
1564
+ dummy_nonconst_nb_shapes = [
1565
+ aes .get_scalar_type (dtype = v .dtype )()
1566
+ for v in scalar_nonconst_nb_shapes
1567
+ ]
1555
1568
assert_cond = reduce (
1556
1569
aes .and_ ,
1557
1570
(
1558
1571
aes .or_ (
1559
1572
aes .eq (nbv , one_at ), aes .eq (nbv , const_nt_shape_var )
1560
1573
)
1561
- for nbv in nonconst_nb_shapes
1574
+ for nbv in dummy_nonconst_nb_shapes
1562
1575
),
1563
1576
)
1564
- bcast_dim = assert_dim (const_nt_shape_var , assert_cond )
1577
+ assert_cond_op = Composite (dummy_nonconst_nb_shapes , [assert_cond ])
1578
+
1579
+ bcast_dim = assert_dim (
1580
+ const_nt_shape_var , assert_cond_op (* scalar_nonconst_nb_shapes )
1581
+ )
1565
1582
else :
1566
1583
bcast_dim = const_nt_shape_var
1567
1584
else :
@@ -1579,21 +1596,37 @@ def broadcast_shape_iter(
1579
1596
result_dims .append (maybe_non_bcast_shapes [0 ])
1580
1597
continue
1581
1598
1599
+ scalar_maybe_non_bcast_shapes = [
1600
+ at .scalar_from_tensor (s ) if isinstance (s .type , TensorType ) else s
1601
+ for s in maybe_non_bcast_shapes
1602
+ ]
1603
+ dummy_maybe_non_bcast_shapes = [
1604
+ aes .get_scalar_type (dtype = v .dtype )()
1605
+ for v in scalar_maybe_non_bcast_shapes
1606
+ ]
1582
1607
non_bcast_vec = [
1583
1608
aes .switch (aes .eq (nbv , 1 ), - one_at , nbv )
1584
- for nbv in maybe_non_bcast_shapes
1609
+ for nbv in dummy_maybe_non_bcast_shapes
1585
1610
]
1586
1611
dim_max = aes .abs (reduce (aes .scalar_maximum , non_bcast_vec ))
1612
+ dim_max_op = Composite (dummy_maybe_non_bcast_shapes , [dim_max ])
1613
+
1614
+ dummy_dim_max = dim_max_op (* dummy_maybe_non_bcast_shapes )
1587
1615
1588
1616
assert_dim = Assert ("Could not broadcast dimensions" )
1589
1617
assert_cond = reduce (
1590
1618
aes .and_ ,
1591
1619
(
1592
- aes .or_ (aes .eq (nbv , - one_at ), aes .eq (nbv , dim_max ))
1620
+ aes .or_ (aes .eq (nbv , - one_at ), aes .eq (nbv , dummy_dim_max ))
1593
1621
for nbv in non_bcast_vec
1594
1622
),
1595
1623
)
1596
- bcast_dim = assert_dim (dim_max , assert_cond )
1624
+ assert_cond_op = Composite (dummy_maybe_non_bcast_shapes , [assert_cond ])
1625
+
1626
+ bcast_dim = assert_dim (
1627
+ dim_max_op (* scalar_maybe_non_bcast_shapes ),
1628
+ assert_cond_op (* scalar_maybe_non_bcast_shapes ),
1629
+ )
1597
1630
1598
1631
result_dims .append (bcast_dim )
1599
1632
0 commit comments