Skip to content

Commit da5835d

Browse files
committed
Fix local func definitions causing pass-by-reference problems
1 parent 9ee5375 commit da5835d

File tree

1 file changed

+89
-46
lines changed

1 file changed

+89
-46
lines changed

array_api_tests/test_special_cases.py

Lines changed: 89 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
from dataclasses import dataclass
55
from decimal import ROUND_HALF_EVEN, Decimal
6+
from enum import Enum, auto
67
from typing import Callable, List, Match, Protocol, Tuple
78
from warnings import warn
89

@@ -371,49 +372,88 @@ class BinaryCase(Case):
371372
r_both_inputs_are_value = re.compile("are both (.+)")
372373

373374

375+
class BinaryCondInput(Enum):
376+
FIRST = auto()
377+
SECOND = auto()
378+
BOTH = auto()
379+
EITHER = auto()
380+
381+
382+
def noop(obj):
383+
return obj
384+
385+
386+
def make_partial_cond(
387+
input_: BinaryCondInput, unary_check: UnaryCheck, *, input_wrapper=None
388+
) -> BinaryCond:
389+
if input_wrapper is None:
390+
input_wrapper = noop
391+
if input_ == BinaryCondInput.FIRST:
392+
393+
def partial_cond(i1: float, i2: float) -> bool:
394+
return unary_check(input_wrapper(i1))
395+
396+
elif input_ == BinaryCondInput.SECOND:
397+
398+
def partial_cond(i1: float, i2: float) -> bool:
399+
return unary_check(input_wrapper(i2))
400+
401+
elif input_ == BinaryCondInput.BOTH:
402+
403+
def partial_cond(i1: float, i2: float) -> bool:
404+
return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2))
405+
406+
else:
407+
408+
def partial_cond(i1: float, i2: float) -> bool:
409+
return unary_check(input_wrapper(i1)) or unary_check(input_wrapper(i2))
410+
411+
return partial_cond
412+
413+
374414
def parse_binary_case(case_m: Match) -> BinaryCase:
375415
cond_strs = r_cond_sep.split(case_m.group(1))
376-
conds = []
377-
cond_exprs = []
416+
partial_conds = []
417+
partial_exprs = []
378418
for cond_str in cond_strs:
379419
if m := r_input_is_array_element.match(cond_str):
380420
in_sign, input_array, value_sign, value_array = m.groups()
381421
assert in_sign == "" and value_array != input_array # sanity check
382-
expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ"
422+
partial_expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ"
383423
if value_array == "1":
384424
if value_sign != "-":
385425

386-
def cond(i1: float, i2: float) -> bool:
426+
def partial_cond(i1: float, i2: float) -> bool:
387427
eq = make_eq(i1)
388428
return eq(i2)
389429

390430
else:
391431

392-
def cond(i1: float, i2: float) -> bool:
432+
def partial_cond(i1: float, i2: float) -> bool:
393433
eq = make_eq(-i1)
394434
return eq(i2)
395435

396436
else:
397437
if value_sign != "-":
398438

399-
def cond(i1: float, i2: float) -> bool:
439+
def partial_cond(i1: float, i2: float) -> bool:
400440
eq = make_eq(i2)
401441
return eq(i1)
402442

403443
else:
404444

405-
def cond(i1: float, i2: float) -> bool:
445+
def partial_cond(i1: float, i2: float) -> bool:
406446
eq = make_eq(-i2)
407447
return eq(i1)
408448

409449
elif m := r_both_inputs_are_value.match(cond_str):
410450
unary_cond, expr_template = parse_cond(m.group(1))
411451
left_expr = expr_template.replace("{}", "x1ᵢ")
412452
right_expr = expr_template.replace("{}", "x2ᵢ")
413-
expr = f"({left_expr}) and ({right_expr})"
414-
415-
def cond(i1: float, i2: float) -> bool:
416-
return unary_cond(i1) and unary_cond(i2)
453+
partial_expr = f"({left_expr}) and ({right_expr})"
454+
partial_cond = make_partial_cond( # type: ignore
455+
BinaryCondInput.BOTH, unary_cond
456+
)
417457

418458
else:
419459
cond_m = r_cond.match(cond_str)
@@ -422,57 +462,58 @@ def cond(i1: float, i2: float) -> bool:
422462
input_str, value_str = cond_m.groups()
423463

424464
if value_str == "the same mathematical sign":
425-
expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)"
465+
partial_expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)"
426466

427-
def cond(i1: float, i2: float) -> bool:
467+
def partial_cond(i1: float, i2: float) -> bool:
428468
return math.copysign(1, i1) == math.copysign(1, i2)
429469

430470
elif value_str == "different mathematical signs":
431-
expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)"
471+
partial_expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)"
432472

433-
def cond(i1: float, i2: float) -> bool:
473+
def partial_cond(i1: float, i2: float) -> bool:
434474
return math.copysign(1, i1) != math.copysign(1, i2)
435475

436476
else:
437-
unary_cond, expr_template = parse_cond(value_str)
438-
477+
unary_check, expr_template = parse_cond(value_str)
478+
# Do not define partial_cond via the def keyword, as one
479+
# partial_cond definition can mess up previous definitions
480+
# in the partial_conds list. This is a hard-limitation of
481+
# using local functions with the same name and that use the same
482+
# outer variables (i.e. unary_cond).
483+
input_wrapper = None
439484
if m := r_input.match(input_str):
440485
x_no = m.group(1)
441-
args_i = int(x_no) - 1
442-
expr = expr_template.replace("{}", f"x{x_no}ᵢ")
443-
444-
def cond(*inputs) -> bool:
445-
return unary_cond(inputs[args_i])
446-
486+
partial_expr = expr_template.replace("{}", f"x{x_no}ᵢ")
487+
if x_no == "1":
488+
input_ = BinaryCondInput.FIRST
489+
else:
490+
input_ = BinaryCondInput.SECOND
447491
elif m := r_abs_input.match(input_str):
448492
x_no = m.group(1)
449-
args_i = int(x_no) - 1
450-
expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)")
451-
452-
def cond(*inputs) -> bool:
453-
return unary_cond(abs(inputs[args_i]))
454-
493+
partial_expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)")
494+
if x_no == "1":
495+
input_ = BinaryCondInput.FIRST
496+
else:
497+
input_ = BinaryCondInput.SECOND
498+
input_wrapper = abs
455499
elif r_and_input.match(input_str):
456500
left_expr = expr_template.replace("{}", "x1ᵢ")
457501
right_expr = expr_template.replace("{}", "x2ᵢ")
458-
expr = f"({left_expr}) and ({right_expr})"
459-
460-
def cond(i1: float, i2: float) -> bool:
461-
return unary_cond(i1) and unary_cond(i2)
462-
502+
partial_expr = f"({left_expr}) and ({right_expr})"
503+
input_ = BinaryCondInput.BOTH
463504
elif r_or_input.match(input_str):
464505
left_expr = expr_template.replace("{}", "x1ᵢ")
465506
right_expr = expr_template.replace("{}", "x2ᵢ")
466-
expr = f"({left_expr}) or ({right_expr})"
467-
468-
def cond(i1: float, i2: float) -> bool:
469-
return unary_cond(i1) or unary_cond(i2)
470-
507+
partial_expr = f"({left_expr}) or ({right_expr})"
508+
input_ = BinaryCondInput.EITHER
471509
else:
472510
raise ValueParseError(input_str)
511+
partial_cond = make_partial_cond( # type: ignore
512+
input_, unary_check, input_wrapper=input_wrapper
513+
)
473514

474-
conds.append(cond)
475-
cond_exprs.append(expr)
515+
partial_conds.append(partial_cond)
516+
partial_exprs.append(partial_expr)
476517

477518
result_m = r_result.match(case_m.group(2))
478519
if result_m is None:
@@ -513,10 +554,10 @@ def check_result(i1: float, i2: float, result: float) -> bool:
513554
def check_result(i1: float, i2: float, result: float) -> bool:
514555
return _check_result(result)
515556

516-
expr = " and ".join(cond_exprs) + " -> " + result_expr
557+
expr = " and ".join(partial_exprs) + " -> " + result_expr
517558

518559
def cond(i1: float, i2: float) -> bool:
519-
return all(c(i1, i2) for c in conds)
560+
return all(pc(i1, i2) for pc in partial_conds)
520561

521562
return BinaryCase(expr, cond, check_result)
522563

@@ -639,8 +680,10 @@ def test_binary(func_name, func, cases, x1, x2):
639680
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
640681
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
641682
f_out = f"{sh.fmt_idx('out', o_idx)}={o}"
642-
assert case.check_result(
643-
l, r, o
644-
), f"{f_out} not good [{func_name}()]\n{f_left}, {f_right}"
683+
assert case.check_result(l, r, o), (
684+
f"{f_out} not good [{func_name}()]\n"
685+
f"{case.expr}\n"
686+
f"{f_left}, {f_right}"
687+
)
645688
break
646689
assume(good_example)

0 commit comments

Comments
 (0)