Skip to content

Commit 7e98dbf

Browse files
committed
Fuse consecutive Elemwise nodes with multiple clients
1 parent 36a5cb6 commit 7e98dbf

File tree

7 files changed

+596
-375
lines changed

7 files changed

+596
-375
lines changed

pytensor/tensor/elemwise.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.misc.safe_asarray import _asarray
1717
from pytensor.printing import FunctionPrinter, Printer, pprint
1818
from pytensor.scalar import get_scalar_type
19+
from pytensor.scalar.basic import Composite
1920
from pytensor.scalar.basic import bool as scalar_bool
2021
from pytensor.scalar.basic import identity as scalar_identity
2122
from pytensor.scalar.basic import transfer_type, upcast
@@ -652,10 +653,12 @@ def transform(r):
652653

653654
def prepare_node(self, node, storage_map, compute_map, impl):
654655
# Postpone the ufunc building to the last minutes due to:
655-
# - NumPy ufunc support only up to 31 inputs.
656+
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
656657
# But our c code support more.
657658
# - nfunc is reused for scipy and scipy is optional
658-
if len(node.inputs) > 32 and self.ufunc and impl == "py":
659+
if isinstance(self.scalar_op, Composite):
660+
print("WOW")
661+
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
659662
impl = "c"
660663

661664
if getattr(self, "nfunc_spec", None) and impl != "c":
@@ -677,7 +680,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
677680
self.nfunc = module
678681

679682
if (
680-
len(node.inputs) < 32
683+
(len(node.inputs) + len(node.inputs)) <= 32
681684
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
682685
and self.ufunc is None
683686
and impl == "py"
@@ -727,28 +730,18 @@ def prepare_node(self, node, storage_map, compute_map, impl):
727730
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
728731

729732
def perform(self, node, inputs, output_storage):
730-
if len(node.inputs) >= 32:
733+
if (len(node.inputs) + len(node.outputs)) > 32:
731734
# Some versions of NumPy will segfault, other will raise a
732-
# ValueError, if the number of inputs to a ufunc is 32 or more.
735+
# ValueError, if the number of operands in an ufunc is more than 32.
733736
# In that case, the C version should be used, or Elemwise fusion
734737
# should be disabled.
738+
# FIXME: This no longer calls the C implementation!
735739
super().perform(node, inputs, output_storage)
736740

737741
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
738742
if len(set(dim_shapes) - {1}) > 1:
739743
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
740744

741-
# Determine the shape of outputs
742-
out_shape = []
743-
for values in zip(*[input.shape for input in inputs]):
744-
if any(v == 0 for v in values):
745-
# All non-broadcasted dimensions should be zero
746-
assert max(values) <= 1
747-
out_shape.append(0)
748-
else:
749-
out_shape.append(max(values))
750-
out_shape = tuple(out_shape)
751-
752745
ufunc_args = inputs
753746
ufunc_kwargs = {}
754747
# We supported in the past calling manually op.perform.

0 commit comments

Comments
 (0)