Skip to content

Commit 4bb15bf

Browse files
committed
Avoid manipulation of deprecated _mpm_cheap
Internal API changed in numba 0.61 Existing benchmarks don't show any difference in performance
1 parent 911c6a3 commit 4bb15bf

File tree

2 files changed

+4
-66
lines changed

2 files changed

+4
-66
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import operator
22
import sys
33
import warnings
4-
from contextlib import contextmanager
54
from copy import copy
65
from functools import singledispatch
76
from textwrap import dedent
@@ -362,23 +361,6 @@ def create_arg_string(x):
362361
return args
363362

364363

365-
@contextmanager
366-
def use_optimized_cheap_pass(*args, **kwargs):
367-
"""Temporarily replace the cheap optimization pass with a better one."""
368-
from numba.core.registry import cpu_target
369-
370-
context = cpu_target.target_context._internal_codegen
371-
old_pm = context._mpm_cheap
372-
new_pm = context._module_pass_manager(
373-
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
374-
)
375-
context._mpm_cheap = new_pm
376-
try:
377-
yield
378-
finally:
379-
context._mpm_cheap = old_pm
380-
381-
382364
@singledispatch
383365
def numba_typify(data, dtype=None, **kwargs):
384366
return data

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from pytensor.graph.op import Op
1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import (
12-
create_numba_signature,
1312
numba_funcify,
1413
numba_njit,
15-
use_optimized_cheap_pass,
1614
)
1715
from pytensor.link.numba.dispatch.vectorize_codegen import (
1816
_jit_options,
@@ -245,47 +243,6 @@ def {careduce_fn_name}(x):
245243
return careduce_fn
246244

247245

248-
def jit_compile_reducer(
249-
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
250-
):
251-
"""Compile Python source for reduction loops using additional optimizations.
252-
253-
Parameters
254-
==========
255-
node
256-
An node from which the signature can be derived.
257-
fn
258-
The Python function object to compile.
259-
reduce_to_scalar: bool, default False
260-
Whether to reduce output to a scalar (instead of 0d array)
261-
infer_signature: bool: default True
262-
Whether to try and infer the function signature from the Apply node.
263-
kwds
264-
Extra keywords to be added to the :func:`numba.njit` function.
265-
266-
Returns
267-
=======
268-
A :func:`numba.njit`-compiled function.
269-
270-
"""
271-
if infer_signature:
272-
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
273-
args = (signature,)
274-
else:
275-
args = ()
276-
277-
# Eagerly compile the function using increased optimizations. This should
278-
# help improve nested loop reductions.
279-
with use_optimized_cheap_pass():
280-
res = numba_basic.numba_njit(
281-
*args,
282-
boundscheck=False,
283-
**kwds,
284-
)(fn)
285-
286-
return res
287-
288-
289246
def create_axis_apply_fn(fn, axis, ndim, dtype):
290247
axis = normalize_axis_index(axis, ndim)
291248

@@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
448405
np.dtype(node.outputs[0].type.dtype),
449406
)
450407

451-
careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
408+
careduce_fn = numba_njit(careduce_py_fn, boundscheck=False)
452409
return careduce_fn
453410

454411

@@ -579,7 +536,7 @@ def softmax_py_fn(x):
579536
sm = e_x / w
580537
return sm
581538

582-
softmax = jit_compile_reducer(node, softmax_py_fn)
539+
softmax = numba_njit(softmax_py_fn, boundscheck=False)
583540

584541
return softmax
585542

@@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm):
608565
dx = dy_times_sm - sum_dy_times_sm * sm
609566
return dx
610567

611-
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
612-
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)
568+
softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False)
613569

614570
return softmax_grad
615571

@@ -647,7 +603,7 @@ def log_softmax_py_fn(x):
647603
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
648604
return lsm
649605

650-
log_softmax = jit_compile_reducer(node, log_softmax_py_fn)
606+
log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False)
651607
return log_softmax
652608

653609

0 commit comments

Comments
 (0)