Skip to content

Commit a5f329e

Browse files
author
Ian Schweer
committed
Check for scipy in elemwise
1 parent 5ebbb1a commit a5f329e

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
@@ -10,8 +12,20 @@
1012
def pytorch_funcify_Elemwise(op, node, **kwargs):
1113
scalar_op = op.scalar_op
1214
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
15+
16+
def check_special_scipy(func_name):
17+
if "scipy." not in func_name:
18+
return False
19+
loc = func_name.split(".")[1:]
20+
try:
21+
mod = importlib.import_module(".".join(loc[:-1]), "torch")
22+
return getattr(mod, loc[-1], False)
23+
except ImportError:
24+
return False
25+
1326
if hasattr(scalar_op, "nfunc_spec") and (
14-
hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0]
27+
hasattr(torch, scalar_op.nfunc_spec[0])
28+
or check_special_scipy(scalar_op.nfunc_spec[0])
1529
):
1630
# torch can handle this scalar
1731
# broadcast, we'll let it.

0 commit comments

Comments
 (0)