9
9
import numpy as np
10
10
from numba import TypingError , types
11
11
from numba .core import cgutils
12
+ from numba .core .extending import overload
12
13
from numba .np import arrayobj
13
14
from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
14
15
@@ -174,6 +175,7 @@ def create_axis_reducer(
174
175
ndim : int ,
175
176
dtype : numba .types .Type ,
176
177
keepdims : bool = False ,
178
+ return_scalar = False ,
177
179
) -> numba .core .dispatcher .Dispatcher :
178
180
r"""Create Python function that performs a NumPy-like reduction on a given axis.
179
181
@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
284
286
inplace_update_statement = indent (inplace_update_statement , " " * 4 * 2 )
285
287
286
288
return_expr = "res" if keepdims else "res.item()"
289
+ if not return_scalar :
290
+ return_expr = f"np.asarray({ return_expr } )"
287
291
reduce_elemwise_def_src = f"""
288
292
def { reduce_elemwise_fn_name } (x):
289
293
@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
305
309
306
310
307
311
def create_multiaxis_reducer (
308
- scalar_op , identity , axes , ndim , dtype , input_name = "input"
312
+ scalar_op ,
313
+ identity ,
314
+ axes ,
315
+ ndim ,
316
+ dtype ,
317
+ input_name = "input" ,
318
+ return_scalar = False ,
309
319
):
310
320
r"""Construct a function that reduces multiple axes.
311
321
@@ -336,6 +346,8 @@ def careduce_maximum(input):
336
346
The number of dimensions of the result.
337
347
dtype:
338
348
The data type of the result.
349
+ return_scalar:
350
+ If True, return a scalar, otherwise an array.
339
351
340
352
Returns
341
353
=======
@@ -370,10 +382,17 @@ def careduce_maximum(input):
370
382
)
371
383
372
384
careduce_assign_lines = indent ("\n " .join (careduce_lines_src ), " " * 4 )
385
+ if not return_scalar :
386
+ pre_result = "np.asarray"
387
+ post_result = ""
388
+ else :
389
+ pre_result = "np.asarray"
390
+ post_result = ".item()"
391
+
373
392
careduce_def_src = f"""
374
393
def { careduce_fn_name } ({ input_name } ):
375
394
{ careduce_assign_lines }
376
- return np.asarray ({ var_name } )
395
+ return { pre_result } ({ var_name } ){ post_result }
377
396
"""
378
397
379
398
careduce_fn = compile_function_src (
@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
383
402
return careduce_fn
384
403
385
404
386
- def jit_compile_reducer (node , fn , ** kwds ):
405
+ def jit_compile_reducer (node , fn , * , reduce_to_scalar = False , * *kwds ):
387
406
"""Compile Python source for reduction loops using additional optimizations.
388
407
389
408
Parameters
@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
400
419
A :func:`numba.njit`-compiled function.
401
420
402
421
"""
403
- signature = create_numba_signature (node , reduce_to_scalar = True )
422
+ signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
404
423
405
424
# Eagerly compile the function using increased optimizations. This should
406
425
# help improve nested loop reductions.
@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
618
637
inplace_pattern = tuple (op .inplace_pattern .items ())
619
638
620
639
# numba doesn't support nested literals right now...
621
- input_bc_patterns = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
622
- output_bc_patterns = base64 .encodebytes (pickle .dumps (output_bc_patterns )).decode ()
623
- output_dtypes = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
624
- inplace_pattern = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
640
+ input_bc_patterns_enc = base64 .encodebytes (pickle .dumps (input_bc_patterns )).decode ()
641
+ output_bc_patterns_enc = base64 .encodebytes (
642
+ pickle .dumps (output_bc_patterns )
643
+ ).decode ()
644
+ output_dtypes_enc = base64 .encodebytes (pickle .dumps (output_dtypes )).decode ()
645
+ inplace_pattern_enc = base64 .encodebytes (pickle .dumps (inplace_pattern )).decode ()
625
646
626
- @numba_njit
627
647
def elemwise_wrapper (* inputs ):
628
648
return _vectorized (
629
649
scalar_op_fn ,
630
- input_bc_patterns ,
631
- output_bc_patterns ,
632
- output_dtypes ,
633
- inplace_pattern ,
650
+ input_bc_patterns_enc ,
651
+ output_bc_patterns_enc ,
652
+ output_dtypes_enc ,
653
+ inplace_pattern_enc ,
634
654
inputs ,
635
655
)
636
656
637
- return elemwise_wrapper
657
+ # Pure python implementation, that will be used in tests
658
+ def elemwise (* inputs ):
659
+ inputs = [np .asarray (input ) for input in inputs ]
660
+ inputs_bc = np .broadcast_arrays (* inputs )
661
+ shape = inputs [0 ].shape
662
+ for input , bc in zip (inputs , input_bc_patterns ):
663
+ for length , allow_bc , iter_length in zip (input .shape , bc , shape ):
664
+ if length == 1 and shape and iter_length != 1 and not allow_bc :
665
+ raise ValueError ("Broadcast not allowed." )
666
+
667
+ outputs = []
668
+ for dtype in output_dtypes :
669
+ outputs .append (np .empty (shape , dtype = dtype ))
670
+
671
+ for idx in np .ndindex (shape ):
672
+ vals = [input [idx ] for input in inputs_bc ]
673
+ outs = scalar_op_fn (* vals )
674
+ if not isinstance (outs , tuple ):
675
+ outs = (outs ,)
676
+ for out , out_val in zip (outputs , outs ):
677
+ out [idx ] = out_val
678
+
679
+ outputs_summed = []
680
+ for output , bc in zip (outputs , output_bc_patterns ):
681
+ axes = tuple (np .nonzero (bc )[0 ])
682
+ outputs_summed .append (output .sum (axes , keepdims = True ))
683
+ if len (outputs_summed ) != 1 :
684
+ return tuple (outputs_summed )
685
+ return outputs_summed [0 ]
686
+
687
+ @overload (elemwise )
688
+ def ov_elemwise (* inputs ):
689
+ return elemwise_wrapper
690
+
691
+ return elemwise
638
692
639
693
640
694
@numba_funcify .register (Sum )
@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
643
697
if axes is None :
644
698
axes = list (range (node .inputs [0 ].ndim ))
645
699
646
- axes = list (axes )
700
+ axes = tuple (axes )
647
701
648
702
ndim_input = node .inputs [0 ].ndim
649
703
@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
658
712
659
713
@numba_njit (fastmath = True )
660
714
def impl_sum (array ):
661
- # TODO The accumulation itself should happen in acc_dtype...
662
- return np .asarray (array .sum ()).astype (np_acc_dtype )
715
+ return np .asarray (array .sum (), dtype = np_acc_dtype )
663
716
664
- else :
717
+ elif len ( axes ) == 0 :
665
718
666
719
@numba_njit (fastmath = True )
667
720
def impl_sum (array ):
668
- # TODO The accumulation itself should happen in acc_dtype...
669
- return array .sum (axes ).astype (np_acc_dtype )
721
+ return array
722
+
723
+ else :
724
+ impl_sum = numba_funcify_CAReduce (op , node , ** kwargs )
670
725
671
726
return impl_sum
672
727
@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
705
760
input_name = input_name ,
706
761
)
707
762
708
- careduce_fn = jit_compile_reducer (node , careduce_py_fn )
763
+ careduce_fn = jit_compile_reducer (node , careduce_py_fn , reduce_to_scalar = False )
709
764
return careduce_fn
710
765
711
766
@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
888
943
if axis is not None :
889
944
axis = normalize_axis_index (axis , x_at .ndim )
890
945
reduce_max_py = create_axis_reducer (
891
- scalar_maximum , - np .inf , axis , x_at .ndim , x_dtype , keepdims = True
946
+ scalar_maximum ,
947
+ - np .inf ,
948
+ axis ,
949
+ x_at .ndim ,
950
+ x_dtype ,
951
+ keepdims = True ,
892
952
)
893
953
reduce_sum_py = create_axis_reducer (
894
954
add_as , 0.0 , axis , x_at .ndim , x_dtype , keepdims = True
@@ -935,10 +995,17 @@ def maxandargmax(x):
935
995
keep_axes = tuple (i for i in range (x_ndim ) if i not in axes )
936
996
937
997
reduce_max_py_fn = create_multiaxis_reducer (
938
- scalar_maximum , - np .inf , axes , x_ndim , x_dtype
998
+ scalar_maximum ,
999
+ - np .inf ,
1000
+ axes ,
1001
+ x_ndim ,
1002
+ x_dtype ,
1003
+ return_scalar = False ,
939
1004
)
940
1005
reduce_max = jit_compile_reducer (
941
- Apply (node .op , node .inputs , [node .outputs [0 ].clone ()]), reduce_max_py_fn
1006
+ Apply (node .op , node .inputs , [node .outputs [0 ].clone ()]),
1007
+ reduce_max_py_fn ,
1008
+ reduce_to_scalar = False ,
942
1009
)
943
1010
944
1011
reduced_x_ndim = x_ndim - len (axes ) + 1
0 commit comments