diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 72f97af1fa..79ca5beec1 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -11,12 +13,26 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) - if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]): + def check_special_scipy(func_name): + if "scipy." not in func_name: + return False + loc = func_name.split(".")[1:] + try: + mod = importlib.import_module(".".join(loc[:-1]), "torch") + return getattr(mod, loc[-1], False) + except ImportError: + return False + + if hasattr(scalar_op, "nfunc_spec") and ( + hasattr(torch, scalar_op.nfunc_spec[0]) + or check_special_scipy(scalar_op.nfunc_spec[0]) + ): # torch can handle this scalar # broadcast, we'll let it. def elemwise_fn(*inputs): Elemwise._check_runtime_broadcast(node, inputs) return base_fn(*inputs) + else: def elemwise_fn(*inputs): diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index a977c6d4b2..1416e58f55 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,3 +1,5 @@ +import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -5,6 +7,7 @@ Cast, ScalarOp, ) +from pytensor.scalar.math import Softplus @pytorch_funcify.register(ScalarOp) @@ -19,9 +22,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs): if nfunc_spec is None: raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") - func_name = nfunc_spec[0] + func_name = nfunc_spec[0].replace("scipy.", "") - pytorch_func = getattr(torch, func_name) + if "." in func_name: + loc = func_name.split(".") + mod = importlib.import_module(".".join(["torch", *loc[:-1]])) + pytorch_func = getattr(mod, loc[-1]) + else: + pytorch_func = getattr(torch, func_name) if len(node.inputs) > op.nfunc_spec[1]: # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, @@ -49,3 +57,8 @@ def cast(x): return x.to(dtype=dtype) return cast + + +@pytorch_funcify.register(Softplus) +def pytorch_funcify_Softplus(op, node, **kwargs): + return torch.nn.Softplus() diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 25827d23f9..83249d021b 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -17,7 +17,7 @@ from pytensor.ifelse import ifelse from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise -from pytensor.tensor import alloc, arange, as_tensor, empty, eye +from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -374,3 +374,17 @@ def inner_fn(x): f = function([x], out, mode="PYTORCH") f(torch.ones(3)) assert "inner_fn" not in dir(m), "function call reference leaked" + + +def test_pytorch_scipy(): + x = vector("a", shape=(3,)) + out = expit(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)]) + + +def test_pytorch_softplus(): + x = vector("a", shape=(3,)) + out = softplus(x) + f = FunctionGraph([x], [out]) + compare_pytorch_and_py(f, [np.random.rand(3)])