Skip to content

Commit 429a9a5

Browse files
committed
Test xor scenarios for either special cases
1 parent b9800b1 commit 429a9a5

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

array_api_tests/test_special_cases.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,30 @@ def partial_cond(i1: float, i2: float) -> bool:
714714
x1_cond_from_dtypes.append(cond_from_dtype)
715715
elif cond_arg == BinaryCondArg.SECOND:
716716
x2_cond_from_dtypes.append(cond_from_dtype)
717-
else:
718-
# TODO: xor scenarios
717+
elif cond_arg == BinaryCondArg.BOTH:
719718
x1_cond_from_dtypes.append(cond_from_dtype)
720719
x2_cond_from_dtypes.append(cond_from_dtype)
720+
else:
721+
use_x1_or_x2_strat = st.shared(
722+
st.sampled_from([(True, False), (True, False), (True, True)])
723+
)
724+
725+
def x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
726+
return use_x1_or_x2_strat.flatmap(
727+
lambda t: cond_from_dtype(dtype)
728+
if t[0]
729+
else xps.from_dtype(dtype)
730+
)
731+
732+
def x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
733+
return use_x1_or_x2_strat.flatmap(
734+
lambda t: cond_from_dtype(dtype)
735+
if t[1]
736+
else xps.from_dtype(dtype)
737+
)
738+
739+
x1_cond_from_dtypes.append(x1_cond_from_dtype)
740+
x2_cond_from_dtypes.append(x2_cond_from_dtype)
721741

722742
partial_conds.append(partial_cond)
723743
partial_exprs.append(partial_expr)
@@ -750,15 +770,15 @@ def cond(i1: float, i2: float) -> bool:
750770
else:
751771
# sanity check
752772
assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes)
753-
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}, None))
773+
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}))
754774
if len(x2_cond_from_dtypes) == 0:
755775
x2_cond_from_dtype = xps.from_dtype
756776
elif len(x2_cond_from_dtypes) == 1:
757777
x2_cond_from_dtype = x2_cond_from_dtypes[0]
758778
else:
759779
# sanity check
760780
assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes)
761-
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}, None))
781+
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}))
762782

763783
return BinaryCase(
764784
cond_expr=cond_expr,
@@ -896,11 +916,11 @@ def test_binary(func_name, func, case, x1, x2, data):
896916

897917
indices_strat = st.shared(st.sampled_from(all_indices))
898918
set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx")
899-
set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx")
900919
set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value")
901-
set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value")
902920
x1[set_x1_idx] = set_x1_value
903921
note(f"{x1=}")
922+
set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx")
923+
set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value")
904924
x2[set_x2_idx] = set_x2_value
905925
note(f"{x2=}")
906926

0 commit comments

Comments
 (0)