Skip to content

Commit 2f0b424

Browse files
committed
Test on numba>=0.57
1 parent f0fda41 commit 2f0b424

File tree

5 files changed

+9
-10
lines changed

5 files changed

+9
-10
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
shell: bash -l {0}
140140
run: |
141141
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
142-
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
142+
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
143143
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
144144
pip install -e ./
145145
mamba list && pip freeze
@@ -192,7 +192,7 @@ jobs:
192192
- name: Install dependencies
193193
shell: bash -l {0}
194194
run: |
195-
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.55" numba-scipy jax jaxlib pytest-benchmark
195+
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" numba-scipy jax jaxlib pytest-benchmark
196196
pip install -e ./
197197
mamba list && pip freeze
198198
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies:
2222
- mkl-service
2323
- libblas=*=*mkl
2424
# numba backend
25-
- numba>=0.55
25+
- numba>=0.57
2626
- numba-scipy
2727
# For testing
2828
- coveralls

pytensor/link/numba/dispatch/random.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def body_fn(a):
312312
def numba_funcify_CategoricalRV(op, node, **kwargs):
313313
out_dtype = node.outputs[1].type.numpy_dtype
314314
size_len = int(get_vector_length(node.inputs[1]))
315+
p_ndim = node.inputs[-1].ndim
315316

316317
@numba_basic.numba_njit
317318
def categorical_rv(rng, size, dtype, p):
@@ -321,7 +322,11 @@ def categorical_rv(rng, size, dtype, p):
321322
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
322323
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
323324

324-
unif_samples = np.random.uniform(0, 1, size_tpl)
325+
# Workaround https://github.com/numba/numba/issues/8975
326+
if not size_len and p_ndim == 1:
327+
unif_samples = np.asarray(np.random.uniform(0, 1))
328+
else:
329+
unif_samples = np.random.uniform(0, 1, size_tpl)
325330

326331
res = np.empty(size_tpl, dtype=out_dtype)
327332
for idx in np.ndindex(*size_tpl):

tests/link/numba/test_basic.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
530530
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
531531
at.as_tensor(rng.poisson(size=(2, 5))),
532532
([1, 1], [2, 2]),
533-
marks=pytest.mark.xfail(
534-
reason="Duplicate index handling hasn't been implemented, yet."
535-
),
536533
),
537534
],
538535
)

tests/link/numba/test_extra_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,6 @@ def test_UnravelIndex(arr, shape, order, exc):
459459
"left",
460460
None,
461461
None,
462-
marks=pytest.mark.xfail(
463-
reason="This won't work until https://github.com/numba/numba/pull/7005 is merged"
464-
),
465462
),
466463
(
467464
set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),

0 commit comments

Comments
 (0)