Skip to content

Commit f1d7852

Browse files
committed
Add softplus
1 parent a7a62f9 commit f1d7852

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Cast,
88
ScalarOp,
99
)
10+
from pytensor.scalar.math import Softplus
1011

1112

1213
@pytorch_funcify.register(ScalarOp)
@@ -56,3 +57,8 @@ def cast(x):
5657
return x.to(dtype=dtype)
5758

5859
return cast
60+
61+
62+
@pytorch_funcify.register(Softplus)
63+
def pytorch_funcify_Softplus(op, node, **kwargs):
64+
return torch.nn.Softplus()

0 commit comments

Comments
 (0)