@@ -90,14 +90,25 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
90
90
return n
91
91
92
92
93
- def isclose (a : float , b : float , * , rel_tol : float = 0.25 , abs_tol : float = 1 ) -> bool :
93
+ def isclose (
94
+ a : float ,
95
+ b : float ,
96
+ M : float ,
97
+ * ,
98
+ rel_tol : float = 0.25 ,
99
+ abs_tol : float = 1 ,
100
+ ) -> bool :
94
101
"""Wraps math.isclose with very generous defaults.
95
102
96
103
This is useful for many floating-point operations where the spec does not
97
104
make accuracy requirements.
98
105
"""
99
- if not (math .isfinite (a ) and math .isfinite (b )):
100
- raise ValueError (f"{ a = } and { b = } , but input must be finite" )
106
+ if math .isnan (a ) or math .isnan (b ):
107
+ raise ValueError (f"{ a = } and { b = } , but input must be non-NaN" )
108
+ if math .isinf (a ):
109
+ return math .isinf (b ) or abs (b ) > math .log (M )
110
+ elif math .isinf (b ):
111
+ return math .isinf (a ) or abs (a ) > math .log (M )
101
112
return math .isclose (a , b , rel_tol = rel_tol , abs_tol = abs_tol )
102
113
103
114
@@ -288,10 +299,10 @@ def unary_assert_against_refimpl(
288
299
f"{ f_i } ={ scalar_i } "
289
300
)
290
301
if res .dtype in dh .complex_dtypes :
291
- assert isclose (scalar_o .real , expected .real ), msg
292
- assert isclose (scalar_o .imag , expected .imag ), msg
302
+ assert isclose (scalar_o .real , expected .real , M ), msg
303
+ assert isclose (scalar_o .imag , expected .imag , M ), msg
293
304
else :
294
- assert isclose (scalar_o , expected ), msg
305
+ assert isclose (scalar_o , expected , M ), msg
295
306
else :
296
307
assert scalar_o == expected , (
297
308
f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
@@ -364,10 +375,10 @@ def binary_assert_against_refimpl(
364
375
f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
365
376
)
366
377
if res .dtype in dh .complex_dtypes :
367
- assert isclose (scalar_o .real , expected .real ), msg
368
- assert isclose (scalar_o .imag , expected .imag ), msg
378
+ assert isclose (scalar_o .real , expected .real , M ), msg
379
+ assert isclose (scalar_o .imag , expected .imag , M ), msg
369
380
else :
370
- assert isclose (scalar_o , expected ), msg
381
+ assert isclose (scalar_o , expected , M ), msg
371
382
else :
372
383
assert scalar_o == expected , (
373
384
f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
@@ -437,10 +448,10 @@ def right_scalar_assert_against_refimpl(
437
448
f"{ f_l } ={ scalar_l } "
438
449
)
439
450
if res .dtype in dh .complex_dtypes :
440
- assert isclose (scalar_o .real , expected .real ), msg
441
- assert isclose (scalar_o .imag , expected .imag ), msg
451
+ assert isclose (scalar_o .real , expected .real , M ), msg
452
+ assert isclose (scalar_o .imag , expected .imag , M ), msg
442
453
else :
443
- assert isclose (scalar_o , expected ), msg
454
+ assert isclose (scalar_o , expected , M ), msg
444
455
else :
445
456
assert scalar_o == expected , (
446
457
f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
0 commit comments