16
16
from pytensor .misc .safe_asarray import _asarray
17
17
from pytensor .printing import FunctionPrinter , Printer , pprint
18
18
from pytensor .scalar import get_scalar_type
19
+ from pytensor .scalar .basic import Composite
19
20
from pytensor .scalar .basic import bool as scalar_bool
20
21
from pytensor .scalar .basic import identity as scalar_identity
21
22
from pytensor .scalar .basic import transfer_type , upcast
@@ -652,10 +653,12 @@ def transform(r):
652
653
653
654
def prepare_node (self , node , storage_map , compute_map , impl ):
654
655
# Postpone the ufunc building to the last minutes due to:
655
- # - NumPy ufunc support only up to 31 inputs.
656
+ # - NumPy ufunc support only up to 32 operands ( inputs and outputs)
656
657
# But our c code support more.
657
658
# - nfunc is reused for scipy and scipy is optional
658
- if len (node .inputs ) > 32 and self .ufunc and impl == "py" :
659
+ if isinstance (self .scalar_op , Composite ):
660
+ print ("WOW" )
661
+ if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
659
662
impl = "c"
660
663
661
664
if getattr (self , "nfunc_spec" , None ) and impl != "c" :
@@ -677,7 +680,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
677
680
self .nfunc = module
678
681
679
682
if (
680
- len (node .inputs ) < 32
683
+ ( len (node .inputs ) + len ( node . inputs )) <= 32
681
684
and (self .nfunc is None or self .scalar_op .nin != len (node .inputs ))
682
685
and self .ufunc is None
683
686
and impl == "py"
@@ -727,28 +730,18 @@ def prepare_node(self, node, storage_map, compute_map, impl):
727
730
self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
728
731
729
732
def perform (self , node , inputs , output_storage ):
730
- if len (node .inputs ) >= 32 :
733
+ if ( len (node .inputs ) + len ( node . outputs )) > 32 :
731
734
# Some versions of NumPy will segfault, other will raise a
732
- # ValueError, if the number of inputs to a ufunc is 32 or more .
735
+ # ValueError, if the number of operands in an ufunc is more than 32 .
733
736
# In that case, the C version should be used, or Elemwise fusion
734
737
# should be disabled.
738
+ # FIXME: This no longer calls the C implementation!
735
739
super ().perform (node , inputs , output_storage )
736
740
737
741
for d , dim_shapes in enumerate (zip (* (i .shape for i in inputs ))):
738
742
if len (set (dim_shapes ) - {1 }) > 1 :
739
743
raise ValueError (f"Shapes on dimension { d } do not match: { dim_shapes } " )
740
744
741
- # Determine the shape of outputs
742
- out_shape = []
743
- for values in zip (* [input .shape for input in inputs ]):
744
- if any (v == 0 for v in values ):
745
- # All non-broadcasted dimensions should be zero
746
- assert max (values ) <= 1
747
- out_shape .append (0 )
748
- else :
749
- out_shape .append (max (values ))
750
- out_shape = tuple (out_shape )
751
-
752
745
ufunc_args = inputs
753
746
ufunc_kwargs = {}
754
747
# We supported in the past calling manually op.perform.
0 commit comments