|
5 | 5 | import numpy as np
|
6 | 6 | from numba import _helperlib, types
|
7 | 7 | from numba.core import cgutils
|
8 |
| -from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox |
| 8 | +from numba.extending import ( |
| 9 | + NativeValue, |
| 10 | + box, |
| 11 | + models, |
| 12 | + overload, |
| 13 | + register_model, |
| 14 | + typeof_impl, |
| 15 | + unbox, |
| 16 | +) |
9 | 17 | from numpy.random import RandomState
|
10 | 18 |
|
11 | 19 | import pytensor.tensor.random.basic as aer
|
@@ -78,6 +86,16 @@ def box_random_state(typ, val, c):
|
78 | 86 | return class_obj
|
79 | 87 |
|
80 | 88 |
|
| 89 | +@overload(np.random.uniform) |
| 90 | +def uniform_empty_size(a, b, size): |
| 91 | + if isinstance(size, types.Tuple) and size.count == 0: |
| 92 | + |
| 93 | + def uniform_no_size(a, b, size): |
| 94 | + return np.random.uniform(a, b) |
| 95 | + |
| 96 | + return uniform_no_size |
| 97 | + |
| 98 | + |
81 | 99 | @numba_typify.register(RandomState)
|
82 | 100 | def numba_typify_RandomState(state, **kwargs):
|
83 | 101 | # The numba_typify in this case is just an passthrough function
|
@@ -321,7 +339,7 @@ def categorical_rv(rng, size, dtype, p):
|
321 | 339 | size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
|
322 | 340 | p = np.broadcast_to(p, size_tpl + p.shape[-1:])
|
323 | 341 |
|
324 |
| - unif_samples = np.random.uniform(0, 1, size_tpl) |
| 342 | + unif_samples = np.asarray(np.random.uniform(0, 1, size_tpl)) |
325 | 343 |
|
326 | 344 | res = np.empty(size_tpl, dtype=out_dtype)
|
327 | 345 | for idx in np.ndindex(*size_tpl):
|
|
0 commit comments