Skip to content

Commit 594f46b

Browse files
ricardoV94jessegrabowski
authored andcommitted
Cast to output, not input in numba dispatch of scalar Softplus
1 parent bfcad6d commit 594f46b

File tree

4 files changed

+57
-71
lines changed

4 files changed

+57
-71
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
fgraph_to_python,
3232
)
3333
from pytensor.scalar.basic import ScalarType
34-
from pytensor.scalar.math import Softplus
3534
from pytensor.sparse import SparseTensorType
3635
from pytensor.tensor.basic import Nonzero
3736
from pytensor.tensor.blas import BatchedDot
@@ -607,25 +606,6 @@ def dot(x, y):
607606
return dot
608607

609608

610-
@numba_funcify.register(Softplus)
611-
def numba_funcify_Softplus(op, node, **kwargs):
612-
x_dtype = np.dtype(node.inputs[0].dtype)
613-
614-
@numba_njit
615-
def softplus(x):
616-
if x < -37.0:
617-
value = np.exp(x)
618-
elif x < 18.0:
619-
value = np.log1p(np.exp(x))
620-
elif x < 33.3:
621-
value = x + np.exp(-x)
622-
else:
623-
value = x
624-
return direct_cast(value, x_dtype)
625-
626-
return softplus
627-
628-
629609
@numba_funcify.register(Solve)
630610
def numba_funcify_Solve(op, node, **kwargs):
631611
assume_a = op.assume_a
@@ -689,11 +669,6 @@ def batched_dot(x, y):
689669
return batched_dot
690670

691671

692-
# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because
693-
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
694-
# optimizations are apparently already performed by Numba
695-
696-
697672
@numba_funcify.register(IfElse)
698673
def numba_funcify_IfElse(op, **kwargs):
699674
n_outs = op.n_outs

pytensor/link/numba/dispatch/scalar.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Second,
2929
Switch,
3030
)
31-
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid
31+
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus
3232

3333

3434
@numba_funcify.register(ScalarOp)
@@ -312,3 +312,22 @@ def erfc(x):
312312
@numba_funcify.register(Erfc)
313313
def numba_funcify_Erfc(op, **kwargs):
314314
return numba_basic.global_numba_func(erfc)
315+
316+
317+
@numba_funcify.register(Softplus)
318+
def numba_funcify_Softplus(op, node, **kwargs):
319+
out_dtype = np.dtype(node.outputs[0].type.dtype)
320+
321+
@numba_basic.numba_njit
322+
def softplus(x):
323+
if x < -37.0:
324+
value = np.exp(x)
325+
elif x < 18.0:
326+
value = np.log1p(np.exp(x))
327+
elif x < 33.3:
328+
value = x + np.exp(-x)
329+
else:
330+
value = x
331+
return numba_basic.direct_cast(value, out_dtype)
332+
333+
return softplus

tests/link/numba/test_basic.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
numba = pytest.importorskip("numba")
1515

1616
import pytensor.scalar as ps
17-
import pytensor.scalar.math as psm
1817
import pytensor.tensor as pt
1918
import pytensor.tensor.math as ptm
2019
from pytensor import config, shared
@@ -643,48 +642,6 @@ def test_Dot(x, y, exc):
643642
)
644643

645644

646-
@pytest.mark.parametrize(
647-
"x, exc",
648-
[
649-
(
650-
(ps.float64(), np.array(0.0, dtype="float64")),
651-
None,
652-
),
653-
(
654-
(ps.float64(), np.array(-32.0, dtype="float64")),
655-
None,
656-
),
657-
(
658-
(ps.float64(), np.array(-40.0, dtype="float64")),
659-
None,
660-
),
661-
(
662-
(ps.float64(), np.array(32.0, dtype="float64")),
663-
None,
664-
),
665-
(
666-
(ps.float64(), np.array(40.0, dtype="float64")),
667-
None,
668-
),
669-
(
670-
(ps.int64(), np.array(32, dtype="int64")),
671-
None,
672-
),
673-
],
674-
)
675-
def test_Softplus(x, exc):
676-
x, x_test_value = x
677-
g = psm.Softplus(ps.upgrade_to_float)(x)
678-
679-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
680-
with cm:
681-
compare_numba_and_py(
682-
[x],
683-
[g],
684-
[x_test_value],
685-
)
686-
687-
688645
@pytest.mark.parametrize(
689646
"x, y, exc",
690647
[

tests/link/numba/test_scalar.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import pytensor.scalar as ps
55
import pytensor.scalar.basic as psb
6+
import pytensor.scalar.math as psm
67
import pytensor.tensor as pt
7-
from pytensor import config
8+
from pytensor import config, function
89
from pytensor.scalar.basic import Composite
910
from pytensor.tensor import tensor
1011
from pytensor.tensor.elemwise import Elemwise
11-
from tests.link.numba.test_basic import compare_numba_and_py
12+
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode
1213

1314

1415
rng = np.random.default_rng(42849)
@@ -149,3 +150,37 @@ def test_isnan(composite):
149150
[out],
150151
[np.array([1, 0], dtype="float64")],
151152
)
153+
154+
155+
@pytest.mark.parametrize(
156+
"dtype",
157+
[
158+
pytest.param(
159+
"float32",
160+
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
161+
),
162+
"float64",
163+
pytest.param(
164+
"int16",
165+
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
166+
),
167+
"int64",
168+
"uint32",
169+
],
170+
)
171+
def test_Softplus(dtype):
172+
x = ps.get_scalar_type(dtype)("x")
173+
g = psm.softplus(x)
174+
175+
py_fn = function([x], g, mode=py_mode)
176+
numba_fn = function([x], g, mode=numba_mode)
177+
for value in (-40, -32, 0, 32, 40):
178+
if value < 0 and dtype.startswith("u"):
179+
continue
180+
test_x = np.dtype(dtype).type(value)
181+
np.testing.assert_allclose(
182+
py_fn(test_x),
183+
numba_fn(test_x),
184+
strict=True,
185+
err_msg=f"Failed for value {value}",
186+
)

0 commit comments

Comments
 (0)