From 4bb15bf52381eef0e44b724f15b9ef64630ff8c9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Feb 2025 14:50:24 +0100 Subject: [PATCH 1/2] Avoid manipulation of deprecated _mpm_cheap Internal API changed in numba 0.61 Existing benchmarks don't show any difference in performance --- pytensor/link/numba/dispatch/basic.py | 18 -------- pytensor/link/numba/dispatch/elemwise.py | 52 ++---------------------- 2 files changed, 4 insertions(+), 66 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 843a4dbf1f..0b2b58904a 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,7 +1,6 @@ import operator import sys import warnings -from contextlib import contextmanager from copy import copy from functools import singledispatch from textwrap import dedent @@ -362,23 +361,6 @@ def create_arg_string(x): return args -@contextmanager -def use_optimized_cheap_pass(*args, **kwargs): - """Temporarily replace the cheap optimization pass with a better one.""" - from numba.core.registry import cpu_target - - context = cpu_target.target_context._internal_codegen - old_pm = context._mpm_cheap - new_pm = context._module_pass_manager( - loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap" - ) - context._mpm_cheap = new_pm - try: - yield - finally: - context._mpm_cheap = old_pm - - @singledispatch def numba_typify(data, dtype=None, **kwargs): return data diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index ae5ef3dcb1..2a98985efe 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -9,10 +9,8 @@ from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( - create_numba_signature, numba_funcify, numba_njit, - use_optimized_cheap_pass, ) from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, @@ -245,47 +243,6 @@ def {careduce_fn_name}(x): return careduce_fn -def jit_compile_reducer( - node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds -): - """Compile Python source for reduction loops using additional optimizations. - - Parameters - ========== - node - An node from which the signature can be derived. - fn - The Python function object to compile. - reduce_to_scalar: bool, default False - Whether to reduce output to a scalar (instead of 0d array) - infer_signature: bool: default True - Whether to try and infer the function signature from the Apply node. - kwds - Extra keywords to be added to the :func:`numba.njit` function. - - Returns - ======= - A :func:`numba.njit`-compiled function. - - """ - if infer_signature: - signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) - args = (signature,) - else: - args = () - - # Eagerly compile the function using increased optimizations. This should - # help improve nested loop reductions. - with use_optimized_cheap_pass(): - res = numba_basic.numba_njit( - *args, - boundscheck=False, - **kwds, - )(fn) - - return res - - def create_axis_apply_fn(fn, axis, ndim, dtype): axis = normalize_axis_index(axis, ndim) @@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): np.dtype(node.outputs[0].type.dtype), ) - careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) + careduce_fn = numba_njit(careduce_py_fn, boundscheck=False) return careduce_fn @@ -579,7 +536,7 @@ def softmax_py_fn(x): sm = e_x / w return sm - softmax = jit_compile_reducer(node, softmax_py_fn) + softmax = numba_njit(softmax_py_fn, boundscheck=False) return softmax @@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm): dx = dy_times_sm - sum_dy_times_sm * sm return dx - # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True) - softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False) + softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False) return softmax_grad @@ -647,7 +603,7 @@ def log_softmax_py_fn(x): lsm = xdev - np.log(reduce_sum(np.exp(xdev))) return lsm - log_softmax = jit_compile_reducer(node, log_softmax_py_fn) + log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False) return log_softmax From c744636db602967cd0acc0def6fdb4c0b8bddac8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Feb 2025 14:52:42 +0100 Subject: [PATCH 2/2] Group numba benchmark tests in same class --- tests/link/numba/test_elemwise.py | 138 +++++++++++++++--------------- 1 file changed, 68 insertions(+), 70 deletions(-) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 1da34ff392..b2ccc1ef1e 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -130,25 +130,6 @@ def test_elemwise_runtime_broadcast(): check_elemwise_runtime_broadcast(get_mode("NUMBA")) -def test_elemwise_speed(benchmark): - x = pt.dmatrix("y") - y = pt.dvector("z") - - out = np.exp(2 * x * y + y) - - rng = np.random.default_rng(42) - - x_val = rng.normal(size=(200, 500)) - y_val = rng.normal(size=500) - - func = function([x, y], out, mode="NUMBA") - func = func.vm.jit_fn - (out,) = func(x_val, y_val) - np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) - - benchmark(func, x_val, y_val) - - @pytest.mark.parametrize( "v, new_order", [ @@ -631,41 +612,6 @@ def test_Argmax(x, axes, exc): ) -@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) -@pytest.mark.parametrize("axis", [0, 1]) -def test_logsumexp_benchmark(size, axis, benchmark): - X = pt.matrix("X") - X_max = pt.max(X, axis=axis, keepdims=True) - X_max = pt.switch(pt.isinf(X_max), 0, X_max) - X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max - - rng = np.random.default_rng(23920) - X_val = rng.normal(size=size) - - X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") - - # JIT compile first - res = X_lse_fn(X_val) - exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) - np.testing.assert_array_almost_equal(res, exp_res) - benchmark(X_lse_fn, X_val) - - -def test_fused_elemwise_benchmark(benchmark): - rng = np.random.default_rng(123) - size = 100_000 - x = pytensor.shared(rng.normal(size=size), name="x") - mu = pytensor.shared(rng.normal(size=size), name="mu") - - logp = -((x - mu) ** 2) / 2 - grad_logp = grad(logp.sum(), x) - - func = pytensor.function([], [logp, grad_logp], mode="NUMBA") - # JIT compile first - func() - benchmark(func) - - def test_elemwise_out_type(): # Create a graph with an elemwise # Ravel failes if the elemwise output type is reported incorrectly @@ -681,22 +627,6 @@ def test_elemwise_out_type(): assert func(x_val).shape == (18,) -@pytest.mark.parametrize( - "axis", - (0, 1, 2, (0, 1), (0, 2), (1, 2), None), - ids=lambda x: f"axis={x}", -) -@pytest.mark.parametrize( - "c_contiguous", - (True, False), - ids=lambda x: f"c_contiguous={x}", -) -def test_numba_careduce_benchmark(axis, c_contiguous, benchmark): - return careduce_benchmark_tester( - axis, c_contiguous, mode="NUMBA", benchmark=benchmark - ) - - def test_scalar_loop(): a = float64("a") scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) @@ -709,3 +639,71 @@ def test_scalar_loop(): ([x], [elemwise_loop]), (np.array([1, 2, 3], dtype="float64"),), ) + + +class TestsBenchmark: + def test_elemwise_speed(self, benchmark): + x = pt.dmatrix("y") + y = pt.dvector("z") + + out = np.exp(2 * x * y + y) + + rng = np.random.default_rng(42) + + x_val = rng.normal(size=(200, 500)) + y_val = rng.normal(size=500) + + func = function([x, y], out, mode="NUMBA") + func = func.vm.jit_fn + (out,) = func(x_val, y_val) + np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out) + + benchmark(func, x_val, y_val) + + def test_fused_elemwise_benchmark(self, benchmark): + rng = np.random.default_rng(123) + size = 100_000 + x = pytensor.shared(rng.normal(size=size), name="x") + mu = pytensor.shared(rng.normal(size=size), name="mu") + + logp = -((x - mu) ** 2) / 2 + grad_logp = grad(logp.sum(), x) + + func = pytensor.function([], [logp, grad_logp], mode="NUMBA") + # JIT compile first + func() + benchmark(func) + + @pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_logsumexp_benchmark(self, size, axis, benchmark): + X = pt.matrix("X") + X_max = pt.max(X, axis=axis, keepdims=True) + X_max = pt.switch(pt.isinf(X_max), 0, X_max) + X_lse = pt.log(pt.sum(pt.exp(X - X_max), axis=axis, keepdims=True)) + X_max + + rng = np.random.default_rng(23920) + X_val = rng.normal(size=size) + + X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") + + # JIT compile first + res = X_lse_fn(X_val) + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) + np.testing.assert_array_almost_equal(res, exp_res) + benchmark(X_lse_fn, X_val) + + @pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", + ) + @pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", + ) + def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="NUMBA", benchmark=benchmark + )