|
9 | 9 | from pytensor.graph.op import Op
|
10 | 10 | from pytensor.link.numba.dispatch import basic as numba_basic
|
11 | 11 | from pytensor.link.numba.dispatch.basic import (
|
12 |
| - create_numba_signature, |
13 | 12 | numba_funcify,
|
14 | 13 | numba_njit,
|
15 |
| - use_optimized_cheap_pass, |
16 | 14 | )
|
17 | 15 | from pytensor.link.numba.dispatch.vectorize_codegen import (
|
18 | 16 | _jit_options,
|
@@ -245,47 +243,6 @@ def {careduce_fn_name}(x):
|
245 | 243 | return careduce_fn
|
246 | 244 |
|
247 | 245 |
|
248 |
| -def jit_compile_reducer( |
249 |
| - node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds |
250 |
| -): |
251 |
| - """Compile Python source for reduction loops using additional optimizations. |
252 |
| -
|
253 |
| - Parameters |
254 |
| - ========== |
255 |
| - node |
256 |
| - An node from which the signature can be derived. |
257 |
| - fn |
258 |
| - The Python function object to compile. |
259 |
| - reduce_to_scalar: bool, default False |
260 |
| - Whether to reduce output to a scalar (instead of 0d array) |
261 |
| - infer_signature: bool: default True |
262 |
| - Whether to try and infer the function signature from the Apply node. |
263 |
| - kwds |
264 |
| - Extra keywords to be added to the :func:`numba.njit` function. |
265 |
| -
|
266 |
| - Returns |
267 |
| - ======= |
268 |
| - A :func:`numba.njit`-compiled function. |
269 |
| -
|
270 |
| - """ |
271 |
| - if infer_signature: |
272 |
| - signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) |
273 |
| - args = (signature,) |
274 |
| - else: |
275 |
| - args = () |
276 |
| - |
277 |
| - # Eagerly compile the function using increased optimizations. This should |
278 |
| - # help improve nested loop reductions. |
279 |
| - with use_optimized_cheap_pass(): |
280 |
| - res = numba_basic.numba_njit( |
281 |
| - *args, |
282 |
| - boundscheck=False, |
283 |
| - **kwds, |
284 |
| - )(fn) |
285 |
| - |
286 |
| - return res |
287 |
| - |
288 |
| - |
289 | 246 | def create_axis_apply_fn(fn, axis, ndim, dtype):
|
290 | 247 | axis = normalize_axis_index(axis, ndim)
|
291 | 248 |
|
@@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
|
448 | 405 | np.dtype(node.outputs[0].type.dtype),
|
449 | 406 | )
|
450 | 407 |
|
451 |
| - careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) |
| 408 | + careduce_fn = numba_njit(careduce_py_fn, boundscheck=False) |
452 | 409 | return careduce_fn
|
453 | 410 |
|
454 | 411 |
|
@@ -579,7 +536,7 @@ def softmax_py_fn(x):
|
579 | 536 | sm = e_x / w
|
580 | 537 | return sm
|
581 | 538 |
|
582 |
| - softmax = jit_compile_reducer(node, softmax_py_fn) |
| 539 | + softmax = numba_njit(softmax_py_fn, boundscheck=False) |
583 | 540 |
|
584 | 541 | return softmax
|
585 | 542 |
|
@@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm):
|
608 | 565 | dx = dy_times_sm - sum_dy_times_sm * sm
|
609 | 566 | return dx
|
610 | 567 |
|
611 |
| - # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True) |
612 |
| - softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False) |
| 568 | + softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False) |
613 | 569 |
|
614 | 570 | return softmax_grad
|
615 | 571 |
|
@@ -647,7 +603,7 @@ def log_softmax_py_fn(x):
|
647 | 603 | lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
|
648 | 604 | return lsm
|
649 | 605 |
|
650 |
| - log_softmax = jit_compile_reducer(node, log_softmax_py_fn) |
| 606 | + log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False) |
651 | 607 | return log_softmax
|
652 | 608 |
|
653 | 609 |
|
|
0 commit comments