@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs):
539
539
540
540
ndim_new_shape = len (shuffle ) + len (augment )
541
541
542
+ no_transpose = all (i == j for i , j in enumerate (transposition ))
543
+ if no_transpose :
544
+
545
+ @numba_basic .numba_njit
546
+ def transpose (x ):
547
+ return x
548
+
549
+ else :
550
+
551
+ @numba_basic .numba_njit
552
+ def transpose (x ):
553
+ return np .transpose (x , transposition )
554
+
555
+ shape_template = (1 ,) * ndim_new_shape
556
+
557
+ # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
558
+ # is typed as `getitem(Tuple(), int)`, which has no implementation
559
+ # (since getting an item from an empty sequence doesn't make sense).
560
+ # To avoid this compile-time error, we omit the expression altogether.
542
561
if len (shuffle ) > 0 :
543
562
544
563
@numba_basic .numba_njit
545
- def populate_new_shape (i , j , new_shape , shuffle_shape ):
546
- if i in augment :
547
- new_shape = numba_basic .tuple_setitem (new_shape , i , 1 )
548
- return j , new_shape
549
- else :
550
- new_shape = numba_basic .tuple_setitem (new_shape , i , shuffle_shape [j ])
551
- return j + 1 , new_shape
564
+ def find_shape (array_shape ):
565
+ shape = shape_template
566
+ j = 0
567
+ for i in range (ndim_new_shape ):
568
+ if i not in augment :
569
+ length = array_shape [j ]
570
+ shape = numba_basic .tuple_setitem (shape , i , length )
571
+ j = j + 1
572
+ return shape
552
573
553
574
else :
554
- # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
555
- # is typed as `getitem(Tuple(), int)`, which has no implementation
556
- # (since getting an item from an empty sequence doesn't make sense).
557
- # To avoid this compile-time error, we omit the expression altogether.
558
- @numba_basic .numba_njit (inline = "always" )
559
- def populate_new_shape (i , j , new_shape , shuffle_shape ):
560
- return j , numba_basic .tuple_setitem (new_shape , i , 1 )
575
+
576
+ @numba_basic .numba_njit
577
+ def find_shape (array_shape ):
578
+ return shape_template
561
579
562
580
if ndim_new_shape > 0 :
563
- create_zeros_tuple = numba_basic .create_tuple_creator (
564
- lambda _ : 0 , ndim_new_shape
565
- )
566
581
567
582
@numba_basic .numba_njit
568
583
def dimshuffle_inner (x , shuffle ):
569
- res = np .transpose (x , transposition )
570
- shuffle_shape = res .shape [: len (shuffle )]
571
-
572
- new_shape = create_zeros_tuple ()
573
-
574
- j = 0
575
- for i in range (len (new_shape )):
576
- j , new_shape = populate_new_shape (i , j , new_shape , shuffle_shape )
584
+ x = transpose (x )
585
+ shuffle_shape = x .shape [: len (shuffle )]
586
+ new_shape = find_shape (shuffle_shape )
577
587
578
588
# FIXME: Numba's `array.reshape` only accepts C arrays.
579
- res_reshape = np .reshape (np .ascontiguousarray (res ), new_shape )
589
+ res_reshape = np .reshape (np .ascontiguousarray (x ), new_shape )
580
590
581
591
if not inplace :
582
592
return res_reshape .copy ()
0 commit comments