2
2
import math
3
3
import re
4
4
from dataclasses import dataclass
5
+ from decimal import ROUND_HALF_EVEN , Decimal
5
6
from typing import (
6
7
Callable ,
7
8
Dict ,
@@ -236,7 +237,7 @@ def parse_inline_code(inline_code: str) -> float:
236
237
237
238
238
239
r_special_cases = re .compile (
239
- r"\*\*Special [Cc]ases\*\*\n+\s* "
240
+ r"\*\*Special [Cc]ases\*\*(?:\n.*)+ "
240
241
r"For floating-point operands,\n+"
241
242
r"((?:\s*-\s*.*\n)+)"
242
243
)
@@ -342,8 +343,8 @@ class UnaryResultCheck:
342
343
check_result : Callable
343
344
expr : str
344
345
345
- def __call__ (self , result : float ) -> bool :
346
- return self .check_result (result )
346
+ def __call__ (self , i : float , result : float ) -> bool :
347
+ return self .check_result (i , result )
347
348
348
349
349
350
class Case (Protocol ):
@@ -366,17 +367,26 @@ def from_strings(cls, cond_str: str, result_str: str):
366
367
check_result , check_result_expr = parse_result (result_str )
367
368
return cls (
368
369
UnaryCond (cond , cond_expr ),
369
- UnaryResultCheck (check_result , check_result_expr ),
370
+ UnaryResultCheck (lambda _ , r : check_result ( r ) , check_result_expr ),
370
371
)
371
372
372
373
def __repr__ (self ):
373
374
return f"UnaryCase(<{ self .cond .expr } -> { self .check_result .expr } >)"
374
375
375
376
376
377
r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
377
- # re.compile(
378
- # "If two integers are equally close to ``x_i``, the result is (.+)"
379
- # ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
378
+ r_even_int_round_case = re .compile (
379
+ "If two integers are equally close to ``x_i``, "
380
+ "the result is the even integer closest to ``x_i``"
381
+ )
382
+
383
+ even_int_round_case = UnaryCase (
384
+ cond = UnaryCond (lambda i : i % 0.5 == 0 , "i % 0.5 == 0" ),
385
+ check_result = UnaryResultCheck (
386
+ lambda i , r : r == float (Decimal (i ).to_integral_exact (ROUND_HALF_EVEN )),
387
+ "Decimal(i).to_integral_exact(ROUND_HALF_EVEN)" ,
388
+ ),
389
+ )
380
390
381
391
382
392
def parse_unary_docstring (docstring : str ) -> List [UnaryCase ]:
@@ -398,6 +408,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
398
408
warn (f"not machine-readable: '{ e .value } '" )
399
409
continue
400
410
cases .append (case )
411
+ elif m := r_even_int_round_case .search (case ):
412
+ cases .append (even_int_round_case )
401
413
else :
402
414
if not r_remaining_case .search (case ):
403
415
warn (f"case not machine-readable: '{ case } '" )
@@ -795,7 +807,7 @@ def test_unary(func_name, func, cases, x):
795
807
f_in = f"{ sh .fmt_idx ('x' , idx )} ={ in_ } "
796
808
f_out = f"{ sh .fmt_idx ('out' , idx )} ={ out } "
797
809
assert case .check_result (
798
- out
810
+ in_ , out
799
811
), f"{ f_out } not good [{ func_name } ()]\n { f_in } "
800
812
break
801
813
assume (good_example )
0 commit comments