Skip to content

Commit 7e68db8

Browse files
committed
Use numba code for supported CAReduce cases
1 parent 7147f7f commit 7e68db8

File tree

2 files changed

+53
-165
lines changed

2 files changed

+53
-165
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,16 @@
4444
)
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
47-
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
47+
from pytensor.tensor.math import (
48+
All,
49+
Argmax,
50+
Max,
51+
Min,
52+
MulWithoutZeros,
53+
Prod,
54+
ProdWithoutZeros,
55+
Sum,
56+
)
4857
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4958
from pytensor.tensor.type import scalar
5059

@@ -546,37 +555,52 @@ def ov_elemwise(*inputs):
546555

547556

548557
@numba_funcify.register(Sum)
549-
def numba_funcify_Sum(op, node, **kwargs):
558+
@numba_funcify.register(Prod)
559+
@numba_funcify.register(ProdWithoutZeros)
560+
@numba_funcify.register(Max)
561+
@numba_funcify.register(Min)
562+
@numba_funcify.register(All)
563+
@numba_funcify.register(Any)
564+
def numba_funcify_CAReduce_specialized(op, node, **kwargs):
565+
if isinstance(op, ProdWithoutZeros):
566+
# ProdWithoutZeros is the same as Prod but the gradient can assume no zeros
567+
np_op = np.prod
568+
else:
569+
np_op = getattr(np, op.__class__.__name__.lower())
570+
550571
axes = op.axis
551572
if axes is None:
552573
axes = list(range(node.inputs[0].ndim))
553574

554-
axes = tuple(axes)
575+
axes = tuple(sorted(axes))
555576

556577
ndim_input = node.inputs[0].ndim
578+
out_dtype = np.dtype(node.outputs[0].dtype)
557579

558-
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
559-
acc_dtype = op.acc_dtype
560-
else:
561-
acc_dtype = node.outputs[0].type.dtype
562-
563-
np_acc_dtype = np.dtype(acc_dtype)
580+
if len(axes) == 0:
564581

565-
out_dtype = np.dtype(node.outputs[0].dtype)
582+
@numba_njit(fastmath=True)
583+
def impl_sum(array):
584+
return np.asarray(array, dtype=out_dtype)
566585

567-
if ndim_input == len(axes):
586+
elif (
587+
len(axes) == 1
588+
# Some Ops don't support axis in Numba
589+
and not isinstance(op, Prod | ProdWithoutZeros | All | Prod | Mean | Max | Min)
590+
):
568591

569592
@numba_njit(fastmath=True)
570593
def impl_sum(array):
571-
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
594+
return np.asarray(np_op(array, axis=axes[0])).astype(out_dtype)
572595

573-
elif len(axes) == 0:
596+
elif len(axes) == ndim_input:
574597

575598
@numba_njit(fastmath=True)
576599
def impl_sum(array):
577-
return np.asarray(array, dtype=out_dtype)
600+
return np.asarray(np_op(array)).astype(out_dtype)
578601

579602
else:
603+
# Slow path
580604
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)
581605

582606
return impl_sum

tests/link/numba/test_elemwise.py

Lines changed: 15 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -236,157 +236,21 @@ def test_Dimshuffle_non_contiguous():
236236
assert func(np.zeros(3), np.array([1])).ndim == 0
237237

238238

239-
@pytest.mark.parametrize(
240-
"careduce_fn, axis, v",
241-
[
242-
(
243-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
244-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
245-
)(x),
246-
0,
247-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
248-
),
249-
(
250-
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
251-
0,
252-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
253-
),
254-
(
255-
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
256-
0,
257-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
258-
),
259-
(
260-
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
261-
0,
262-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
263-
),
264-
(
265-
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
266-
0,
267-
set_test_value(
268-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
269-
),
270-
),
271-
(
272-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
273-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
274-
)(x),
275-
0,
276-
set_test_value(
277-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
278-
),
279-
),
280-
(
281-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
282-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
283-
)(x),
284-
(0, 1),
285-
set_test_value(
286-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
287-
),
288-
),
289-
(
290-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
291-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
292-
)(x),
293-
(1, 0),
294-
set_test_value(
295-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
296-
),
297-
),
298-
(
299-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
300-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
301-
)(x),
302-
None,
303-
set_test_value(
304-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
305-
),
306-
),
307-
(
308-
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
309-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
310-
)(x),
311-
1,
312-
set_test_value(
313-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
314-
),
315-
),
316-
(
317-
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
318-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
319-
)(x),
320-
0,
321-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
322-
),
323-
(
324-
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
325-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
326-
)(x),
327-
0,
328-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
329-
),
330-
(
331-
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
332-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
333-
)(x),
334-
0,
335-
set_test_value(
336-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
337-
),
338-
),
339-
(
340-
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
341-
axis=axis, dtype=dtype, acc_dtype=acc_dtype
342-
)(x),
343-
1,
344-
set_test_value(
345-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
346-
),
347-
),
348-
(
349-
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
350-
None,
351-
set_test_value(
352-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
353-
),
354-
),
355-
(
356-
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
357-
None,
358-
set_test_value(
359-
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
360-
),
361-
),
362-
(
363-
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
364-
None,
365-
set_test_value(
366-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
367-
),
368-
),
369-
(
370-
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
371-
None,
372-
set_test_value(
373-
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
374-
),
375-
),
376-
],
377-
)
378-
def test_CAReduce(careduce_fn, axis, v):
379-
g = careduce_fn(v, axis=axis)
380-
g_fg = FunctionGraph(outputs=[g])
381-
382-
compare_numba_and_py(
383-
g_fg,
384-
[
385-
i.tag.test_value
386-
for i in g_fg.inputs
387-
if not isinstance(i, SharedVariable | Constant)
388-
],
389-
)
239+
@pytest.mark.parametrize("axis", [0, None, (0, 1)])
240+
@pytest.mark.parametrize("op", [Sum, Prod, ProdWithoutZeros, All, Any, Mean, Max, Min])
241+
def test_CAReduce(op, axis):
242+
if op == Mean and isinstance(axis, tuple) and len(axis) > 1:
243+
pytest.xfail("Mean does not support multiple partial axes")
244+
245+
bool_reduction = op in (All, Any)
246+
x = pt.tensor3("x", dtype=bool if bool_reduction else config.floatX)
247+
g = op(axis=axis)(x)
248+
g_fg = FunctionGraph([x], [g])
249+
250+
x_test = np.random.normal(size=(2, 3, 4)).astype(config.floatX)
251+
if bool_reduction:
252+
x_test = x_test > 0
253+
compare_numba_and_py(g_fg, [x_test])
390254

391255

392256
def test_scalar_Elemwise_Clip():

0 commit comments

Comments
 (0)