@@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}):
402
402
return careduce_fn
403
403
404
404
405
- def jit_compile_reducer (node , fn , * , reduce_to_scalar = False , ** kwds ):
405
+ def jit_compile_reducer (
406
+ node , fn , * , reduce_to_scalar = False , infer_signature = True , ** kwds
407
+ ):
406
408
"""Compile Python source for reduction loops using additional optimizations.
407
409
408
410
Parameters
@@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
411
413
An node from which the signature can be derived.
412
414
fn
413
415
The Python function object to compile.
416
+ reduce_to_scalar: bool, default False
417
+ Whether to reduce output to a scalar (instead of 0d array)
418
+ infer_signature: bool: default True
419
+ Whether to try and infer the function signature from the Apply node.
414
420
kwds
415
421
Extra keywords to be added to the :func:`numba.njit` function.
416
422
@@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
419
425
A :func:`numba.njit`-compiled function.
420
426
421
427
"""
422
- signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
428
+ if infer_signature :
429
+ signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
430
+ args = (signature ,)
431
+ else :
432
+ args = ()
423
433
424
434
# Eagerly compile the function using increased optimizations. This should
425
435
# help improve nested loop reductions.
426
436
with use_optimized_cheap_pass ():
427
437
res = numba_basic .numba_njit (
428
- signature ,
438
+ * args ,
429
439
boundscheck = False ,
430
440
fastmath = config .numba__fastmath ,
431
441
** kwds ,
@@ -926,11 +936,7 @@ def softmax_grad_py_fn(dy, sm):
926
936
return dx
927
937
928
938
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
929
- # softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
930
- softmax_grad = numba_njit (
931
- boundscheck = False ,
932
- fastmath = config .numba__fastmath ,
933
- )(softmax_grad_py_fn )
939
+ softmax_grad = jit_compile_reducer (node , softmax_grad_py_fn , infer_signature = False )
934
940
935
941
return softmax_grad
936
942
0 commit comments