|
| 1 | +import base64 |
| 2 | +import pickle |
1 | 3 | from functools import singledispatch
|
2 | 4 | from numbers import Number
|
3 |
| -import pickle |
4 | 5 | from textwrap import indent
|
5 |
| -from typing import Any, Callable, Literal, Optional, Union |
6 |
| -import base64 |
| 6 | +from typing import Any, Callable, Optional, Union |
7 | 7 |
|
8 | 8 | import numba
|
9 | 9 | import numpy as np
|
10 |
| -from llvmlite import ir |
11 |
| -from numba import TypingError, literal_unroll, types, literally |
| 10 | +from numba import TypingError, types |
12 | 11 | from numba.core import cgutils
|
13 |
| -from numba.cpython.unsafe.tuple import tuple_setitem |
14 | 12 | from numba.np import arrayobj
|
15 | 13 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
16 | 14 |
|
17 | 15 | from pytensor import config
|
18 | 16 | from pytensor.graph.basic import Apply
|
19 | 17 | from pytensor.graph.op import Op
|
20 | 18 | from pytensor.link.numba.dispatch import basic as numba_basic
|
| 19 | +from pytensor.link.numba.dispatch import elemwise_codegen |
21 | 20 | from pytensor.link.numba.dispatch.basic import (
|
22 | 21 | create_numba_signature,
|
23 | 22 | create_tuple_creator,
|
24 | 23 | numba_funcify,
|
25 | 24 | numba_njit,
|
26 | 25 | use_optimized_cheap_pass,
|
27 | 26 | )
|
28 |
| -from pytensor.link.numba.dispatch.helpers import check_broadcasting, tuple_mapper |
29 |
| -from pytensor.link.numba.dispatch import elemwise_codegen |
30 | 27 | from pytensor.link.utils import compile_function_src, get_name_for_object
|
31 | 28 | from pytensor.scalar.basic import (
|
32 | 29 | AND,
|
|
45 | 42 | from pytensor.scalar.basic import add as add_as
|
46 | 43 | from pytensor.scalar.basic import scalar_maximum
|
47 | 44 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
48 |
| -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros |
| 45 | +from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum |
49 | 46 | from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
|
50 | 47 | from pytensor.tensor.type import scalar
|
51 | 48 |
|
@@ -376,8 +373,7 @@ def careduce_maximum(input):
|
376 | 373 | careduce_def_src = f"""
|
377 | 374 | def {careduce_fn_name}({input_name}):
|
378 | 375 | {careduce_assign_lines}
|
379 |
| - #return np.asarray({var_name}) |
380 |
| - return {var_name} |
| 376 | + return np.asarray({var_name}) |
381 | 377 | """
|
382 | 378 |
|
383 | 379 | careduce_fn = compile_function_src(
|
@@ -447,6 +443,7 @@ def axis_apply_fn(x):
|
447 | 443 | }
|
448 | 444 | }
|
449 | 445 |
|
| 446 | + |
450 | 447 | @numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
|
451 | 448 | def _vectorized(
|
452 | 449 | typingctx,
|
@@ -490,7 +487,6 @@ def _vectorized(
|
490 | 487 | inplace_pattern = inplace_pattern.literal_value
|
491 | 488 | inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
|
492 | 489 |
|
493 |
| - n_inputs = len(inputs) |
494 | 490 | n_outputs = len(output_bc_patterns)
|
495 | 491 |
|
496 | 492 | if not len(inputs) > 0:
|
@@ -531,7 +527,10 @@ def codegen(
|
531 | 527 |
|
532 | 528 | [_, _, _, _, _, inputs] = args
|
533 | 529 | inputs = cgutils.unpack_tuple(builder, inputs)
|
534 |
| - inputs = [arrayobj.make_array(ty)(ctx, builder, val) for ty, val in zip(input_types, inputs)] |
| 530 | + inputs = [ |
| 531 | + arrayobj.make_array(ty)(ctx, builder, val) |
| 532 | + for ty, val in zip(input_types, inputs) |
| 533 | + ] |
535 | 534 | in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
|
536 | 535 |
|
537 | 536 | iter_shape = elemwise_codegen.compute_itershape(
|
@@ -586,14 +585,22 @@ def _check_input_shapes(*_):
|
586 | 585 | return outputs[0]._getvalue()
|
587 | 586 |
|
588 | 587 | for inplace_idx in dict(inplace_pattern):
|
589 |
| - ctx.nrt.incref(builder, sig.return_type.types[inplace_idx], outputs[inplace_idx]._get_value()) |
590 |
| - return ctx.make_tuple(builder, sig.return_type, [out._getvalue() for out in outputs]) |
| 588 | + ctx.nrt.incref( |
| 589 | + builder, |
| 590 | + sig.return_type.types[inplace_idx], |
| 591 | + outputs[inplace_idx]._get_value(), |
| 592 | + ) |
| 593 | + return ctx.make_tuple( |
| 594 | + builder, sig.return_type, [out._getvalue() for out in outputs] |
| 595 | + ) |
591 | 596 |
|
592 | 597 | # TODO check inplace_pattern
|
593 |
| - ret_type = types.Tuple([ |
594 |
| - types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") |
595 |
| - for dtype in output_dtypes |
596 |
| - ]) |
| 598 | + ret_type = types.Tuple( |
| 599 | + [ |
| 600 | + types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") |
| 601 | + for dtype in output_dtypes |
| 602 | + ] |
| 603 | + ) |
597 | 604 | if len(output_dtypes) == 1:
|
598 | 605 | ret_type = ret_type.types[0]
|
599 | 606 | sig = ret_type(*arg_types)
|
@@ -649,6 +656,40 @@ def elemwise_wrapper(*inputs):
|
649 | 656 | return elemwise_wrapper
|
650 | 657 |
|
651 | 658 |
|
| 659 | +@numba_funcify.register(Sum) |
| 660 | +def numba_funcify_Sum(op, node, **kwargs): |
| 661 | + axes = op.axis |
| 662 | + if axes is None: |
| 663 | + axes = list(range(node.inputs[0].ndim)) |
| 664 | + |
| 665 | + axes = list(axes) |
| 666 | + |
| 667 | + ndim_input = node.inputs[0].ndim |
| 668 | + |
| 669 | + if hasattr(op, "acc_dtype") and op.acc_dtype is not None: |
| 670 | + acc_dtype = op.acc_dtype |
| 671 | + else: |
| 672 | + acc_dtype = node.outputs[0].type.dtype |
| 673 | + |
| 674 | + np_acc_dtype = np.dtype(acc_dtype) |
| 675 | + |
| 676 | + if ndim_input == len(axes): |
| 677 | + |
| 678 | + @numba_njit(fastmath=True) |
| 679 | + def impl_sum(array): |
| 680 | + # TODO The accumulation itself should happen in acc_dtype... |
| 681 | + return np.asarray(array.sum()).astype(np_acc_dtype) |
| 682 | + |
| 683 | + else: |
| 684 | + |
| 685 | + @numba_njit(fastmath=True) |
| 686 | + def impl_sum(array): |
| 687 | + # TODO The accumulation itself should happen in acc_dtype... |
| 688 | + return array.sum(axes).astype(np_acc_dtype) |
| 689 | + |
| 690 | + return impl_sum |
| 691 | + |
| 692 | + |
652 | 693 | @numba_funcify.register(CAReduce)
|
653 | 694 | def numba_funcify_CAReduce(op, node, **kwargs):
|
654 | 695 | axes = op.axis
|
|
0 commit comments