Skip to content

Commit c905a27

Browse files
committed
Rudimentary generic condition parsing for binary cases
1 parent d557a1b commit c905a27

File tree

1 file changed

+93
-12
lines changed

1 file changed

+93
-12
lines changed

array_api_tests/test_special_cases.py

Lines changed: 93 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Dict,
99
List,
1010
Literal,
11+
Match,
1112
NamedTuple,
1213
Pattern,
1314
Protocol,
@@ -617,6 +618,87 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
617618
): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)),
618619
}
619620

621+
r_binary_case = re.compile("If (.+), the result (.+)")
622+
623+
r_cond_sep = re.compile(", | and ")
624+
r_cond = re.compile("(.+) (?:is|have) (.+)")
625+
626+
r_element = re.compile("x([12])_i")
627+
r_input = re.compile(rf"``{r_element.pattern}``")
628+
r_abs_input = re.compile(r"``abs\({r_element.pattern}\)``")
629+
r_and_input = re.compile(f"{r_input.pattern} and {r_input.pattern}")
630+
r_or_input = re.compile(f"either {r_input.pattern} or {r_input.pattern}")
631+
632+
r_result = re.compile(r"(?:is|has a) (.+)")
633+
634+
r_both_inputs_are_value = re.compile("are both (.+)")
635+
636+
637+
def parse_binary_case(case_m: Match) -> BinaryCase:
638+
cond_strs = r_cond_sep.split(case_m.group(1))
639+
conds = []
640+
cond_exprs = []
641+
for cond_str in cond_strs:
642+
if m := r_both_inputs_are_value.match(cond_str):
643+
raise ValueParseError(cond_str)
644+
else:
645+
cond_m = r_cond.match(cond_str)
646+
if cond_m is None:
647+
raise ValueParseError(cond_str)
648+
input_str, value_str = cond_m.groups()
649+
650+
unary_cond, expr_template = parse_cond(value_str)
651+
652+
if m := r_input.match(input_str):
653+
x_no = m.group(1)
654+
args_i = int(x_no) - 1
655+
expr = expr_template.replace("{}", f"x{x_no}ᵢ")
656+
657+
def cond(*inputs) -> bool:
658+
return unary_cond(inputs[args_i])
659+
660+
elif m := r_abs_input.match(input_str):
661+
x_no = m.group(1)
662+
args_i = int(x_no) - 1
663+
expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)")
664+
665+
def cond(*inputs) -> bool:
666+
return unary_cond(abs(inputs[args_i]))
667+
668+
elif r_and_input.match(input_str):
669+
left_expr = expr_template.replace("{}", "x1ᵢ")
670+
right_expr = expr_template.replace("{}", "x2ᵢ")
671+
expr = f"({left_expr}) and ({right_expr})"
672+
673+
def cond(i1: float, i2: float) -> bool:
674+
return unary_cond(i1) and unary_cond(i2)
675+
676+
elif r_or_input.match(input_str):
677+
left_expr = expr_template.replace("{}", "x1ᵢ")
678+
right_expr = expr_template.replace("{}", "x2ᵢ")
679+
expr = f"({left_expr}) and ({right_expr})"
680+
681+
def cond(i1: float, i2: float) -> bool:
682+
return unary_cond(i1) or unary_cond(i2)
683+
684+
else:
685+
raise ValueParseError(input_str)
686+
687+
conds.append(cond)
688+
cond_exprs.append(expr)
689+
690+
result_m = r_result.match(case_m.group(2))
691+
if result_m is None:
692+
raise ValueParseError(case_m.group(2))
693+
check_result, result_expr = parse_result(result_m.group(1))
694+
695+
expr = " and ".join(f"({expr})" for expr in cond_exprs) + " -> " + result_expr
696+
697+
def cond(i1: float, i2: float) -> bool:
698+
return all(c(i1, i2) for c in conds)
699+
700+
return BinaryCase(expr, cond, lambda l, r, o: check_result(o))
701+
620702

621703
r_redundant_case = re.compile("result.+determined by the rule already stated above")
622704

@@ -629,24 +711,23 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
629711
cases = []
630712
for line in lines:
631713
if m := r_case.match(line):
632-
case = m.group(1)
714+
case_str = m.group(1)
633715
else:
634716
warn(f"line not machine-readable: '{line}'")
635717
continue
636-
if r_redundant_case.search(case):
718+
if r_redundant_case.search(case_str):
637719
continue
638-
for pattern, make_case in binary_pattern_to_case_factory.items():
639-
if m := pattern.search(case):
640-
try:
641-
case = make_case(m.groups())
642-
except ValueParseError as e:
643-
warn(f"not machine-readable: '{e.value}'")
644-
break
645-
cases.append(case)
720+
if m := r_binary_case.search(case_str):
721+
try:
722+
case = parse_binary_case(m)
723+
except ValueParseError as e:
724+
warn(f"not machine-readable: '{e.value}'")
646725
break
726+
cases.append(case)
727+
break
647728
else:
648-
if not r_remaining_case.search(case):
649-
warn(f"case not machine-readable: '{case}'")
729+
if not r_remaining_case.search(case_str):
730+
warn(f"case not machine-readable: '{case_str}'")
650731
return cases
651732

652733

0 commit comments

Comments
 (0)