Skip to content

Commit 1de6bae

Browse files
committed
Actually check types and dtypes match in numba testing helper
NOTE: CI failing at this point
1 parent 5ffe17a commit 1de6bae

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tests/link/numba/test_basic.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,12 @@ def compare_numba_and_py(
260260
if assert_fn is None:
261261

262262
def assert_fn(x, y):
263-
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
264-
x, y
265-
)
263+
np.testing.assert_allclose(x, y, rtol=1e-4, strict=True)
264+
# Make sure we don't have one input be a np.ndarray while the other is not
265+
if isinstance(x, np.ndarray):
266+
assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is"
267+
else:
268+
assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not"
266269

267270
if any(
268271
inp.owner is not None
@@ -295,8 +298,8 @@ def assert_fn(x, y):
295298
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
296299
numba_res = pytensor_numba_fn(*test_inputs_copy)
297300
if isinstance(graph_outputs, tuple | list):
298-
for j, p in zip(numba_res, py_res, strict=True):
299-
assert_fn(j, p)
301+
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
302+
assert_fn(numba_res_i, python_res_i)
300303
else:
301304
assert_fn(numba_res, py_res)
302305

0 commit comments

Comments
 (0)