@@ -709,19 +709,13 @@ def test_equal(
709
709
_left = ah .asarray (_left , dtype = promoted_dtype )
710
710
_right = ah .asarray (_right , dtype = promoted_dtype )
711
711
712
- if dh .is_int_dtype (promoted_dtype ):
713
- scalar_func = int
714
- elif dh .is_float_dtype (promoted_dtype ):
715
- scalar_func = float
716
- else :
717
- scalar_func = bool
718
-
712
+ scalar_type = dh .get_scalar_type (promoted_dtype )
719
713
for idx in ah .ndindex (shape ):
720
714
x1_idx = _left [idx ]
721
715
x2_idx = _right [idx ]
722
716
out_idx = out [idx ]
723
717
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
724
- assert bool (out_idx ) == (scalar_func (x1_idx ) == scalar_func (x2_idx ))
718
+ assert bool (out_idx ) == (scalar_type (x1_idx ) == scalar_type (x2_idx ))
725
719
726
720
727
721
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -840,18 +834,13 @@ def test_greater(
840
834
_left = ah .asarray (_left , dtype = promoted_dtype )
841
835
_right = ah .asarray (_right , dtype = promoted_dtype )
842
836
843
- if dh .is_int_dtype (promoted_dtype ):
844
- scalar_func = int
845
- elif dh .is_float_dtype (promoted_dtype ):
846
- scalar_func = float
847
- else :
848
- scalar_func = bool
837
+ scalar_type = dh .get_scalar_type (promoted_dtype )
849
838
for idx in ah .ndindex (shape ):
850
839
out_idx = out [idx ]
851
840
x1_idx = _left [idx ]
852
841
x2_idx = _right [idx ]
853
842
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
854
- assert bool (out_idx ) == (scalar_func (x1_idx ) > scalar_func (x2_idx ))
843
+ assert bool (out_idx ) == (scalar_type (x1_idx ) > scalar_type (x2_idx ))
855
844
856
845
857
846
@pytest .mark .parametrize (
@@ -886,18 +875,13 @@ def test_greater_equal(
886
875
_left = ah .asarray (_left , dtype = promoted_dtype )
887
876
_right = ah .asarray (_right , dtype = promoted_dtype )
888
877
889
- if dh .is_int_dtype (promoted_dtype ):
890
- scalar_func = int
891
- elif dh .is_float_dtype (promoted_dtype ):
892
- scalar_func = float
893
- else :
894
- scalar_func = bool
878
+ scalar_type = dh .get_scalar_type (promoted_dtype )
895
879
for idx in ah .ndindex (shape ):
896
880
out_idx = out [idx ]
897
881
x1_idx = _left [idx ]
898
882
x2_idx = _right [idx ]
899
883
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
900
- assert bool (out_idx ) == (scalar_func (x1_idx ) >= scalar_func (x2_idx ))
884
+ assert bool (out_idx ) == (scalar_type (x1_idx ) >= scalar_type (x2_idx ))
901
885
902
886
903
887
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -983,19 +967,13 @@ def test_less(
983
967
_left = ah .asarray (_left , dtype = promoted_dtype )
984
968
_right = ah .asarray (_right , dtype = promoted_dtype )
985
969
986
- if dh .is_int_dtype (promoted_dtype ):
987
- scalar_func = int
988
- elif dh .is_float_dtype (promoted_dtype ):
989
- scalar_func = float
990
- else :
991
- scalar_func = bool
992
-
970
+ scalar_type = dh .get_scalar_type (promoted_dtype )
993
971
for idx in ah .ndindex (shape ):
994
972
x1_idx = _left [idx ]
995
973
x2_idx = _right [idx ]
996
974
out_idx = out [idx ]
997
975
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
998
- assert bool (out_idx ) == (scalar_func (x1_idx ) < scalar_func (x2_idx ))
976
+ assert bool (out_idx ) == (scalar_type (x1_idx ) < scalar_type (x2_idx ))
999
977
1000
978
1001
979
@pytest .mark .parametrize (
@@ -1030,19 +1008,13 @@ def test_less_equal(
1030
1008
_left = ah .asarray (_left , dtype = promoted_dtype )
1031
1009
_right = ah .asarray (_right , dtype = promoted_dtype )
1032
1010
1033
- if dh .is_int_dtype (promoted_dtype ):
1034
- scalar_func = int
1035
- elif dh .is_float_dtype (promoted_dtype ):
1036
- scalar_func = float
1037
- else :
1038
- scalar_func = bool
1039
-
1011
+ scalar_type = dh .get_scalar_type (promoted_dtype )
1040
1012
for idx in ah .ndindex (shape ):
1041
1013
x1_idx = _left [idx ]
1042
1014
x2_idx = _right [idx ]
1043
1015
out_idx = out [idx ]
1044
1016
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1045
- assert bool (out_idx ) == (scalar_func (x1_idx ) <= scalar_func (x2_idx ))
1017
+ assert bool (out_idx ) == (scalar_type (x1_idx ) <= scalar_type (x2_idx ))
1046
1018
1047
1019
1048
1020
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -1241,19 +1213,13 @@ def test_not_equal(
1241
1213
_left = ah .asarray (_left , dtype = promoted_dtype )
1242
1214
_right = ah .asarray (_right , dtype = promoted_dtype )
1243
1215
1244
- if dh .is_int_dtype (promoted_dtype ):
1245
- scalar_func = int
1246
- elif dh .is_float_dtype (promoted_dtype ):
1247
- scalar_func = float
1248
- else :
1249
- scalar_func = bool
1250
-
1216
+ scalar_type = dh .get_scalar_type (promoted_dtype )
1251
1217
for idx in ah .ndindex (shape ):
1252
1218
out_idx = out [idx ]
1253
1219
x1_idx = _left [idx ]
1254
1220
x2_idx = _right [idx ]
1255
1221
assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1256
- assert bool (out_idx ) == (scalar_func (x1_idx ) != scalar_func (x2_idx ))
1222
+ assert bool (out_idx ) == (scalar_type (x1_idx ) != scalar_type (x2_idx ))
1257
1223
1258
1224
1259
1225
@pytest .mark .parametrize (
0 commit comments