Skip to content

Commit 27459b6

Browse files
brandonwillardricardoV94
authored andcommitted
Make scalar Categorical sampling work in recent versions of Numba
1 parent 4eded29 commit 27459b6

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
import numpy as np
66
from numba import _helperlib, types
77
from numba.core import cgutils
8-
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox
8+
from numba.extending import (
9+
NativeValue,
10+
box,
11+
models,
12+
overload,
13+
register_model,
14+
typeof_impl,
15+
unbox,
16+
)
917
from numpy.random import RandomState
1018

1119
import pytensor.tensor.random.basic as aer
@@ -78,6 +86,16 @@ def box_random_state(typ, val, c):
7886
return class_obj
7987

8088

89+
@overload(np.random.uniform)
90+
def uniform_empty_size(a, b, size):
91+
if isinstance(size, types.Tuple) and size.count == 0:
92+
93+
def uniform_no_size(a, b, size):
94+
return np.random.uniform(a, b)
95+
96+
return uniform_no_size
97+
98+
8199
@numba_typify.register(RandomState)
82100
def numba_typify_RandomState(state, **kwargs):
83101
# The numba_typify in this case is just an passthrough function
@@ -321,7 +339,7 @@ def categorical_rv(rng, size, dtype, p):
321339
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
322340
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
323341

324-
unif_samples = np.random.uniform(0, 1, size_tpl)
342+
unif_samples = np.asarray(np.random.uniform(0, 1, size_tpl))
325343

326344
res = np.empty(size_tpl, dtype=out_dtype)
327345
for idx in np.ndindex(*size_tpl):

0 commit comments

Comments
 (0)