From 1de6baea0d626514a35066ce1d56bfbc56ff8bfc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:13:03 +0200 Subject: [PATCH 1/8] Actually check types and dtypes match in numba testing helper NOTE: CI failing at this point --- tests/link/numba/test_basic.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index e8390b8ebf..b85febc005 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -260,9 +260,12 @@ def compare_numba_and_py( if assert_fn is None: def assert_fn(x, y): - return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( - x, y - ) + np.testing.assert_allclose(x, y, rtol=1e-4, strict=True) + # Make sure we don't have one input be a np.ndarray while the other is not + if isinstance(x, np.ndarray): + assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is" + else: + assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not" if any( inp.owner is not None @@ -295,8 +298,8 @@ def assert_fn(x, y): test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs numba_res = pytensor_numba_fn(*test_inputs_copy) if isinstance(graph_outputs, tuple | list): - for j, p in zip(numba_res, py_res, strict=True): - assert_fn(j, p) + for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True): + assert_fn(numba_res_i, python_res_i) else: assert_fn(numba_res, py_res) From 4e42b849b01cec8c363567ff2cb4be9122c86c19 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 13:09:00 +0200 Subject: [PATCH 2/8] Fix numba dispatch of Det and SlogDet returning non-arrays --- pytensor/link/numba/dispatch/nlinalg.py | 6 +-- tests/link/numba/test_nlinalg.py | 68 ++++--------------------- 2 files changed, 12 insertions(+), 62 deletions(-) diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 860560d0a6..3271b5bd26 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def det(x): - return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) + return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) return det @@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs): def slogdet(x): sign, det = np.linalg.slogdet(inputs_cast(x)) return ( - numba_basic.direct_cast(sign, out_dtype_1), - numba_basic.direct_cast(det, out_dtype_2), + np.array(sign).astype(out_dtype_1), + np.array(det).astype(out_dtype_2), ) return slogdet diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 67bdc6f1a0..5c11ca524d 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -11,68 +11,18 @@ rng = np.random.default_rng(42849) -@pytest.mark.parametrize( - "x, exc", - [ - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), - ), - None, - ), - ], -) -def test_Det(x, exc): - x, test_x = x - g = nlinalg.Det()(x) +@pytest.mark.parametrize("dtype", ("float64", "int64")) +@pytest.mark.parametrize("op", (nlinalg.Det(), nlinalg.SLogDet())) +def test_Det_SLogDet(op, dtype): + x = pt.matrix(dtype=dtype) - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - g, - [test_x], - ) + rng = np.random.default_rng([50, sum(map(ord, dtype))]) + x_ = rng.random(size=(3, 3)).astype(dtype) + test_x = x_.T.dot(x_) + g = op(x) -@pytest.mark.parametrize( - "x, exc", - [ - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), - ), - None, - ), - ], -) -def test_SLogDet(x, exc): - x, test_x = x - g = nlinalg.SLogDet()(x) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - g, - [test_x], - ) + compare_numba_and_py([x], g, [test_x]) # We were seeing some weird results in CI where the following two almost From e7f4cf96711be39e3507bea9d6d3f11f3740d88d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:08:46 +0200 Subject: [PATCH 3/8] Fix indices dtype in numba dispatch of LU --- pytensor/link/numba/dispatch/linalg/decomposition/lu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py index 570c024b07..739f0a6990 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/lu.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/lu.py @@ -30,7 +30,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a): # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array IPIV = IPIV - 1 p_inv = _pivot_to_permutation(IPIV, dtype=dtype) - perm = np.argsort(p_inv) + perm = np.argsort(p_inv).astype("int32") return perm, L, U From 70fb7cf965f7c7e5c1e036d3dcdc0bc6f34f07fa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:10:47 +0200 Subject: [PATCH 4/8] Fix type of numba Argmax special case --- pytensor/link/numba/dispatch/elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..7244762b93 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -561,7 +561,7 @@ def numba_funcify_Argmax(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def argmax(x): - return 0 + return np.array(0, dtype="int64") else: axes = tuple(int(ax) for ax in axis) From 2672b45ab2888705aa3bcb10fb7f443f25e16061 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:11:46 +0200 Subject: [PATCH 5/8] Fix dtype of numba dispatch of ArgSort --- pytensor/link/numba/dispatch/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 4938ecc42f..7b84cb4925 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -466,7 +466,7 @@ def argort_vec(X, axis): axis = axis.item() Y = np.swapaxes(X, axis, 0) - result = np.empty_like(Y) + result = np.empty_like(Y, dtype="int64") indices = list(np.ndindex(Y.shape[1:])) From 4a25c9cd87a8c0d3bf3352e03f93a1cd1d8f1277 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:10:16 +0200 Subject: [PATCH 6/8] Mark scalar downcast as failing --- tests/link/numba/test_scalar.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 504d2a163c..4ffa22a43a 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -99,7 +99,11 @@ def test_Composite(inputs, input_values, scalar_fn): "v, dtype", [ ((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), - ((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), + pytest.param( + (pt.dscalar(), np.array(1.0, dtype="float64")), + psb.float32, + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), ], ) def test_Cast(v, dtype): From 614ffddc3badb1fbed11fd6ffa39478cafda996f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 20 May 2025 15:09:21 +0200 Subject: [PATCH 7/8] Cast to output, not input in numba dispatch of scalar Softplus --- pytensor/link/numba/dispatch/basic.py | 25 --------------- pytensor/link/numba/dispatch/scalar.py | 21 ++++++++++++- tests/link/numba/test_basic.py | 43 -------------------------- tests/link/numba/test_scalar.py | 39 +++++++++++++++++++++-- 4 files changed, 57 insertions(+), 71 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 7b84cb4925..845d6afc7a 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -31,7 +31,6 @@ fgraph_to_python, ) from pytensor.scalar.basic import ScalarType -from pytensor.scalar.math import Softplus from pytensor.sparse import SparseTensorType from pytensor.tensor.basic import Nonzero from pytensor.tensor.blas import BatchedDot @@ -607,25 +606,6 @@ def dot(x, y): return dot -@numba_funcify.register(Softplus) -def numba_funcify_Softplus(op, node, **kwargs): - x_dtype = np.dtype(node.inputs[0].dtype) - - @numba_njit - def softplus(x): - if x < -37.0: - value = np.exp(x) - elif x < 18.0: - value = np.log1p(np.exp(x)) - elif x < 33.3: - value = x + np.exp(-x) - else: - value = x - return direct_cast(value, x_dtype) - - return softplus - - @numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a @@ -689,11 +669,6 @@ def batched_dot(x, y): return batched_dot -# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because -# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM -# optimizations are apparently already performed by Numba - - @numba_funcify.register(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..7e4703c8df 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -28,7 +28,7 @@ Second, Switch, ) -from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid +from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus @numba_funcify.register(ScalarOp) @@ -312,3 +312,22 @@ def erfc(x): @numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): return numba_basic.global_numba_func(erfc) + + +@numba_funcify.register(Softplus) +def numba_funcify_Softplus(op, node, **kwargs): + out_dtype = np.dtype(node.outputs[0].type.dtype) + + @numba_basic.numba_njit + def softplus(x): + if x < -37.0: + value = np.exp(x) + elif x < 18.0: + value = np.log1p(np.exp(x)) + elif x < 33.3: + value = x + np.exp(-x) + else: + value = x + return numba_basic.direct_cast(value, out_dtype) + + return softplus diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index b85febc005..9132d7b202 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -14,7 +14,6 @@ numba = pytest.importorskip("numba") import pytensor.scalar as ps -import pytensor.scalar.math as psm import pytensor.tensor as pt import pytensor.tensor.math as ptm from pytensor import config, shared @@ -643,48 +642,6 @@ def test_Dot(x, y, exc): ) -@pytest.mark.parametrize( - "x, exc", - [ - ( - (ps.float64(), np.array(0.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(-32.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(-40.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(32.0, dtype="float64")), - None, - ), - ( - (ps.float64(), np.array(40.0, dtype="float64")), - None, - ), - ( - (ps.int64(), np.array(32, dtype="int64")), - None, - ), - ], -) -def test_Softplus(x, exc): - x, x_test_value = x - g = psm.Softplus(ps.upgrade_to_float)(x) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - [g], - [x_test_value], - ) - - @pytest.mark.parametrize( "x, y, exc", [ diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 4ffa22a43a..2125d7cc0e 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -3,12 +3,13 @@ import pytensor.scalar as ps import pytensor.scalar.basic as psb +import pytensor.scalar.math as psm import pytensor.tensor as pt -from pytensor import config +from pytensor import config, function from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise -from tests.link.numba.test_basic import compare_numba_and_py +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode rng = np.random.default_rng(42849) @@ -149,3 +150,37 @@ def test_isnan(composite): [out], [np.array([1, 0], dtype="float64")], ) + + +@pytest.mark.parametrize( + "dtype", + [ + pytest.param( + "float32", + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), + "float64", + pytest.param( + "int16", + marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"), + ), + "int64", + "uint32", + ], +) +def test_Softplus(dtype): + x = ps.get_scalar_type(dtype)("x") + g = psm.softplus(x) + + py_fn = function([x], g, mode=py_mode) + numba_fn = function([x], g, mode=numba_mode) + for value in (-40, -32, 0, 32, 40): + if value < 0 and dtype.startswith("u"): + continue + test_x = np.dtype(dtype).type(value) + np.testing.assert_allclose( + py_fn(test_x), + numba_fn(test_x), + strict=True, + err_msg=f"Failed for value {value}", + ) From bff6d08057d33a8aeda80dc417b80a7b4be8c6b3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 22 May 2025 23:20:17 +0800 Subject: [PATCH 8/8] Add ids to Det_SLogDet test --- tests/link/numba/test_nlinalg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 5c11ca524d..8d7c3a449c 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -12,7 +12,9 @@ @pytest.mark.parametrize("dtype", ("float64", "int64")) -@pytest.mark.parametrize("op", (nlinalg.Det(), nlinalg.SLogDet())) +@pytest.mark.parametrize( + "op", (nlinalg.Det(), nlinalg.SLogDet()), ids=["det", "slogdet"] +) def test_Det_SLogDet(op, dtype): x = pt.matrix(dtype=dtype)