Skip to content

Commit 67a5eba

Browse files
committed
Improve numba DimShuffle compile time
1 parent 1827703 commit 67a5eba

File tree

1 file changed

+36
-26
lines changed

1 file changed

+36
-26
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs):
539539

540540
ndim_new_shape = len(shuffle) + len(augment)
541541

542+
no_transpose = all(i == j for i, j in enumerate(transposition))
543+
if no_transpose:
544+
545+
@numba_basic.numba_njit
546+
def transpose(x):
547+
return x
548+
549+
else:
550+
551+
@numba_basic.numba_njit
552+
def transpose(x):
553+
return np.transpose(x, transposition)
554+
555+
shape_template = (1,) * ndim_new_shape
556+
557+
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
558+
# is typed as `getitem(Tuple(), int)`, which has no implementation
559+
# (since getting an item from an empty sequence doesn't make sense).
560+
# To avoid this compile-time error, we omit the expression altogether.
542561
if len(shuffle) > 0:
543562

544563
@numba_basic.numba_njit
545-
def populate_new_shape(i, j, new_shape, shuffle_shape):
546-
if i in augment:
547-
new_shape = numba_basic.tuple_setitem(new_shape, i, 1)
548-
return j, new_shape
549-
else:
550-
new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
551-
return j + 1, new_shape
564+
def find_shape(array_shape):
565+
shape = shape_template
566+
j = 0
567+
for i in range(ndim_new_shape):
568+
if i not in augment:
569+
length = array_shape[j]
570+
shape = numba_basic.tuple_setitem(shape, i, length)
571+
j = j + 1
572+
return shape
552573

553574
else:
554-
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
555-
# is typed as `getitem(Tuple(), int)`, which has no implementation
556-
# (since getting an item from an empty sequence doesn't make sense).
557-
# To avoid this compile-time error, we omit the expression altogether.
558-
@numba_basic.numba_njit(inline="always")
559-
def populate_new_shape(i, j, new_shape, shuffle_shape):
560-
return j, numba_basic.tuple_setitem(new_shape, i, 1)
575+
576+
@numba_basic.numba_njit
577+
def find_shape(array_shape):
578+
return shape_template
561579

562580
if ndim_new_shape > 0:
563-
create_zeros_tuple = numba_basic.create_tuple_creator(
564-
lambda _: 0, ndim_new_shape
565-
)
566581

567582
@numba_basic.numba_njit
568583
def dimshuffle_inner(x, shuffle):
569-
res = np.transpose(x, transposition)
570-
shuffle_shape = res.shape[: len(shuffle)]
571-
572-
new_shape = create_zeros_tuple()
573-
574-
j = 0
575-
for i in range(len(new_shape)):
576-
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
584+
x = transpose(x)
585+
shuffle_shape = x.shape[: len(shuffle)]
586+
new_shape = find_shape(shuffle_shape)
577587

578588
# FIXME: Numba's `array.reshape` only accepts C arrays.
579-
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
589+
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
580590

581591
if not inplace:
582592
return res_reshape.copy()

0 commit comments

Comments
 (0)