3
3
import re
4
4
from dataclasses import dataclass
5
5
from decimal import ROUND_HALF_EVEN , Decimal
6
+ from enum import Enum , auto
6
7
from typing import Callable , List , Match , Protocol , Tuple
7
8
from warnings import warn
8
9
@@ -371,49 +372,88 @@ class BinaryCase(Case):
371
372
r_both_inputs_are_value = re .compile ("are both (.+)" )
372
373
373
374
375
+ class BinaryCondInput (Enum ):
376
+ FIRST = auto ()
377
+ SECOND = auto ()
378
+ BOTH = auto ()
379
+ EITHER = auto ()
380
+
381
+
382
+ def noop (obj ):
383
+ return obj
384
+
385
+
386
+ def make_partial_cond (
387
+ input_ : BinaryCondInput , unary_check : UnaryCheck , * , input_wrapper = None
388
+ ) -> BinaryCond :
389
+ if input_wrapper is None :
390
+ input_wrapper = noop
391
+ if input_ == BinaryCondInput .FIRST :
392
+
393
+ def partial_cond (i1 : float , i2 : float ) -> bool :
394
+ return unary_check (input_wrapper (i1 ))
395
+
396
+ elif input_ == BinaryCondInput .SECOND :
397
+
398
+ def partial_cond (i1 : float , i2 : float ) -> bool :
399
+ return unary_check (input_wrapper (i2 ))
400
+
401
+ elif input_ == BinaryCondInput .BOTH :
402
+
403
+ def partial_cond (i1 : float , i2 : float ) -> bool :
404
+ return unary_check (input_wrapper (i1 )) and unary_check (input_wrapper (i2 ))
405
+
406
+ else :
407
+
408
+ def partial_cond (i1 : float , i2 : float ) -> bool :
409
+ return unary_check (input_wrapper (i1 )) or unary_check (input_wrapper (i2 ))
410
+
411
+ return partial_cond
412
+
413
+
374
414
def parse_binary_case (case_m : Match ) -> BinaryCase :
375
415
cond_strs = r_cond_sep .split (case_m .group (1 ))
376
- conds = []
377
- cond_exprs = []
416
+ partial_conds = []
417
+ partial_exprs = []
378
418
for cond_str in cond_strs :
379
419
if m := r_input_is_array_element .match (cond_str ):
380
420
in_sign , input_array , value_sign , value_array = m .groups ()
381
421
assert in_sign == "" and value_array != input_array # sanity check
382
- expr = f"{ in_sign } x{ input_array } ᵢ == { value_sign } x{ value_array } ᵢ"
422
+ partial_expr = f"{ in_sign } x{ input_array } ᵢ == { value_sign } x{ value_array } ᵢ"
383
423
if value_array == "1" :
384
424
if value_sign != "-" :
385
425
386
- def cond (i1 : float , i2 : float ) -> bool :
426
+ def partial_cond (i1 : float , i2 : float ) -> bool :
387
427
eq = make_eq (i1 )
388
428
return eq (i2 )
389
429
390
430
else :
391
431
392
- def cond (i1 : float , i2 : float ) -> bool :
432
+ def partial_cond (i1 : float , i2 : float ) -> bool :
393
433
eq = make_eq (- i1 )
394
434
return eq (i2 )
395
435
396
436
else :
397
437
if value_sign != "-" :
398
438
399
- def cond (i1 : float , i2 : float ) -> bool :
439
+ def partial_cond (i1 : float , i2 : float ) -> bool :
400
440
eq = make_eq (i2 )
401
441
return eq (i1 )
402
442
403
443
else :
404
444
405
- def cond (i1 : float , i2 : float ) -> bool :
445
+ def partial_cond (i1 : float , i2 : float ) -> bool :
406
446
eq = make_eq (- i2 )
407
447
return eq (i1 )
408
448
409
449
elif m := r_both_inputs_are_value .match (cond_str ):
410
450
unary_cond , expr_template = parse_cond (m .group (1 ))
411
451
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
412
452
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
413
- expr = f"({ left_expr } ) and ({ right_expr } )"
414
-
415
- def cond ( i1 : float , i2 : float ) -> bool :
416
- return unary_cond ( i1 ) and unary_cond ( i2 )
453
+ partial_expr = f"({ left_expr } ) and ({ right_expr } )"
454
+ partial_cond = make_partial_cond ( # type: ignore
455
+ BinaryCondInput . BOTH , unary_cond
456
+ )
417
457
418
458
else :
419
459
cond_m = r_cond .match (cond_str )
@@ -422,57 +462,58 @@ def cond(i1: float, i2: float) -> bool:
422
462
input_str , value_str = cond_m .groups ()
423
463
424
464
if value_str == "the same mathematical sign" :
425
- expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)"
465
+ partial_expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)"
426
466
427
- def cond (i1 : float , i2 : float ) -> bool :
467
+ def partial_cond (i1 : float , i2 : float ) -> bool :
428
468
return math .copysign (1 , i1 ) == math .copysign (1 , i2 )
429
469
430
470
elif value_str == "different mathematical signs" :
431
- expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)"
471
+ partial_expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)"
432
472
433
- def cond (i1 : float , i2 : float ) -> bool :
473
+ def partial_cond (i1 : float , i2 : float ) -> bool :
434
474
return math .copysign (1 , i1 ) != math .copysign (1 , i2 )
435
475
436
476
else :
437
- unary_cond , expr_template = parse_cond (value_str )
438
-
477
+ unary_check , expr_template = parse_cond (value_str )
478
+ # Do not define partial_cond via the def keyword, as one
479
+ # partial_cond definition can mess up previous definitions
480
+ # in the partial_conds list. This is a hard-limitation of
481
+ # using local functions with the same name and that use the same
482
+ # outer variables (i.e. unary_cond).
483
+ input_wrapper = None
439
484
if m := r_input .match (input_str ):
440
485
x_no = m .group (1 )
441
- args_i = int (x_no ) - 1
442
- expr = expr_template .replace ("{}" , f"x{ x_no } ᵢ" )
443
-
444
- def cond (* inputs ) -> bool :
445
- return unary_cond (inputs [args_i ])
446
-
486
+ partial_expr = expr_template .replace ("{}" , f"x{ x_no } ᵢ" )
487
+ if x_no == "1" :
488
+ input_ = BinaryCondInput .FIRST
489
+ else :
490
+ input_ = BinaryCondInput .SECOND
447
491
elif m := r_abs_input .match (input_str ):
448
492
x_no = m .group (1 )
449
- args_i = int ( x_no ) - 1
450
- expr = expr_template . replace ( "{}" , f"abs(x { x_no } ᵢ)" )
451
-
452
- def cond ( * inputs ) -> bool :
453
- return unary_cond ( abs ( inputs [ args_i ]))
454
-
493
+ partial_expr = expr_template . replace ( "{}" , f"abs(x { x_no } ᵢ)" )
494
+ if x_no == "1" :
495
+ input_ = BinaryCondInput . FIRST
496
+ else :
497
+ input_ = BinaryCondInput . SECOND
498
+ input_wrapper = abs
455
499
elif r_and_input .match (input_str ):
456
500
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
457
501
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
458
- expr = f"({ left_expr } ) and ({ right_expr } )"
459
-
460
- def cond (i1 : float , i2 : float ) -> bool :
461
- return unary_cond (i1 ) and unary_cond (i2 )
462
-
502
+ partial_expr = f"({ left_expr } ) and ({ right_expr } )"
503
+ input_ = BinaryCondInput .BOTH
463
504
elif r_or_input .match (input_str ):
464
505
left_expr = expr_template .replace ("{}" , "x1ᵢ" )
465
506
right_expr = expr_template .replace ("{}" , "x2ᵢ" )
466
- expr = f"({ left_expr } ) or ({ right_expr } )"
467
-
468
- def cond (i1 : float , i2 : float ) -> bool :
469
- return unary_cond (i1 ) or unary_cond (i2 )
470
-
507
+ partial_expr = f"({ left_expr } ) or ({ right_expr } )"
508
+ input_ = BinaryCondInput .EITHER
471
509
else :
472
510
raise ValueParseError (input_str )
511
+ partial_cond = make_partial_cond ( # type: ignore
512
+ input_ , unary_check , input_wrapper = input_wrapper
513
+ )
473
514
474
- conds .append (cond )
475
- cond_exprs .append (expr )
515
+ partial_conds .append (partial_cond )
516
+ partial_exprs .append (partial_expr )
476
517
477
518
result_m = r_result .match (case_m .group (2 ))
478
519
if result_m is None :
@@ -513,10 +554,10 @@ def check_result(i1: float, i2: float, result: float) -> bool:
513
554
def check_result (i1 : float , i2 : float , result : float ) -> bool :
514
555
return _check_result (result )
515
556
516
- expr = " and " .join (cond_exprs ) + " -> " + result_expr
557
+ expr = " and " .join (partial_exprs ) + " -> " + result_expr
517
558
518
559
def cond (i1 : float , i2 : float ) -> bool :
519
- return all (c (i1 , i2 ) for c in conds )
560
+ return all (pc (i1 , i2 ) for pc in partial_conds )
520
561
521
562
return BinaryCase (expr , cond , check_result )
522
563
@@ -639,8 +680,10 @@ def test_binary(func_name, func, cases, x1, x2):
639
680
f_left = f"{ sh .fmt_idx ('x1' , l_idx )} ={ l } "
640
681
f_right = f"{ sh .fmt_idx ('x2' , r_idx )} ={ r } "
641
682
f_out = f"{ sh .fmt_idx ('out' , o_idx )} ={ o } "
642
- assert case .check_result (
643
- l , r , o
644
- ), f"{ f_out } not good [{ func_name } ()]\n { f_left } , { f_right } "
683
+ assert case .check_result (l , r , o ), (
684
+ f"{ f_out } not good [{ func_name } ()]\n "
685
+ f"{ case .expr } \n "
686
+ f"{ f_left } , { f_right } "
687
+ )
645
688
break
646
689
assume (good_example )
0 commit comments