@@ -479,6 +479,9 @@ def numba_funcify_FunctionGraph(
479
479
)
480
480
481
481
482
+ SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483
+
484
+
482
485
def create_index_func (node , objmode = False ):
483
486
"""Create a Python function that assembles and uses an index on an array."""
484
487
@@ -501,14 +504,13 @@ def convert_indices(indices, entry):
501
504
else :
502
505
raise ValueError ()
503
506
504
- set_or_inc = isinstance (
505
- node .op , IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506
- )
507
+ op = node .op
508
+ set_or_inc = isinstance (op , SET_OR_INC_OPS ) # type: ignore
507
509
index_start_idx = 1 + int (set_or_inc )
508
510
509
511
input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
510
512
op_indices = list (node .inputs [index_start_idx :])
511
- idx_list = getattr (node . op , "idx_list" , None )
513
+ idx_list = getattr (op , "idx_list" , None )
512
514
513
515
indices_creation_src = (
514
516
tuple (convert_indices (op_indices , idx ) for idx in idx_list )
@@ -523,8 +525,7 @@ def convert_indices(indices, entry):
523
525
indices_creation_src = f"indices = ({ indices_creation_src } )"
524
526
525
527
if set_or_inc :
526
- fn_name = "incsubtensor"
527
- if node .op .inplace :
528
+ if op .inplace :
528
529
index_prologue = f"z = { input_names [0 ]} "
529
530
else :
530
531
index_prologue = f"z = np.copy({ input_names [0 ]} )"
@@ -536,30 +537,17 @@ def convert_indices(indices, entry):
536
537
else :
537
538
y_name = input_names [1 ]
538
539
539
- if node .op .set_instead_of_inc :
540
+ if op .set_instead_of_inc :
541
+ fn_name = "setsubtensor"
540
542
index_body = f"z[indices] = { y_name } "
541
543
else :
544
+ fn_name = "incsubtensor"
542
545
index_body = f"z[indices] += { y_name } "
543
546
else :
544
547
fn_name = "subtensor"
545
548
index_prologue = ""
546
549
index_body = f"z = { input_names [0 ]} [indices]"
547
550
548
- if objmode :
549
- output_var = node .outputs [0 ]
550
-
551
- if not set_or_inc :
552
- # Since `z` is being "created" while in object mode, it's
553
- # considered an "outgoing" variable and needs to be manually typed
554
- output_sig = f"z='{ output_var .dtype } [{ ', ' .join ([':' ] * output_var .ndim )} ]'"
555
- else :
556
- output_sig = ""
557
-
558
- index_body = f"""
559
- with objmode({ output_sig } ):
560
- { index_body }
561
- """
562
-
563
551
subtensor_def_src = f"""
564
552
def { fn_name } ({ ", " .join (input_names )} ):
565
553
{ index_prologue }
@@ -572,48 +560,34 @@ def {fn_name}({", ".join(input_names)}):
572
560
573
561
574
562
@numba_funcify .register (Subtensor )
563
+ @numba_funcify .register (IncSubtensor )
575
564
@numba_funcify .register (AdvancedSubtensor1 )
576
- def numba_funcify_Subtensor (op , node , ** kwargs ):
577
- objmode = isinstance (op , AdvancedSubtensor )
578
- if objmode :
579
- warnings .warn (
580
- ("Numba will use object mode to allow run " "AdvancedSubtensor." ),
581
- UserWarning ,
582
- )
583
-
584
- subtensor_def_src = create_index_func (node , objmode = objmode )
585
-
586
- global_env = {"np" : np }
587
- if objmode :
588
- global_env ["objmode" ] = numba .objmode
589
-
565
+ def numba_funcify_default_subtensor (op , node , ** kwargs ):
566
+ function_name = "subtensor"
567
+ if isinstance (op , SET_OR_INC_OPS ): # type: ignore
568
+ function_name = "setsubtensor" if op .set_instead_of_inc else "incsubtensor"
569
+ subtensor_def_src = create_index_func (node )
590
570
subtensor_fn = compile_function_src (
591
- subtensor_def_src , "subtensor" , {** globals (), ** global_env }
571
+ subtensor_def_src , function_name , {** globals (), "np" : np }
592
572
)
593
-
594
573
return numba_njit (subtensor_fn , boundscheck = True )
595
574
596
575
597
- @numba_funcify .register (IncSubtensor )
598
- def numba_funcify_IncSubtensor (op , node , ** kwargs ):
599
- objmode = isinstance (op , AdvancedIncSubtensor )
600
- if objmode :
601
- warnings .warn (
602
- ("Numba will use object mode to allow run " "AdvancedIncSubtensor." ),
603
- UserWarning ,
604
- )
605
-
606
- incsubtensor_def_src = create_index_func (node , objmode = objmode )
607
-
608
- global_env = {"np" : np }
609
- if objmode :
610
- global_env ["objmode" ] = numba .objmode
576
+ @numba_funcify .register (AdvancedSubtensor )
577
+ def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
578
+ idxs = node .inputs [1 :]
579
+ adv_idxs_dims = [
580
+ idx .type .ndim
581
+ for idx in idxs
582
+ if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
583
+ ]
611
584
612
- incsubtensor_fn = compile_function_src (
613
- incsubtensor_def_src , "incsubtensor" , {** globals (), ** global_env }
614
- )
585
+ # Numba does not support indexes with more than one dimension
586
+ # Nor multiple vector indexes
587
+ if len (adv_idxs_dims ) > 1 or adv_idxs_dims [0 ] > 1 :
588
+ return generate_fallback_impl (op , node , ** kwargs )
615
589
616
- return numba_njit ( incsubtensor_fn , boundscheck = True )
590
+ return numba_funcify_default_subtensor ( op , node , ** kwargs )
617
591
618
592
619
593
@numba_funcify .register (AdvancedIncSubtensor1 )
0 commit comments