Skip to content

Commit 9fdaeca

Browse files
committed
Specialized numba sum impl
1 parent b20e039 commit 9fdaeca

File tree

2 files changed

+74
-43
lines changed

2 files changed

+74
-43
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
1+
import base64
2+
import pickle
13
from functools import singledispatch
24
from numbers import Number
3-
import pickle
45
from textwrap import indent
5-
from typing import Any, Callable, Literal, Optional, Union
6-
import base64
6+
from typing import Any, Callable, Optional, Union
77

88
import numba
99
import numpy as np
10-
from llvmlite import ir
11-
from numba import TypingError, literal_unroll, types, literally
10+
from numba import TypingError, types
1211
from numba.core import cgutils
13-
from numba.cpython.unsafe.tuple import tuple_setitem
1412
from numba.np import arrayobj
1513
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1614

1715
from pytensor import config
1816
from pytensor.graph.basic import Apply
1917
from pytensor.graph.op import Op
2018
from pytensor.link.numba.dispatch import basic as numba_basic
19+
from pytensor.link.numba.dispatch import elemwise_codegen
2120
from pytensor.link.numba.dispatch.basic import (
2221
create_numba_signature,
2322
create_tuple_creator,
2423
numba_funcify,
2524
numba_njit,
2625
use_optimized_cheap_pass,
2726
)
28-
from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper
29-
from pytensor.link.numba.dispatch import elemwise_codegen
3027
from pytensor.link.utils import compile_function_src, get_name_for_object
3128
from pytensor.scalar.basic import (
3229
AND,
@@ -45,7 +42,7 @@
4542
from pytensor.scalar.basic import add as add_as
4643
from pytensor.scalar.basic import scalar_maximum
4744
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
48-
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
45+
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
4946
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
5047
from pytensor.tensor.type import scalar
5148

@@ -376,8 +373,7 @@ def careduce_maximum(input):
376373
careduce_def_src = f"""
377374
def {careduce_fn_name}({input_name}):
378375
{careduce_assign_lines}
379-
#return np.asarray({var_name})
380-
return {var_name}
376+
return np.asarray({var_name})
381377
"""
382378

383379
careduce_fn = compile_function_src(
@@ -447,6 +443,7 @@ def axis_apply_fn(x):
447443
}
448444
}
449445

446+
450447
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
451448
def _vectorized(
452449
typingctx,
@@ -490,7 +487,6 @@ def _vectorized(
490487
inplace_pattern = inplace_pattern.literal_value
491488
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
492489

493-
n_inputs = len(inputs)
494490
n_outputs = len(output_bc_patterns)
495491

496492
if not len(inputs) > 0:
@@ -531,7 +527,10 @@ def codegen(
531527

532528
[_, _, _, _, _, inputs] = args
533529
inputs = cgutils.unpack_tuple(builder, inputs)
534-
inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)]
530+
inputs = [
531+
arrayobj.make_array(ty)(ctx, builder, val)
532+
for ty, val in zip(input_types, inputs)
533+
]
535534
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
536535

537536
iter_shape = elemwise_codegen.compute_itershape(
@@ -586,14 +585,22 @@ def _check_input_shapes(*_):
586585
return outputs[0]._getvalue()
587586

588587
for inplace_idx in dict(inplace_pattern):
589-
ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value())
590-
return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs])
588+
ctx.nrt.incref(
589+
builder,
590+
sig.return_type.types[inplace_idx],
591+
outputs[inplace_idx]._get_value(),
592+
)
593+
return ctx.make_tuple(
594+
builder, sig.return_type, [out._getvalue() for out in outputs]
595+
)
591596

592597
# TODO check inplace_pattern
593-
ret_type = types.Tuple([
594-
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
595-
for dtype in output_dtypes
596-
])
598+
ret_type = types.Tuple(
599+
[
600+
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
601+
for dtype in output_dtypes
602+
]
603+
)
597604
if len(output_dtypes) == 1:
598605
ret_type = ret_type.types[0]
599606
sig = ret_type(*arg_types)
@@ -649,6 +656,40 @@ def elemwise_wrapper(*inputs):
649656
return elemwise_wrapper
650657

651658

659+
@numba_funcify.register(Sum)
660+
def numba_funcify_Sum(op, node, **kwargs):
661+
axes = op.axis
662+
if axes is None:
663+
axes = list(range(node.inputs[0].ndim))
664+
665+
axes = list(axes)
666+
667+
ndim_input = node.inputs[0].ndim
668+
669+
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
670+
acc_dtype = op.acc_dtype
671+
else:
672+
acc_dtype = node.outputs[0].type.dtype
673+
674+
np_acc_dtype = np.dtype(acc_dtype)
675+
676+
if ndim_input == len(axes):
677+
678+
@numba_njit(fastmath=True)
679+
def impl_sum(array):
680+
# TODO The accumulation itself should happen in acc_dtype...
681+
return np.asarray(array.sum()).astype(np_acc_dtype)
682+
683+
else:
684+
685+
@numba_njit(fastmath=True)
686+
def impl_sum(array):
687+
# TODO The accumulation itself should happen in acc_dtype...
688+
return array.sum(axes).astype(np_acc_dtype)
689+
690+
return impl_sum
691+
692+
652693
@numba_funcify.register(CAReduce)
653694
def numba_funcify_CAReduce(op, node, **kwargs):
654695
axes = op.axis

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import numba
2+
import numpy as np
13
from llvmlite import ir
24
from numba import types
3-
from numba.np import arrayobj
45
from numba.core import cgutils
5-
import numba
6-
import numpy as np
6+
from numba.np import arrayobj
77

88

99
def compute_itershape(
@@ -35,7 +35,9 @@ def compute_itershape(
3535
return shape
3636

3737

38-
def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types):
38+
def make_outputs(
39+
ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace, inputs, input_types
40+
):
3941
arrays = []
4042
ar_types: list[types.Array] = []
4143
one = ir.IntType(64)(1)
@@ -52,8 +54,7 @@ def make_outputs(ctx, builder: ir.IRBuilder, iter_shape, out_bc, dtypes, inplace
5254
# This is actually an interal numba function, I guess we could
5355
# call `numba.nd.unsafe.ndarray` instead?
5456
shape = [
55-
length if not bc_dim else one
56-
for length, bc_dim in zip(iter_shape, bc)
57+
length if not bc_dim else one for length, bc_dim in zip(iter_shape, bc)
5758
]
5859
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
5960
arrays.append(array)
@@ -84,7 +85,7 @@ def make_loop_call(
8485
safe = (False, False)
8586
n_outputs = len(outputs)
8687

87-
#context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
88+
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
8889

8990
# Lower the code of the scalar function so that we can use it in the inner loop
9091
# Caching is set to false to avoid a numba bug TODO ref?
@@ -155,12 +156,8 @@ def extract_array(aryty, obj):
155156
# Load values from input arrays
156157
input_vals = []
157158
for array_info, bc in zip(inputs, input_bc, strict=True):
158-
idxs_bc = [
159-
zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)
160-
]
161-
ptr = cgutils.get_item_pointer2(
162-
context, builder, *array_info, idxs_bc, *safe
163-
)
159+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)]
160+
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
164161
val = builder.load(ptr)
165162
# val.set_metadata("alias.scope", input_scope_set)
166163
# val.set_metadata("noalias", output_scope_set)
@@ -193,12 +190,9 @@ def extract_array(aryty, obj):
193190
# store.set_metadata("noalias", input_scope_set)
194191
else:
195192
idxs_bc = [
196-
zero if bc else idx
197-
for idx, bc in zip(idxs, output_bc[i], strict=True)
193+
zero if bc else idx for idx, bc in zip(idxs, output_bc[i], strict=True)
198194
]
199-
ptr = cgutils.get_item_pointer2(
200-
context, builder, *outputs[i], idxs_bc
201-
)
195+
ptr = cgutils.get_item_pointer2(context, builder, *outputs[i], idxs_bc)
202196
# store = builder.store(value, ptr)
203197
arrayobj.store_item(context, builder, output_types[i], value, ptr)
204198
# store.set_metadata("alias.scope", output_scope_set)
@@ -210,9 +204,7 @@ def extract_array(aryty, obj):
210204
if accu_depth == depth:
211205
idxs_bc = [
212206
zero if bc else idx
213-
for idx, bc in zip(
214-
idxs, output_bc[output], strict=True
215-
)
207+
for idx, bc in zip(idxs, output_bc[output], strict=True)
216208
]
217209
ptr = cgutils.get_item_pointer2(
218210
context, builder, *outputs[output], idxs_bc
@@ -221,9 +213,7 @@ def extract_array(aryty, obj):
221213
# load.set_metadata("alias.scope", output_scope_set)
222214
# load.set_metadata("noalias", input_scope_set)
223215
# store = builder.store(load, ptr)
224-
arrayobj.store_item(
225-
context, builder, output_types[output], load, ptr
226-
)
216+
arrayobj.store_item(context, builder, output_types[output], load, ptr)
227217
# store.set_metadata("alias.scope", output_scope_set)
228218
# store.set_metadata("noalias", input_scope_set)
229219
loop.__exit__(None, None, None)

0 commit comments

Comments
 (0)