Skip to content

Commit 4efbd19

Browse files
committed
Refactor SoftmaxGrad numba patch
1 parent da66c2e commit 4efbd19

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}):
402402
return careduce_fn
403403

404404

405-
def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
405+
def jit_compile_reducer(
406+
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
407+
):
406408
"""Compile Python source for reduction loops using additional optimizations.
407409
408410
Parameters
@@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
411413
An node from which the signature can be derived.
412414
fn
413415
The Python function object to compile.
416+
reduce_to_scalar: bool, default False
417+
Whether to reduce output to a scalar (instead of 0d array)
418+
infer_signature: bool: default True
419+
Whether to try and infer the function signature from the Apply node.
414420
kwds
415421
Extra keywords to be added to the :func:`numba.njit` function.
416422
@@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
419425
A :func:`numba.njit`-compiled function.
420426
421427
"""
422-
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
428+
if infer_signature:
429+
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
430+
args = (signature,)
431+
else:
432+
args = ()
423433

424434
# Eagerly compile the function using increased optimizations. This should
425435
# help improve nested loop reductions.
426436
with use_optimized_cheap_pass():
427437
res = numba_basic.numba_njit(
428-
signature,
438+
*args,
429439
boundscheck=False,
430440
fastmath=config.numba__fastmath,
431441
**kwds,
@@ -926,11 +936,7 @@ def softmax_grad_py_fn(dy, sm):
926936
return dx
927937

928938
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
929-
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
930-
softmax_grad = numba_njit(
931-
boundscheck=False,
932-
fastmath=config.numba__fastmath,
933-
)(softmax_grad_py_fn)
939+
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)
934940

935941
return softmax_grad
936942

0 commit comments

Comments
 (0)