1
1
from collections .abc import Collection
2
- from functools import reduce
3
2
from typing import Iterable , Set , Tuple , Union
4
3
5
4
import numpy as np
6
- import numpy .core .numeric
7
5
from numpy .core .multiarray import normalize_axis_index
8
6
9
7
import pytensor
14
12
disconnected_type ,
15
13
grad_undefined ,
16
14
)
17
- from pytensor .graph .basic import Apply , Constant , Variable , equal_computations
15
+ from pytensor .graph .basic import Apply , Constant , Variable
18
16
from pytensor .graph .op import Op
19
17
from pytensor .link .c .op import COp
20
18
from pytensor .link .c .params_type import ParamsType
23
21
from pytensor .raise_op import Assert
24
22
from pytensor .scalar import int32 as int_t
25
23
from pytensor .scalar import upcast
26
- from pytensor .scalar .basic import Composite
27
24
from pytensor .tensor import basic as at
28
25
from pytensor .tensor import get_vector_length
29
26
from pytensor .tensor .exceptions import NotScalarConstantError
30
27
from pytensor .tensor .math import abs as at_abs
31
- from pytensor .tensor .math import all as at_all
28
+ from pytensor .tensor .math import all as pt_all
29
+ from pytensor .tensor .math import eq as pt_eq
32
30
from pytensor .tensor .math import ge , lt , maximum , minimum , prod
33
31
from pytensor .tensor .math import sum as at_sum
34
32
from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
536
534
537
535
if assert_nonneg :
538
536
assert_op = Assert ("Input to bincount has negative values!" )
539
- x = assert_op (x , at_all (x >= 0 ))
537
+ x = assert_op (x , pt_all (x >= 0 ))
540
538
541
539
max_value = at .cast (x .max () + 1 , "int64" )
542
540
@@ -1510,8 +1508,8 @@ def broadcast_shape_iter(
1510
1508
result_dims = []
1511
1509
1512
1510
for dim_shapes in zip (* array_shapes ):
1513
- # Get the shapes in this dimension that are not definitively
1514
- # broadcastable (i.e. not symbolically known to be broadcastable)
1511
+ # Get the shapes in this dimension that are not broadcastable
1512
+ # (i.e. not symbolically known to be broadcastable)
1515
1513
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at ]
1516
1514
1517
1515
if len (maybe_non_bcast_shapes ) == 0 :
@@ -1532,97 +1530,36 @@ def broadcast_shape_iter(
1532
1530
nonconst_nb_shapes .add (shape )
1533
1531
1534
1532
if len (const_nb_shapes ) > 1 :
1535
- raise ValueError ("Could not broadcast dimensions" )
1536
- elif len (const_nb_shapes ) == 1 :
1537
- (const_nb_shape ,) = const_nb_shapes
1533
+ raise ValueError (f"Could not broadcast dimensions. Incompatible shapes were { array_shapes } ." )
1538
1534
1539
- assert const_nb_shape != 1
1540
-
1541
- const_nt_shape_var = pytensor .scalar .ScalarConstant (
1542
- pytensor .scalar .int64 , const_nb_shape
1543
- )
1544
-
1545
- if len (nonconst_nb_shapes ) > 0 :
1546
- # All the potential non-broadcast shapes need to either
1547
- # be broadcastable or equal to the one non-broadcastable
1548
- # constant `const_nt_shape_var`.
1549
- assert_dim = Assert ("Could not broadcast dimensions" )
1550
-
1551
- scalar_nonconst_nb_shapes = [
1552
- at .scalar_from_tensor (s )
1553
- if isinstance (s .type , TensorType )
1554
- else s
1555
- for s in nonconst_nb_shapes
1556
- ]
1557
-
1558
- dummy_nonconst_nb_shapes = [
1559
- aes .get_scalar_type (dtype = v .dtype )()
1560
- for v in scalar_nonconst_nb_shapes
1561
- ]
1562
- assert_cond = reduce (
1563
- aes .and_ ,
1564
- (
1565
- aes .or_ (
1566
- aes .eq (nbv , one_at ), aes .eq (nbv , const_nt_shape_var )
1567
- )
1568
- for nbv in dummy_nonconst_nb_shapes
1569
- ),
1570
- )
1571
- assert_cond_op = Composite (dummy_nonconst_nb_shapes , [assert_cond ])
1572
-
1573
- bcast_dim = assert_dim (
1574
- const_nt_shape_var , assert_cond_op (* scalar_nonconst_nb_shapes )
1575
- )
1576
- else :
1577
- bcast_dim = const_nt_shape_var
1535
+ assert_op = Assert ("Could not dynamically broadcast dimensions." )
1536
+ if len (const_nb_shapes ) == 1 :
1537
+ (first_length ,) = const_nb_shapes
1538
+ other_lengths = nonconst_nb_shapes
1539
+ first_length = aes .as_scalar (first_length )
1578
1540
else :
1579
- # There are no constant, non-broadcastable shapes in this
1580
- # dimension.
1581
-
1582
- all_dims_equal = all (
1583
- # TODO FIXME: This is a largely deficient, and expensive, means
1584
- # of comparing graphs (and especially shapes)
1585
- equal_computations ([maybe_non_bcast_shapes [0 ]], [dim ])
1586
- for dim in maybe_non_bcast_shapes [1 :]
1587
- )
1588
-
1589
- if all_dims_equal :
1590
- result_dims .append (maybe_non_bcast_shapes [0 ])
1591
- continue
1592
-
1593
- scalar_maybe_non_bcast_shapes = [
1594
- at .scalar_from_tensor (s ) if isinstance (s .type , TensorType ) else s
1595
- for s in maybe_non_bcast_shapes
1596
- ]
1597
- dummy_maybe_non_bcast_shapes = [
1598
- aes .get_scalar_type (dtype = v .dtype )()
1599
- for v in scalar_maybe_non_bcast_shapes
1600
- ]
1601
- non_bcast_vec = [
1602
- aes .switch (aes .eq (nbv , 1 ), - one_at , nbv )
1603
- for nbv in dummy_maybe_non_bcast_shapes
1604
- ]
1605
- dim_max = aes .abs (reduce (aes .scalar_maximum , non_bcast_vec ))
1606
- dim_max_op = Composite (dummy_maybe_non_bcast_shapes , [dim_max ])
1607
-
1608
- dummy_dim_max = dim_max_op (* dummy_maybe_non_bcast_shapes )
1609
-
1610
- assert_dim = Assert ("Could not broadcast dimensions" )
1611
- assert_cond = reduce (
1612
- aes .and_ ,
1613
- (
1614
- aes .or_ (aes .eq (nbv , - one_at ), aes .eq (nbv , dummy_dim_max ))
1615
- for nbv in non_bcast_vec
1616
- ),
1617
- )
1618
- assert_cond_op = Composite (dummy_maybe_non_bcast_shapes , [assert_cond ])
1619
-
1620
- bcast_dim = assert_dim (
1621
- dim_max_op (* scalar_maybe_non_bcast_shapes ),
1622
- assert_cond_op (* scalar_maybe_non_bcast_shapes ),
1623
- )
1624
-
1625
- result_dims .append (bcast_dim )
1541
+ first_length , * other_lengths = nonconst_nb_shapes
1542
+
1543
+ if len (other_lengths ) == 0 :
1544
+ result_dims .append (first_length )
1545
+ continue
1546
+
1547
+ # Add assert that all remaining shapes are equal
1548
+ use_scalars = False
1549
+ if use_scalars :
1550
+ condition = None
1551
+ for other in other_lengths :
1552
+ cond = aes .eq (first_length , other )
1553
+ if condition is None :
1554
+ condition = cond
1555
+ else :
1556
+ condition = aes .and_ (condition , cond )
1557
+ else :
1558
+ condition = pt_all ([pt_eq (first_length , other ) for other in other_lengths ])
1559
+ if condition is None :
1560
+ result_dims .append (first_length )
1561
+ else :
1562
+ result_dims .append (assert_op (first_length , condition ))
1626
1563
1627
1564
return tuple (result_dims )
1628
1565
0 commit comments