Skip to content

Commit 3292b3b

Browse files
committed
fix test
1 parent 5a58da4 commit 3292b3b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/scalar/test_loop.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,10 @@ def test_inner_composite(mode):
212212
y16 = op(n_steps, x16)
213213
assert y16.type.dtype == "float16"
214214

215-
fn32 = function([n_steps, x16], y16, mode=mode)
216-
np.testing.assert_allclose(
217-
fn32(n_steps=9, x16=np.array(4.73, dtype="float16")),
218-
4.73 + 9,
219-
rtol=1e-3,
220-
)
215+
fn16 = function([n_steps, x16], y16, mode=mode)
216+
out16 = fn16(n_steps=9, x16=np.array(4.73, dtype="float16"))
217+
assert out16.dtype == "float16"
218+
assert np.isnan(out16)
221219

222220

223221
@mode
@@ -243,8 +241,10 @@ def test_inner_loop(mode):
243241
y16 = outer_loop_op(n_steps, x16, n_steps)
244242
assert y16.type.dtype == "float16"
245243

246-
fn32 = function([n_steps, x16], y16, mode=mode)
244+
fn16 = function([n_steps, x16], y16, mode=mode)
245+
out16 = fn16(n_steps=3, x16=np.array(2.5, dtype="float16"))
246+
assert out16.dtype == "float16"
247247
np.testing.assert_allclose(
248-
fn32(n_steps=3, x16=np.array(2.5, dtype="float16")),
248+
out16,
249249
3**2 + 2.5,
250250
)

0 commit comments

Comments
 (0)