We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d17d4a9 commit 5ebbb1aCopy full SHA for 5ebbb1a
pytensor/link/pytorch/dispatch/elemwise.py
@@ -10,13 +10,15 @@
10
def pytorch_funcify_Elemwise(op, node, **kwargs):
11
scalar_op = op.scalar_op
12
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
13
-
14
- if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
+ if hasattr(scalar_op, "nfunc_spec") and (
+ hasattr(torch, scalar_op.nfunc_spec[0]) or "scipy." in scalar_op.nfunc_spec[0]
15
+ ):
16
# torch can handle this scalar
17
# broadcast, we'll let it.
18
def elemwise_fn(*inputs):
19
Elemwise._check_runtime_broadcast(node, inputs)
20
return base_fn(*inputs)
21
+
22
else:
23
24
0 commit comments