Skip to content

Commit a7a62f9

Browse files
committed
Allow for scipy module resolution
1 parent e73258b commit a7a62f9

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
@@ -19,9 +21,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs):
1921
if nfunc_spec is None:
2022
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
2123

22-
func_name = nfunc_spec[0]
24+
func_name = nfunc_spec[0].replace("scipy.", "")
2325

24-
pytorch_func = getattr(torch, func_name)
26+
if "." in func_name:
27+
loc = func_name.split(".")
28+
mod = importlib.import_module(".".join(["torch", *loc[:-1]]))
29+
pytorch_func = getattr(mod, loc[-1])
30+
else:
31+
pytorch_func = getattr(torch, func_name)
2532

2633
if len(node.inputs) > op.nfunc_spec[1]:
2734
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,

0 commit comments

Comments
 (0)