@@ -1805,7 +1805,8 @@ def local_join_subtensors(fgraph, node):
1805
1805
def local_uint_constant_indices (fgraph , node ):
1806
1806
"""Convert constant indices to unsigned dtypes."""
1807
1807
1808
- if isinstance (node .op , IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1 ):
1808
+ op = node .op
1809
+ if isinstance (op , IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1 ):
1809
1810
x , y , * indices = node .inputs
1810
1811
else :
1811
1812
x , * indices = node .inputs
@@ -1864,21 +1865,18 @@ def local_uint_constant_indices(fgraph, node):
1864
1865
if not has_new_index :
1865
1866
return False
1866
1867
1867
- new_out = x [tuple (new_indices )]
1868
-
1869
- if y is not None :
1870
- new_out = inc_subtensor (
1871
- new_out ,
1872
- y ,
1873
- inplace = node .op .inplace ,
1874
- set_instead_of_inc = node .op .set_instead_of_inc ,
1875
- ignore_duplicates = getattr (node .op , "ignore_duplicates" , False ),
1876
- )
1877
-
1878
- new_outs = new_out .owner .outputs
1879
- copy_stack_trace (node .outputs , new_outs )
1880
-
1881
- return new_outs
1868
+ if isinstance (op , Subtensor | IncSubtensor ):
1869
+ # Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
1870
+ props = op ._props_dict ()
1871
+ props ["idx_list" ] = new_indices
1872
+ op = type (op )(** props )
1873
+ # Basic index Ops don't expect slices, but the respective start/step/stop
1874
+ new_indices = get_slice_elements (new_indices )
1875
+
1876
+ new_args = (x , * new_indices ) if y is None else (x , y , * new_indices )
1877
+ new_out = op (* new_args )
1878
+ copy_stack_trace (node .outputs [0 ], new_out )
1879
+ return [new_out ]
1882
1880
1883
1881
1884
1882
@register_canonicalize ("shape_unsafe" )
0 commit comments