@@ -714,10 +714,30 @@ def partial_cond(i1: float, i2: float) -> bool:
714
714
x1_cond_from_dtypes .append (cond_from_dtype )
715
715
elif cond_arg == BinaryCondArg .SECOND :
716
716
x2_cond_from_dtypes .append (cond_from_dtype )
717
- else :
718
- # TODO: xor scenarios
717
+ elif cond_arg == BinaryCondArg .BOTH :
719
718
x1_cond_from_dtypes .append (cond_from_dtype )
720
719
x2_cond_from_dtypes .append (cond_from_dtype )
720
+ else :
721
+ use_x1_or_x2_strat = st .shared (
722
+ st .sampled_from ([(True , False ), (True , False ), (True , True )])
723
+ )
724
+
725
+ def x1_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
726
+ return use_x1_or_x2_strat .flatmap (
727
+ lambda t : cond_from_dtype (dtype )
728
+ if t [0 ]
729
+ else xps .from_dtype (dtype )
730
+ )
731
+
732
+ def x2_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
733
+ return use_x1_or_x2_strat .flatmap (
734
+ lambda t : cond_from_dtype (dtype )
735
+ if t [1 ]
736
+ else xps .from_dtype (dtype )
737
+ )
738
+
739
+ x1_cond_from_dtypes .append (x1_cond_from_dtype )
740
+ x2_cond_from_dtypes .append (x2_cond_from_dtype )
721
741
722
742
partial_conds .append (partial_cond )
723
743
partial_exprs .append (partial_expr )
@@ -750,15 +770,15 @@ def cond(i1: float, i2: float) -> bool:
750
770
else :
751
771
# sanity check
752
772
assert all (isinstance (fd , BoundFromDtype ) for fd in x1_cond_from_dtypes )
753
- x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ({}, None ))
773
+ x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ({}))
754
774
if len (x2_cond_from_dtypes ) == 0 :
755
775
x2_cond_from_dtype = xps .from_dtype
756
776
elif len (x2_cond_from_dtypes ) == 1 :
757
777
x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
758
778
else :
759
779
# sanity check
760
780
assert all (isinstance (fd , BoundFromDtype ) for fd in x2_cond_from_dtypes )
761
- x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ({}, None ))
781
+ x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ({}))
762
782
763
783
return BinaryCase (
764
784
cond_expr = cond_expr ,
@@ -896,11 +916,11 @@ def test_binary(func_name, func, case, x1, x2, data):
896
916
897
917
indices_strat = st .shared (st .sampled_from (all_indices ))
898
918
set_x1_idx = data .draw (indices_strat .map (lambda t : t [0 ]), label = "set x1 idx" )
899
- set_x2_idx = data .draw (indices_strat .map (lambda t : t [1 ]), label = "set x2 idx" )
900
919
set_x1_value = data .draw (case .x1_cond_from_dtype (x1 .dtype ), label = "set x1 value" )
901
- set_x2_value = data .draw (case .x2_cond_from_dtype (x2 .dtype ), label = "set x2 value" )
902
920
x1 [set_x1_idx ] = set_x1_value
903
921
note (f"{ x1 = } " )
922
+ set_x2_idx = data .draw (indices_strat .map (lambda t : t [1 ]), label = "set x2 idx" )
923
+ set_x2_value = data .draw (case .x2_cond_from_dtype (x2 .dtype ), label = "set x2 value" )
904
924
x2 [set_x2_idx ] = set_x2_value
905
925
note (f"{ x2 = } " )
906
926
0 commit comments