@@ -33,13 +33,6 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
33
33
return xps .boolean_dtypes () | all_integer_dtypes ()
34
34
35
35
36
- def all_floating_dtypes () -> st .SearchStrategy [DataType ]:
37
- strat = xps .floating_dtypes ()
38
- if api_version >= "2022.12" :
39
- strat |= xps .complex_dtypes ()
40
- return strat
41
-
42
-
43
36
def mock_int_dtype (n : int , dtype : DataType ) -> int :
44
37
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
45
38
nbits = dh .dtype_nbits [dtype ]
@@ -714,7 +707,7 @@ def test_abs(ctx, data):
714
707
)
715
708
716
709
717
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
710
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
718
711
def test_acos (x ):
719
712
out = xp .acos (x )
720
713
ph .assert_dtype ("acos" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -724,7 +717,7 @@ def test_acos(x):
724
717
)
725
718
726
719
727
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
720
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
728
721
def test_acosh (x ):
729
722
out = xp .acosh (x )
730
723
ph .assert_dtype ("acosh" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -748,7 +741,7 @@ def test_add(ctx, data):
748
741
binary_param_assert_against_refimpl (ctx , left , right , res , "+" , operator .add )
749
742
750
743
751
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
744
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
752
745
def test_asin (x ):
753
746
out = xp .asin (x )
754
747
ph .assert_dtype ("asin" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -758,15 +751,15 @@ def test_asin(x):
758
751
)
759
752
760
753
761
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
754
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
762
755
def test_asinh (x ):
763
756
out = xp .asinh (x )
764
757
ph .assert_dtype ("asinh" , in_dtype = x .dtype , out_dtype = out .dtype )
765
758
ph .assert_shape ("asinh" , out_shape = out .shape , expected = x .shape )
766
759
unary_assert_against_refimpl ("asinh" , x , out , math .asinh )
767
760
768
761
769
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
762
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
770
763
def test_atan (x ):
771
764
out = xp .atan (x )
772
765
ph .assert_dtype ("atan" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -782,7 +775,7 @@ def test_atan2(x1, x2):
782
775
binary_assert_against_refimpl ("atan2" , x1 , x2 , out , math .atan2 )
783
776
784
777
785
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
778
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
786
779
def test_atanh (x ):
787
780
out = xp .atanh (x )
788
781
ph .assert_dtype ("atanh" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -932,15 +925,15 @@ def test_conj(x):
932
925
unary_assert_against_refimpl ("conj" , x , out , operator .methodcaller ("conjugate" ))
933
926
934
927
935
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
928
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
936
929
def test_cos (x ):
937
930
out = xp .cos (x )
938
931
ph .assert_dtype ("cos" , in_dtype = x .dtype , out_dtype = out .dtype )
939
932
ph .assert_shape ("cos" , out_shape = out .shape , expected = x .shape )
940
933
unary_assert_against_refimpl ("cos" , x , out , math .cos )
941
934
942
935
943
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
936
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
944
937
def test_cosh (x ):
945
938
out = xp .cosh (x )
946
939
ph .assert_dtype ("cosh" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1001,15 +994,15 @@ def test_equal(ctx, data):
1001
994
)
1002
995
1003
996
1004
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
997
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1005
998
def test_exp (x ):
1006
999
out = xp .exp (x )
1007
1000
ph .assert_dtype ("exp" , in_dtype = x .dtype , out_dtype = out .dtype )
1008
1001
ph .assert_shape ("exp" , out_shape = out .shape , expected = x .shape )
1009
1002
unary_assert_against_refimpl ("exp" , x , out , math .exp )
1010
1003
1011
1004
1012
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1005
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1013
1006
def test_expm1 (x ):
1014
1007
out = xp .expm1 (x )
1015
1008
ph .assert_dtype ("expm1" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1158,7 +1151,7 @@ def test_less_equal(ctx, data):
1158
1151
)
1159
1152
1160
1153
1161
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1154
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1162
1155
def test_log (x ):
1163
1156
out = xp .log (x )
1164
1157
ph .assert_dtype ("log" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1168,7 +1161,7 @@ def test_log(x):
1168
1161
)
1169
1162
1170
1163
1171
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1164
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1172
1165
def test_log1p (x ):
1173
1166
out = xp .log1p (x )
1174
1167
ph .assert_dtype ("log1p" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1178,7 +1171,7 @@ def test_log1p(x):
1178
1171
)
1179
1172
1180
1173
1181
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1174
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1182
1175
def test_log2 (x ):
1183
1176
out = xp .log2 (x )
1184
1177
ph .assert_dtype ("log2" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1188,7 +1181,7 @@ def test_log2(x):
1188
1181
)
1189
1182
1190
1183
1191
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1184
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1192
1185
def test_log10 (x ):
1193
1186
out = xp .log10 (x )
1194
1187
ph .assert_dtype ("log10" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1379,15 +1372,15 @@ def test_sign(x):
1379
1372
)
1380
1373
1381
1374
1382
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1375
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1383
1376
def test_sin (x ):
1384
1377
out = xp .sin (x )
1385
1378
ph .assert_dtype ("sin" , in_dtype = x .dtype , out_dtype = out .dtype )
1386
1379
ph .assert_shape ("sin" , out_shape = out .shape , expected = x .shape )
1387
1380
unary_assert_against_refimpl ("sin" , x , out , math .sin )
1388
1381
1389
1382
1390
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1383
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1391
1384
def test_sinh (x ):
1392
1385
out = xp .sinh (x )
1393
1386
ph .assert_dtype ("sinh" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1405,7 +1398,7 @@ def test_square(x):
1405
1398
)
1406
1399
1407
1400
1408
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1401
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1409
1402
def test_sqrt (x ):
1410
1403
out = xp .sqrt (x )
1411
1404
ph .assert_dtype ("sqrt" , in_dtype = x .dtype , out_dtype = out .dtype )
@@ -1429,15 +1422,15 @@ def test_subtract(ctx, data):
1429
1422
binary_param_assert_against_refimpl (ctx , left , right , res , "-" , operator .sub )
1430
1423
1431
1424
1432
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1425
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1433
1426
def test_tan (x ):
1434
1427
out = xp .tan (x )
1435
1428
ph .assert_dtype ("tan" , in_dtype = x .dtype , out_dtype = out .dtype )
1436
1429
ph .assert_shape ("tan" , out_shape = out .shape , expected = x .shape )
1437
1430
unary_assert_against_refimpl ("tan" , x , out , math .tan )
1438
1431
1439
1432
1440
- @given (xps .arrays (dtype = all_floating_dtypes (), shape = hh .shapes ()))
1433
+ @given (xps .arrays (dtype = hh . all_floating_dtypes (), shape = hh .shapes ()))
1441
1434
def test_tanh (x ):
1442
1435
out = xp .tanh (x )
1443
1436
ph .assert_dtype ("tanh" , in_dtype = x .dtype , out_dtype = out .dtype )
0 commit comments