File tree Expand file tree Collapse file tree 1 file changed +15
-1
lines changed
pytensor/link/pytorch/dispatch Expand file tree Collapse file tree 1 file changed +15
-1
lines changed Original file line number Diff line number Diff line change
1
+ import importlib
2
+
1
3
import torch
2
4
3
5
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
10
12
def pytorch_funcify_Elemwise (op , node , ** kwargs ):
11
13
scalar_op = op .scalar_op
12
14
base_fn = pytorch_funcify (scalar_op , node = node , ** kwargs )
15
+
16
+ def check_special_scipy (func_name ):
17
+ if "scipy." not in func_name :
18
+ return False
19
+ loc = func_name .split ("." )[1 :]
20
+ try :
21
+ mod = importlib .import_module ("." .join (loc [:- 1 ]), "torch" )
22
+ return getattr (mod , loc [- 1 ], False )
23
+ except ImportError :
24
+ return False
25
+
13
26
if hasattr (scalar_op , "nfunc_spec" ) and (
14
- hasattr (torch , scalar_op .nfunc_spec [0 ]) or "scipy." in scalar_op .nfunc_spec [0 ]
27
+ hasattr (torch , scalar_op .nfunc_spec [0 ])
28
+ or check_special_scipy (scalar_op .nfunc_spec [0 ])
15
29
):
16
30
# torch can handle this scalar
17
31
# broadcast, we'll let it.
You can’t perform that action at this time.
0 commit comments