Skip to content

Commit ecfa471

Browse files
committed
Update special case parsing for >2022.12 specs
1 parent e851b60 commit ecfa471

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

array_api_tests/test_special_cases.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from . import hypothesis_helpers as hh
3333
from . import pytest_helpers as ph
3434
from . import shape_helpers as sh
35-
from . import xps
36-
from . import xp
35+
from . import xp, xps
3736
from .stubs import category_to_funcs
3837

3938
pytestmark = pytest.mark.ci
@@ -126,6 +125,8 @@ def abs_cond(i: float) -> bool:
126125
"infinity": float("inf"),
127126
"0": 0.0,
128127
"1": 1.0,
128+
"False": 0.0,
129+
"True": 1.0,
129130
}
130131
r_value = re.compile(r"([+-]?)(.+)")
131132
r_pi = re.compile(r"(\d?)π(?:/(\d))?")
@@ -507,7 +508,10 @@ def __repr__(self) -> str:
507508
return f"{self.__class__.__name__}(<{self}>)"
508509

509510

510-
r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
511+
r_case_block = re.compile(
512+
r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*"
513+
r"(?:.+\n--+)?(?:\.\. versionchanged.*)?"
514+
)
511515
r_case = re.compile(r"\s+-\s*(.*)\.")
512516

513517

@@ -1121,6 +1125,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11211125
iop_params = []
11221126
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
11231127
for stub in category_to_funcs["elementwise"]:
1128+
# if stub.__name__ == "abs":
1129+
# import ipdb; ipdb.set_trace()
1130+
11241131
if stub.__doc__ is None:
11251132
warn(f"{stub.__name__}() stub has no docstring")
11261133
continue
@@ -1167,6 +1174,8 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11671174
op = getattr(operator, op_name)
11681175
name_to_func[op_name] = op
11691176
# We collect inplace operator test cases seperately
1177+
if stub.__name__ == "equal":
1178+
break
11701179
iop_name = "__i" + op_name[2:]
11711180
iop = getattr(operator, iop_name)
11721181
for case in cases:
@@ -1259,7 +1268,12 @@ def test_binary(func_name, func, case, x1, x2, data):
12591268

12601269
res = func(x1, x2)
12611270
# sanity check
1262-
ph.assert_result_shape(func_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=result_shape)
1271+
ph.assert_result_shape(
1272+
func_name,
1273+
in_shapes=[x1.shape, x2.shape],
1274+
out_shape=res.shape,
1275+
expected=result_shape,
1276+
)
12631277

12641278
good_example = False
12651279
for l_idx, r_idx, o_idx in all_indices:
@@ -1311,7 +1325,9 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13111325
res = xp.asarray(x1, copy=True)
13121326
res = iop(res, x2)
13131327
# sanity check
1314-
ph.assert_result_shape(iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape)
1328+
ph.assert_result_shape(
1329+
iop_name, in_shapes=[x1.shape, x2.shape], out_shape=res.shape
1330+
)
13151331

13161332
good_example = False
13171333
for l_idx, r_idx, o_idx in all_indices:

0 commit comments

Comments
 (0)