@@ -212,12 +212,10 @@ def test_inner_composite(mode):
212
212
y16 = op (n_steps , x16 )
213
213
assert y16 .type .dtype == "float16"
214
214
215
- fn32 = function ([n_steps , x16 ], y16 , mode = mode )
216
- np .testing .assert_allclose (
217
- fn32 (n_steps = 9 , x16 = np .array (4.73 , dtype = "float16" )),
218
- 4.73 + 9 ,
219
- rtol = 1e-3 ,
220
- )
215
+ fn16 = function ([n_steps , x16 ], y16 , mode = mode )
216
+ out16 = fn16 (n_steps = 9 , x16 = np .array (4.73 , dtype = "float16" ))
217
+ assert out16 .dtype == "float16"
218
+ assert np .isnan (out16 )
221
219
222
220
223
221
@mode
@@ -243,8 +241,10 @@ def test_inner_loop(mode):
243
241
y16 = outer_loop_op (n_steps , x16 , n_steps )
244
242
assert y16 .type .dtype == "float16"
245
243
246
- fn32 = function ([n_steps , x16 ], y16 , mode = mode )
244
+ fn16 = function ([n_steps , x16 ], y16 , mode = mode )
245
+ out16 = fn16 (n_steps = 3 , x16 = np .array (2.5 , dtype = "float16" ))
246
+ assert out16 .dtype == "float16"
247
247
np .testing .assert_allclose (
248
- fn32 ( n_steps = 3 , x16 = np . array ( 2.5 , dtype = "float16" )) ,
248
+ out16 ,
249
249
3 ** 2 + 2.5 ,
250
250
)
0 commit comments