8
8
Dict ,
9
9
List ,
10
10
Literal ,
11
+ Match ,
11
12
NamedTuple ,
12
13
Pattern ,
13
14
Protocol ,
@@ -617,6 +618,87 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase:
617
618
): BinaryCaseFactory (ValueCondFactory ("i2" , 0 ), ResultCheckFactory (1 )),
618
619
}
619
620
621
+ r_binary_case = re .compile ("If (.+), the result (.+)" )
622
+
623
+ r_cond_sep = re .compile (", | and " )
624
+ r_cond = re .compile ("(.+) (?:is|have) (.+)" )
625
+
626
+ r_element = re .compile ("x([12])_i" )
627
+ r_input = re .compile (rf"``{ r_element .pattern } ``" )
628
+ r_abs_input = re .compile (r"``abs\({r_element.pattern}\)``" )
629
+ r_and_input = re .compile (f"{ r_input .pattern } and { r_input .pattern } " )
630
+ r_or_input = re .compile (f"either { r_input .pattern } or { r_input .pattern } " )
631
+
632
+ r_result = re .compile (r"(?:is|has a) (.+)" )
633
+
634
+ r_both_inputs_are_value = re .compile ("are both (.+)" )
635
+
636
+
637
+ def parse_binary_case (case_m : Match ) -> BinaryCase :
638
+ cond_strs = r_cond_sep .split (case_m .group (1 ))
639
+ conds = []
640
+ cond_exprs = []
641
+ for cond_str in cond_strs :
642
+ if m := r_both_inputs_are_value .match (cond_str ):
643
+ raise ValueParseError (cond_str )
644
+ else :
645
+ cond_m = r_cond .match (cond_str )
646
+ if cond_m is None :
647
+ raise ValueParseError (cond_str )
648
+ input_str , value_str = cond_m .groups ()
649
+
650
+ unary_cond , expr_template = parse_cond (value_str )
651
+
652
+ if m := r_input .match (input_str ):
653
+ x_no = m .group (1 )
654
+ args_i = int (x_no ) - 1
655
+ expr = expr_template .replace ("{}" , f"x{ x_no } ᵢ" )
656
+
657
+ def cond (* inputs ) -> bool :
658
+ return unary_cond (inputs [args_i ])
659
+
660
+ elif m := r_abs_input .match (input_str ):
661
+ x_no = m .group (1 )
662
+ args_i = int (x_no ) - 1
663
+ expr = expr_template .replace ("{}" , f"abs(x{ x_no } ᵢ)" )
664
+
665
+ def cond (* inputs ) -> bool :
666
+ return unary_cond (abs (inputs [args_i ]))
667
+
668
+ elif r_and_input .match (input_str ):
669
+ left_expr = expr_template .replace ("{}" , "x1ᵢ" )
670
+ right_expr = expr_template .replace ("{}" , "x2ᵢ" )
671
+ expr = f"({ left_expr } ) and ({ right_expr } )"
672
+
673
+ def cond (i1 : float , i2 : float ) -> bool :
674
+ return unary_cond (i1 ) and unary_cond (i2 )
675
+
676
+ elif r_or_input .match (input_str ):
677
+ left_expr = expr_template .replace ("{}" , "x1ᵢ" )
678
+ right_expr = expr_template .replace ("{}" , "x2ᵢ" )
679
+ expr = f"({ left_expr } ) and ({ right_expr } )"
680
+
681
+ def cond (i1 : float , i2 : float ) -> bool :
682
+ return unary_cond (i1 ) or unary_cond (i2 )
683
+
684
+ else :
685
+ raise ValueParseError (input_str )
686
+
687
+ conds .append (cond )
688
+ cond_exprs .append (expr )
689
+
690
+ result_m = r_result .match (case_m .group (2 ))
691
+ if result_m is None :
692
+ raise ValueParseError (case_m .group (2 ))
693
+ check_result , result_expr = parse_result (result_m .group (1 ))
694
+
695
+ expr = " and " .join (f"({ expr } )" for expr in cond_exprs ) + " -> " + result_expr
696
+
697
+ def cond (i1 : float , i2 : float ) -> bool :
698
+ return all (c (i1 , i2 ) for c in conds )
699
+
700
+ return BinaryCase (expr , cond , lambda l , r , o : check_result (o ))
701
+
620
702
621
703
r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
622
704
@@ -629,24 +711,23 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
629
711
cases = []
630
712
for line in lines :
631
713
if m := r_case .match (line ):
632
- case = m .group (1 )
714
+ case_str = m .group (1 )
633
715
else :
634
716
warn (f"line not machine-readable: '{ line } '" )
635
717
continue
636
- if r_redundant_case .search (case ):
718
+ if r_redundant_case .search (case_str ):
637
719
continue
638
- for pattern , make_case in binary_pattern_to_case_factory .items ():
639
- if m := pattern .search (case ):
640
- try :
641
- case = make_case (m .groups ())
642
- except ValueParseError as e :
643
- warn (f"not machine-readable: '{ e .value } '" )
644
- break
645
- cases .append (case )
720
+ if m := r_binary_case .search (case_str ):
721
+ try :
722
+ case = parse_binary_case (m )
723
+ except ValueParseError as e :
724
+ warn (f"not machine-readable: '{ e .value } '" )
646
725
break
726
+ cases .append (case )
727
+ break
647
728
else :
648
- if not r_remaining_case .search (case ):
649
- warn (f"case not machine-readable: '{ case } '" )
729
+ if not r_remaining_case .search (case_str ):
730
+ warn (f"case not machine-readable: '{ case_str } '" )
650
731
return cases
651
732
652
733
0 commit comments