@@ -479,6 +479,9 @@ def numba_funcify_FunctionGraph(
479
479
)
480
480
481
481
482
+ SET_OR_INC_OPS = IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
483
+
484
+
482
485
def create_index_func (node , objmode = False ):
483
486
"""Create a Python function that assembles and uses an index on an array."""
484
487
@@ -501,9 +504,7 @@ def convert_indices(indices, entry):
501
504
else :
502
505
raise ValueError ()
503
506
504
- set_or_inc = isinstance (
505
- node .op , IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
506
- )
507
+ set_or_inc = isinstance (node .op , SET_OR_INC_OPS )
507
508
index_start_idx = 1 + int (set_or_inc )
508
509
509
510
input_names = [unique_names (v , force_unique = True ) for v in node .inputs ]
@@ -545,21 +546,6 @@ def convert_indices(indices, entry):
545
546
index_prologue = ""
546
547
index_body = f"z = { input_names [0 ]} [indices]"
547
548
548
- if objmode :
549
- output_var = node .outputs [0 ]
550
-
551
- if not set_or_inc :
552
- # Since `z` is being "created" while in object mode, it's
553
- # considered an "outgoing" variable and needs to be manually typed
554
- output_sig = f"z='{ output_var .dtype } [{ ', ' .join ([':' ] * output_var .ndim )} ]'"
555
- else :
556
- output_sig = ""
557
-
558
- index_body = f"""
559
- with objmode({ output_sig } ):
560
- { index_body }
561
- """
562
-
563
549
subtensor_def_src = f"""
564
550
def { fn_name } ({ ", " .join (input_names )} ):
565
551
{ index_prologue }
@@ -572,48 +558,34 @@ def {fn_name}({", ".join(input_names)}):
572
558
573
559
574
560
@numba_funcify .register (Subtensor )
561
+ @numba_funcify .register (IncSubtensor )
575
562
@numba_funcify .register (AdvancedSubtensor1 )
576
- def numba_funcify_Subtensor (op , node , ** kwargs ):
577
- objmode = isinstance (op , AdvancedSubtensor )
578
- if objmode :
579
- warnings .warn (
580
- ("Numba will use object mode to allow run " "AdvancedSubtensor." ),
581
- UserWarning ,
582
- )
583
-
584
- subtensor_def_src = create_index_func (node , objmode = objmode )
585
-
586
- global_env = {"np" : np }
587
- if objmode :
588
- global_env ["objmode" ] = numba .objmode
589
-
563
+ def numba_funcify_default_subtensor (op , node , ** kwargs ):
564
+ function_name = "subtensor"
565
+ if isinstance (op , SET_OR_INC_OPS ):
566
+ function_name = "setsubtensor" if op .set_instead_of_inc else "incsubtensor"
567
+ subtensor_def_src = create_index_func (node )
590
568
subtensor_fn = compile_function_src (
591
- subtensor_def_src , "subtensor" , {** globals (), ** global_env }
569
+ subtensor_def_src , function_name , {** globals (), "np" : np }
592
570
)
593
-
594
571
return numba_njit (subtensor_fn , boundscheck = True )
595
572
596
573
597
- @numba_funcify .register (IncSubtensor )
598
- def numba_funcify_IncSubtensor (op , node , ** kwargs ):
599
- objmode = isinstance (op , AdvancedIncSubtensor )
600
- if objmode :
601
- warnings .warn (
602
- ("Numba will use object mode to allow run " "AdvancedIncSubtensor." ),
603
- UserWarning ,
604
- )
605
-
606
- incsubtensor_def_src = create_index_func (node , objmode = objmode )
607
-
608
- global_env = {"np" : np }
609
- if objmode :
610
- global_env ["objmode" ] = numba .objmode
574
+ @numba_funcify .register (AdvancedSubtensor )
575
+ def numba_funcify_AdvancedSubtensor (op , node , ** kwargs ):
576
+ idxs = node .inputs [1 :]
577
+ adv_idxs_dims = [
578
+ idx .type .ndim
579
+ for idx in idxs
580
+ if (isinstance (idx .type , TensorType ) and idx .type .ndim > 0 )
581
+ ]
611
582
612
- incsubtensor_fn = compile_function_src (
613
- incsubtensor_def_src , "incsubtensor" , {** globals (), ** global_env }
614
- )
583
+ # Numba does not support indexes with more than one dimension
584
+ # Nor multiple vector indexes
585
+ if len (adv_idxs_dims ) > 1 or adv_idxs_dims [0 ] > 1 :
586
+ return generate_fallback_impl (op , node , ** kwargs )
615
587
616
- return numba_njit ( incsubtensor_fn , boundscheck = True )
588
+ return numba_funcify_default_subtensor ( op , node , ** kwargs )
617
589
618
590
619
591
@numba_funcify .register (AdvancedIncSubtensor1 )
0 commit comments