Skip to content

Commit aaeb88a

Browse files
brandonwillardtwiecki
authored andcommitted
Change numba_typify to numba_const_convert
1 parent 2724936 commit aaeb88a

File tree

5 files changed

+13
-12
lines changed

5 files changed

+13
-12
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# isort: off
2-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
2+
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_const_convert
33

44
# Load dispatch specializations
55
import pytensor.link.numba.dispatch.scalar

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@ def use_optimized_cheap_pass(*args, **kwargs):
331331

332332

333333
@singledispatch
334-
def numba_typify(data, dtype=None, **kwargs):
334+
def numba_const_convert(data, dtype=None, **kwargs):
335+
"""Create a Numba compatible constant from an PyTensor `Constant`."""
335336
return data
336337

337338

@@ -423,7 +424,7 @@ def numba_funcify_FunctionGraph(
423424
return fgraph_to_python(
424425
fgraph,
425426
numba_funcify,
426-
const_conversion_fn=numba_typify,
427+
const_conversion_fn=numba_const_convert,
427428
fgraph_name=fgraph_name,
428429
**kwargs,
429430
)

pytensor/link/numba/dispatch/random.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.graph.basic import Apply
2121
from pytensor.graph.op import Op
2222
from pytensor.link.numba.dispatch import basic as numba_basic
23-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
23+
from pytensor.link.numba.dispatch.basic import numba_const_convert, numba_funcify
2424
from pytensor.link.utils import (
2525
compile_function_src,
2626
get_name_for_object,
@@ -96,11 +96,11 @@ def uniform_no_size(a, b, size):
9696
return uniform_no_size
9797

9898

99-
@numba_typify.register(RandomState)
100-
def numba_typify_RandomState(state, **kwargs):
101-
# The numba_typify in this case is just an passthrough function
99+
@numba_const_convert.register(RandomState)
100+
def numba_const_convert_RandomState(state, **kwargs):
101+
# The `numba_const_convert` in this case is just a passthrough function
102102
# that synchronizes Numba's internal random state with the current
103-
# RandomState object
103+
# `RandomState` object.
104104
ints, index = state.get_state()[1:3]
105105
ptr = _helperlib.rnd_get_np_state_ptr()
106106
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))

pytensor/link/numba/linker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def jit_compile(self, fn):
3535
def create_thunk_inputs(self, storage_map):
3636
from numpy.random import RandomState
3737

38-
from pytensor.link.numba.dispatch import numba_typify
38+
from pytensor.link.numba.dispatch import numba_const_convert
3939

4040
thunk_inputs = []
4141
for n in self.fgraph.inputs:
4242
sinput = storage_map[n]
4343
if isinstance(sinput[0], RandomState):
44-
new_value = numba_typify(
44+
new_value = numba_const_convert(
4545
sinput[0], dtype=getattr(sinput[0], "dtype", None)
4646
)
4747
# We need to remove the reference-based connection to the

tests/link/numba/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytensor.graph.type import Type
2525
from pytensor.ifelse import ifelse
2626
from pytensor.link.numba.dispatch import basic as numba_basic
27-
from pytensor.link.numba.dispatch import numba_typify
27+
from pytensor.link.numba.dispatch import numba_const_convert
2828
from pytensor.link.numba.linker import NumbaLinker
2929
from pytensor.raise_op import assert_op
3030
from pytensor.tensor import blas
@@ -321,7 +321,7 @@ def test_create_numba_signature(v, expected, force_scalar):
321321
[
322322
(
323323
np.random.RandomState(1),
324-
numba_typify,
324+
numba_const_convert,
325325
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
326326
)
327327
],

0 commit comments

Comments
 (0)