@@ -260,9 +260,12 @@ def compare_numba_and_py(
260
260
if assert_fn is None :
261
261
262
262
def assert_fn (x , y ):
263
- return np .testing .assert_allclose (x , y , rtol = 1e-4 ) and compare_shape_dtype (
264
- x , y
265
- )
263
+ np .testing .assert_allclose (x , y , rtol = 1e-4 , strict = True )
264
+ # Make sure we don't have one input be a np.ndarray while the other is not
265
+ if isinstance (x , np .ndarray ):
266
+ assert isinstance (y , np .ndarray ), "y is not a NumPy array, but x is"
267
+ else :
268
+ assert not isinstance (y , np .ndarray ), "y is a NumPy array, but x is not"
266
269
267
270
if any (
268
271
inp .owner is not None
@@ -295,8 +298,8 @@ def assert_fn(x, y):
295
298
test_inputs_copy = (inp .copy () for inp in test_inputs ) if inplace else test_inputs
296
299
numba_res = pytensor_numba_fn (* test_inputs_copy )
297
300
if isinstance (graph_outputs , tuple | list ):
298
- for j , p in zip (numba_res , py_res , strict = True ):
299
- assert_fn (j , p )
301
+ for numba_res_i , python_res_i in zip (numba_res , py_res , strict = True ):
302
+ assert_fn (numba_res_i , python_res_i )
300
303
else :
301
304
assert_fn (numba_res , py_res )
302
305
0 commit comments