Skip to content

Commit b44aad4

Browse files
committed
Make parse_cond() only return BoundFromDtype
Also add some more granular documentation to it
1 parent f63a6a7 commit b44aad4

File tree

1 file changed

+46
-63
lines changed

1 file changed

+46
-63
lines changed

array_api_tests/test_special_cases.py

Lines changed: 46 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
319319
return from_dtype
320320

321321

322-
def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
322+
def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]:
323323
"""
324324
Parses a Sphinx-formatted condition string to return:
325325
@@ -348,22 +348,30 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
348348
124.978
349349
350350
"""
351+
# We first identify whether the condition starts with "not". If so, we note
352+
# this but parse the condition as if it was not negated.
351353
if m := r_not.match(cond_str):
352354
cond_str = m.group(1)
353355
not_cond = True
354356
else:
355357
not_cond = False
356358

359+
# We parse the condition to identify the condition function, expression
360+
# template, and xps.from_dtype()-like condition strategy.
357361
kwargs = {}
358362
filter_ = None
359363
from_dtype = None # type: ignore
360-
strat = None
361364
if m := r_code.match(cond_str):
362365
value = parse_value(m.group(1))
363366
cond = make_strict_eq(value)
364367
expr_template = "{} == " + m.group(1)
365-
if not not_cond:
366-
strat = st.just(value)
368+
from_dtype = wrap_strat_as_from_dtype(st.just(value))
369+
elif m := r_either_code.match(cond_str):
370+
v1 = parse_value(m.group(1))
371+
v2 = parse_value(m.group(2))
372+
cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
373+
expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")"
374+
from_dtype = wrap_strat_as_from_dtype(st.sampled_from([v1, v2]))
367375
elif m := r_equal_to.match(cond_str):
368376
value = parse_value(m.group(1))
369377
if math.isnan(value):
@@ -374,97 +382,73 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
374382
value = parse_value(m.group(1))
375383
cond = make_gt(value)
376384
expr_template = "{} > " + m.group(1)
377-
if not not_cond:
378-
kwargs = {"min_value": value, "exclude_min": True}
385+
kwargs = {"min_value": value, "exclude_min": True}
379386
elif m := r_lt.match(cond_str):
380387
value = parse_value(m.group(1))
381388
cond = make_lt(value)
382389
expr_template = "{} < " + m.group(1)
383-
if not not_cond:
384-
kwargs = {"max_value": value, "exclude_max": True}
385-
elif m := r_either_code.match(cond_str):
386-
v1 = parse_value(m.group(1))
387-
v2 = parse_value(m.group(2))
388-
cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
389-
expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")"
390-
if not not_cond:
391-
strat = st.sampled_from([v1, v2])
390+
kwargs = {"max_value": value, "exclude_max": True}
392391
elif cond_str in ["finite", "a finite number"]:
393392
cond = math.isfinite
394393
expr_template = "isfinite({})"
395-
if not not_cond:
396-
kwargs = {"allow_nan": False, "allow_infinity": False}
394+
kwargs = {"allow_nan": False, "allow_infinity": False}
397395
elif cond_str in "a positive (i.e., greater than ``0``) finite number":
398396
cond = lambda i: math.isfinite(i) and i > 0
399397
expr_template = "isfinite({}) and {} > 0"
400-
if not not_cond:
401-
kwargs = {
402-
"allow_nan": False,
403-
"allow_infinity": False,
404-
"min_value": 0,
405-
"exclude_min": True,
406-
}
398+
kwargs = {
399+
"allow_nan": False,
400+
"allow_infinity": False,
401+
"min_value": 0,
402+
"exclude_min": True,
403+
}
407404
elif cond_str == "a negative (i.e., less than ``0``) finite number":
408405
cond = lambda i: math.isfinite(i) and i < 0
409406
expr_template = "isfinite({}) and {} < 0"
410-
if not not_cond:
411-
kwargs = {
412-
"allow_nan": False,
413-
"allow_infinity": False,
414-
"max_value": 0,
415-
"exclude_max": True,
416-
}
407+
kwargs = {
408+
"allow_nan": False,
409+
"allow_infinity": False,
410+
"max_value": 0,
411+
"exclude_max": True,
412+
}
417413
elif cond_str == "positive":
418414
cond = lambda i: math.copysign(1, i) == 1
419415
expr_template = "copysign(1, {}) == 1"
420-
if not not_cond:
421-
# We assume (positive) zero is special cased seperately
422-
kwargs = {"min_value": 0, "exclude_min": True}
416+
# We assume (positive) zero is special cased seperately
417+
kwargs = {"min_value": 0, "exclude_min": True}
423418
elif cond_str == "negative":
424419
cond = lambda i: math.copysign(1, i) == -1
425420
expr_template = "copysign(1, {}) == -1"
426-
if not not_cond:
427-
# We assume (negative) zero is special cased seperately
428-
kwargs = {"max_value": 0, "exclude_max": True}
421+
# We assume (negative) zero is special cased seperately
422+
kwargs = {"max_value": 0, "exclude_max": True}
429423
elif "nonzero finite" in cond_str:
430424
cond = lambda i: math.isfinite(i) and i != 0
431425
expr_template = "isfinite({}) and {} != 0"
432-
if not not_cond:
433-
kwargs = {"allow_nan": False, "allow_infinity": False}
434-
filter_ = lambda n: n != 0
426+
kwargs = {"allow_nan": False, "allow_infinity": False}
427+
filter_ = lambda n: n != 0
435428
elif cond_str == "an integer value":
436429
cond = lambda i: i.is_integer()
437430
expr_template = "{}.is_integer()"
438-
if not not_cond:
439-
from_dtype = integers_from_dtype # type: ignore
431+
from_dtype = integers_from_dtype # type: ignore
440432
elif cond_str == "an odd integer value":
441433
cond = lambda i: i.is_integer() and i % 2 == 1
442434
expr_template = "{}.is_integer() and {} % 2 == 1"
443-
if not not_cond:
435+
from_dtype = integers_from_dtype # type: ignore
444436

445-
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
446-
return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
437+
def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
438+
return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
447439

448440
else:
449441
raise ParseError(cond_str)
450442

451-
if strat is not None:
452-
if (
453-
not_cond
454-
or len(kwargs) != 0
455-
or filter_ is not None
456-
or from_dtype is not None
457-
):
458-
raise ParseError(cond_str)
459-
return cond, expr_template, wrap_strat_as_from_dtype(strat)
460-
461443
if not_cond:
462-
expr_template = f"not {expr_template}"
444+
# We handle negated conitions by simply negating the condition function
445+
# and using it as a filter for xps.from_dtype() (or an equivalent).
463446
cond = make_not_cond(cond)
464-
kwargs = {}
447+
expr_template = f"not {expr_template}"
465448
filter_ = cond
466-
assert kwargs is not None
467-
return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype)
449+
return cond, expr_template, BoundFromDtype(filter_=filter_)
450+
else:
451+
return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype)
468452

469453

470454
def parse_result(result_str: str) -> Tuple[UnaryCheck, str]:
@@ -838,6 +822,9 @@ def check_result(i1: float, i2: float, result: float) -> bool:
838822

839823

840824
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
825+
"""
826+
Returns a strategy that generates float-casted integers within the bounds of dtype.
827+
"""
841828
for k in kw.keys():
842829
# sanity check
843830
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
@@ -1036,16 +1023,12 @@ def cond(i1: float, i2: float) -> bool:
10361023
elif len(x1_cond_from_dtypes) == 1:
10371024
x1_cond_from_dtype = x1_cond_from_dtypes[0]
10381025
else:
1039-
if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes):
1040-
raise ParseError(case_str)
10411026
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
10421027
if len(x2_cond_from_dtypes) == 0:
10431028
x2_cond_from_dtype = xps.from_dtype
10441029
elif len(x2_cond_from_dtypes) == 1:
10451030
x2_cond_from_dtype = x2_cond_from_dtypes[0]
10461031
else:
1047-
if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes):
1048-
raise ParseError(case_str)
10491032
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
10501033

10511034
return BinaryCase(

0 commit comments

Comments
 (0)