Skip to content

Commit d7d20be

Browse files
committed
Fix spurious warning from FusionOptimizer
1 parent 4730d0c commit d7d20be

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

pytensor/tensor/rewriting/elemwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,11 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
692692
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
693693
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
694694
for out, clients in fg.clients.items():
695+
# Old FunctionGraph nodes remain in the clients dictionary
696+
# even after they are removed by rewrites
697+
if not clients:
698+
continue
699+
695700
out_maybe_fuseable = (
696701
out.owner
697702
and isinstance(out.owner.op, Elemwise)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24
import pytest
35

@@ -36,6 +38,7 @@
3638
invert,
3739
iround,
3840
log,
41+
log1mexp,
3942
log2,
4043
log10,
4144
mul,
@@ -1370,6 +1373,21 @@ def rewrite_func():
13701373

13711374
assert benchmark(rewrite_func) == 103
13721375

1376+
def test_no_warning_from_old_client(self):
1377+
# There used to be a warning issued when creating fuseable mapping
1378+
# for nodes that are no longer in the FunctionGraph
1379+
with warnings.catch_warnings():
1380+
warnings.simplefilter("error")
1381+
# The -2 integer array cannot be passed directly to the C method
1382+
# of log1mexp as that can only handle floats. There is a rewrite
1383+
# that casts it to a float, but the FunctionGraph client retains
1384+
# the original log1mexp of the integer input, which caused
1385+
# a misleading warning for non C implementation in the FusionRewrite
1386+
assert np.isclose(
1387+
log1mexp(np.array(-2, dtype="int64")).eval(),
1388+
np.log(1 - np.exp(-2)),
1389+
)
1390+
13731391

13741392
class TimesN(aes.basic.UnaryScalarOp):
13751393
"""

0 commit comments

Comments
 (0)