Skip to content

Commit 06be02e

Browse files
committed
Parse even round case
1 parent 244c9a6 commit 06be02e

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

array_api_tests/test_special_cases.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import re
44
from dataclasses import dataclass
5+
from decimal import ROUND_HALF_EVEN, Decimal
56
from typing import (
67
Callable,
78
Dict,
@@ -236,7 +237,7 @@ def parse_inline_code(inline_code: str) -> float:
236237

237238

238239
r_special_cases = re.compile(
239-
r"\*\*Special [Cc]ases\*\*\n+\s*"
240+
r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
240241
r"For floating-point operands,\n+"
241242
r"((?:\s*-\s*.*\n)+)"
242243
)
@@ -342,8 +343,8 @@ class UnaryResultCheck:
342343
check_result: Callable
343344
expr: str
344345

345-
def __call__(self, result: float) -> bool:
346-
return self.check_result(result)
346+
def __call__(self, i: float, result: float) -> bool:
347+
return self.check_result(i, result)
347348

348349

349350
class Case(Protocol):
@@ -366,17 +367,26 @@ def from_strings(cls, cond_str: str, result_str: str):
366367
check_result, check_result_expr = parse_result(result_str)
367368
return cls(
368369
UnaryCond(cond, cond_expr),
369-
UnaryResultCheck(check_result, check_result_expr),
370+
UnaryResultCheck(lambda _, r: check_result(r), check_result_expr),
370371
)
371372

372373
def __repr__(self):
373374
return f"UnaryCase(<{self.cond.expr} -> {self.check_result.expr}>)"
374375

375376

376377
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
377-
# re.compile(
378-
# "If two integers are equally close to ``x_i``, the result is (.+)"
379-
# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5),
378+
r_even_int_round_case = re.compile(
379+
"If two integers are equally close to ``x_i``, "
380+
"the result is the even integer closest to ``x_i``"
381+
)
382+
383+
even_int_round_case = UnaryCase(
384+
cond=UnaryCond(lambda i: i % 0.5 == 0, "i % 0.5 == 0"),
385+
check_result=UnaryResultCheck(
386+
lambda i, r: r == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)),
387+
"Decimal(i).to_integral_exact(ROUND_HALF_EVEN)",
388+
),
389+
)
380390

381391

382392
def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
@@ -398,6 +408,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
398408
warn(f"not machine-readable: '{e.value}'")
399409
continue
400410
cases.append(case)
411+
elif m := r_even_int_round_case.search(case):
412+
cases.append(even_int_round_case)
401413
else:
402414
if not r_remaining_case.search(case):
403415
warn(f"case not machine-readable: '{case}'")
@@ -795,7 +807,7 @@ def test_unary(func_name, func, cases, x):
795807
f_in = f"{sh.fmt_idx('x', idx)}={in_}"
796808
f_out = f"{sh.fmt_idx('out', idx)}={out}"
797809
assert case.check_result(
798-
out
810+
in_, out
799811
), f"{f_out} not good [{func_name}()]\n{f_in}"
800812
break
801813
assume(good_example)

0 commit comments

Comments
 (0)