Skip to content

Commit 3682b1c

Browse files
committed
Fix None in slice for numba boxing
1 parent 004281a commit 3682b1c

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
import operator
24
import warnings
35
from contextlib import contextmanager
@@ -10,7 +12,7 @@
1012
import numpy as np
1113
import scipy
1214
import scipy.special
13-
from llvmlite.ir import Type as llvm_Type
15+
from llvmlite import ir
1416
from numba import types
1517
from numba.core.errors import TypingError
1618
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
@@ -131,7 +133,7 @@ def create_numba_signature(
131133

132134

133135
def slice_new(self, start, stop, step):
134-
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
136+
fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
135137
fn = self._get_function(fnty, name="PySlice_New")
136138
return self.builder.call(fn, [start, stop, step])
137139

@@ -150,11 +152,33 @@ def box_slice(typ, val, c):
150152
This makes it possible to return an Numba's internal representation of a
151153
``slice`` object as a proper ``slice`` to Python.
152154
"""
155+
start = c.builder.extract_value(val, 0)
156+
stop = c.builder.extract_value(val, 1)
157+
158+
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
159+
160+
start_is_none = c.builder.icmp_signed("==", start, none_val)
161+
start = c.builder.select(
162+
start_is_none,
163+
c.pyapi.get_null_object(),
164+
c.box(types.int64, start),
165+
)
166+
167+
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
168+
stop = c.builder.select(
169+
stop_is_none,
170+
c.pyapi.get_null_object(),
171+
c.box(types.int64, stop),
172+
)
153173

154-
start = c.box(types.int64, c.builder.extract_value(val, 0))
155-
stop = c.box(types.int64, c.builder.extract_value(val, 1))
156174
if typ.has_step:
157-
step = c.box(types.int64, c.builder.extract_value(val, 2))
175+
step = c.builder.extract_value(val, 2)
176+
step_is_none = c.builder.icmp_signed("==", step, none_val)
177+
step = c.builder.select(
178+
step_is_none,
179+
c.pyapi.get_null_object(),
180+
c.box(types.int64, step),
181+
)
158182
else:
159183
step = c.pyapi.get_null_object()
160184

0 commit comments

Comments
 (0)