Skip to content

Commit b18b4e3

Browse files
committed
verbose Python scalar types
1 parent d855dce commit b18b4e3

File tree

8 files changed

+70
-57
lines changed

8 files changed

+70
-57
lines changed

array_api_strict/_array_object.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __array__(
199199
# NumPy behavior
200200

201201
def _check_allowed_dtypes(
202-
self, other: Array | complex, dtype_category: str, op: str
202+
self, other: Array | bool | int | float | complex, dtype_category: str, op: str
203203
) -> Array:
204204
"""
205205
Helper function for operators to only allow specific input dtypes
@@ -241,7 +241,7 @@ def _check_allowed_dtypes(
241241

242242
return other
243243

244-
def _check_device(self, other: Array | complex) -> None:
244+
def _check_device(self, other: Array | bool | int | float | complex) -> None:
245245
"""Check that other is on a device compatible with the current array"""
246246
if isinstance(other, (bool, int, float, complex)):
247247
return
@@ -252,7 +252,7 @@ def _check_device(self, other: Array | complex) -> None:
252252
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
253253

254254
# Helper function to match the type promotion rules in the spec
255-
def _promote_scalar(self, scalar: complex) -> Array:
255+
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
256256
"""
257257
Returns a promoted version of a Python scalar appropriate for use with
258258
operations on self.
@@ -546,7 +546,7 @@ def __abs__(self) -> Array:
546546
res = self._array.__abs__()
547547
return self.__class__._new(res, device=self.device)
548548

549-
def __add__(self, other: Array | complex, /) -> Array:
549+
def __add__(self, other: Array | int | float | complex, /) -> Array:
550550
"""
551551
Performs the operation __add__.
552552
"""
@@ -558,7 +558,7 @@ def __add__(self, other: Array | complex, /) -> Array:
558558
res = self._array.__add__(other._array)
559559
return self.__class__._new(res, device=self.device)
560560

561-
def __and__(self, other: Array | int, /) -> Array:
561+
def __and__(self, other: Array | bool | int, /) -> Array:
562562
"""
563563
Performs the operation __and__.
564564
"""
@@ -655,7 +655,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
655655
# Note: device support is required for this
656656
return self._array.__dlpack_device__()
657657

658-
def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
658+
def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
659659
"""
660660
Performs the operation __eq__.
661661
"""
@@ -681,7 +681,7 @@ def __float__(self) -> float:
681681
res = self._array.__float__()
682682
return res
683683

684-
def __floordiv__(self, other: Array | float, /) -> Array:
684+
def __floordiv__(self, other: Array | int | float, /) -> Array:
685685
"""
686686
Performs the operation __floordiv__.
687687
"""
@@ -693,7 +693,7 @@ def __floordiv__(self, other: Array | float, /) -> Array:
693693
res = self._array.__floordiv__(other._array)
694694
return self.__class__._new(res, device=self.device)
695695

696-
def __ge__(self, other: Array | float, /) -> Array:
696+
def __ge__(self, other: Array | int | float, /) -> Array:
697697
"""
698698
Performs the operation __ge__.
699699
"""
@@ -729,7 +729,7 @@ def __getitem__(
729729
res = self._array.__getitem__(np_key)
730730
return self._new(res, device=self.device)
731731

732-
def __gt__(self, other: Array | float, /) -> Array:
732+
def __gt__(self, other: Array | int | float, /) -> Array:
733733
"""
734734
Performs the operation __gt__.
735735
"""
@@ -784,7 +784,7 @@ def __iter__(self) -> Iterator[Array]:
784784
# implemented, which implies iteration on 1-D arrays.
785785
return (Array._new(i, device=self.device) for i in self._array)
786786

787-
def __le__(self, other: Array | float, /) -> Array:
787+
def __le__(self, other: Array | int | float, /) -> Array:
788788
"""
789789
Performs the operation __le__.
790790
"""
@@ -808,7 +808,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
808808
res = self._array.__lshift__(other._array)
809809
return self.__class__._new(res, device=self.device)
810810

811-
def __lt__(self, other: Array | float, /) -> Array:
811+
def __lt__(self, other: Array | int | float, /) -> Array:
812812
"""
813813
Performs the operation __lt__.
814814
"""
@@ -833,7 +833,7 @@ def __matmul__(self, other: Array, /) -> Array:
833833
res = self._array.__matmul__(other._array)
834834
return self.__class__._new(res, device=self.device)
835835

836-
def __mod__(self, other: Array | float, /) -> Array:
836+
def __mod__(self, other: Array | int | float, /) -> Array:
837837
"""
838838
Performs the operation __mod__.
839839
"""
@@ -845,7 +845,7 @@ def __mod__(self, other: Array | float, /) -> Array:
845845
res = self._array.__mod__(other._array)
846846
return self.__class__._new(res, device=self.device)
847847

848-
def __mul__(self, other: Array | complex, /) -> Array:
848+
def __mul__(self, other: Array | int | float | complex, /) -> Array:
849849
"""
850850
Performs the operation __mul__.
851851
"""
@@ -857,7 +857,7 @@ def __mul__(self, other: Array | complex, /) -> Array:
857857
res = self._array.__mul__(other._array)
858858
return self.__class__._new(res, device=self.device)
859859

860-
def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
860+
def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
861861
"""
862862
Performs the operation __ne__.
863863
"""
@@ -878,7 +878,7 @@ def __neg__(self) -> Array:
878878
res = self._array.__neg__()
879879
return self.__class__._new(res, device=self.device)
880880

881-
def __or__(self, other: Array | int, /) -> Array:
881+
def __or__(self, other: Array | bool | int, /) -> Array:
882882
"""
883883
Performs the operation __or__.
884884
"""
@@ -899,7 +899,7 @@ def __pos__(self) -> Array:
899899
res = self._array.__pos__()
900900
return self.__class__._new(res, device=self.device)
901901

902-
def __pow__(self, other: Array | complex, /) -> Array:
902+
def __pow__(self, other: Array | int | float | complex, /) -> Array:
903903
"""
904904
Performs the operation __pow__.
905905
"""
@@ -936,7 +936,7 @@ def __setitem__(
936936
| Array
937937
| tuple[int | slice | EllipsisType, ...]
938938
),
939-
value: Array | complex,
939+
value: Array | bool | int | float | complex,
940940
/,
941941
) -> None:
942942
"""
@@ -949,7 +949,7 @@ def __setitem__(
949949
np_key = key._array if isinstance(key, Array) else key
950950
self._array.__setitem__(np_key, asarray(value)._array)
951951

952-
def __sub__(self, other: Array | complex, /) -> Array:
952+
def __sub__(self, other: Array | int | float | complex, /) -> Array:
953953
"""
954954
Performs the operation __sub__.
955955
"""
@@ -963,7 +963,7 @@ def __sub__(self, other: Array | complex, /) -> Array:
963963

964964
# PEP 484 requires int to be a subtype of float, but __truediv__ should
965965
# not accept int.
966-
def __truediv__(self, other: Array | complex, /) -> Array:
966+
def __truediv__(self, other: Array | int | float | complex, /) -> Array:
967967
"""
968968
Performs the operation __truediv__.
969969
"""
@@ -975,7 +975,7 @@ def __truediv__(self, other: Array | complex, /) -> Array:
975975
res = self._array.__truediv__(other._array)
976976
return self.__class__._new(res, device=self.device)
977977

978-
def __xor__(self, other: Array | int, /) -> Array:
978+
def __xor__(self, other: Array | bool | int, /) -> Array:
979979
"""
980980
Performs the operation __xor__.
981981
"""
@@ -987,7 +987,7 @@ def __xor__(self, other: Array | int, /) -> Array:
987987
res = self._array.__xor__(other._array)
988988
return self.__class__._new(res, device=self.device)
989989

990-
def __iadd__(self, other: Array | complex, /) -> Array:
990+
def __iadd__(self, other: Array | int | float | complex, /) -> Array:
991991
"""
992992
Performs the operation __iadd__.
993993
"""
@@ -998,7 +998,7 @@ def __iadd__(self, other: Array | complex, /) -> Array:
998998
self._array.__iadd__(other._array)
999999
return self
10001000

1001-
def __radd__(self, other: Array | complex, /) -> Array:
1001+
def __radd__(self, other: Array | int | float | complex, /) -> Array:
10021002
"""
10031003
Performs the operation __radd__.
10041004
"""
@@ -1010,7 +1010,7 @@ def __radd__(self, other: Array | complex, /) -> Array:
10101010
res = self._array.__radd__(other._array)
10111011
return self.__class__._new(res, device=self.device)
10121012

1013-
def __iand__(self, other: Array | int, /) -> Array:
1013+
def __iand__(self, other: Array | bool | int, /) -> Array:
10141014
"""
10151015
Performs the operation __iand__.
10161016
"""
@@ -1021,7 +1021,7 @@ def __iand__(self, other: Array | int, /) -> Array:
10211021
self._array.__iand__(other._array)
10221022
return self
10231023

1024-
def __rand__(self, other: Array | int, /) -> Array:
1024+
def __rand__(self, other: Array | bool | int, /) -> Array:
10251025
"""
10261026
Performs the operation __rand__.
10271027
"""
@@ -1033,7 +1033,7 @@ def __rand__(self, other: Array | int, /) -> Array:
10331033
res = self._array.__rand__(other._array)
10341034
return self.__class__._new(res, device=self.device)
10351035

1036-
def __ifloordiv__(self, other: Array | float, /) -> Array:
1036+
def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10371037
"""
10381038
Performs the operation __ifloordiv__.
10391039
"""
@@ -1044,7 +1044,7 @@ def __ifloordiv__(self, other: Array | float, /) -> Array:
10441044
self._array.__ifloordiv__(other._array)
10451045
return self
10461046

1047-
def __rfloordiv__(self, other: Array | float, /) -> Array:
1047+
def __rfloordiv__(self, other: Array | int | float, /) -> Array:
10481048
"""
10491049
Performs the operation __rfloordiv__.
10501050
"""
@@ -1105,7 +1105,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11051105
res = self._array.__rmatmul__(other._array)
11061106
return self.__class__._new(res, device=self.device)
11071107

1108-
def __imod__(self, other: Array | float, /) -> Array:
1108+
def __imod__(self, other: Array | int | float, /) -> Array:
11091109
"""
11101110
Performs the operation __imod__.
11111111
"""
@@ -1115,7 +1115,7 @@ def __imod__(self, other: Array | float, /) -> Array:
11151115
self._array.__imod__(other._array)
11161116
return self
11171117

1118-
def __rmod__(self, other: Array | float, /) -> Array:
1118+
def __rmod__(self, other: Array | int | float, /) -> Array:
11191119
"""
11201120
Performs the operation __rmod__.
11211121
"""
@@ -1127,7 +1127,7 @@ def __rmod__(self, other: Array | float, /) -> Array:
11271127
res = self._array.__rmod__(other._array)
11281128
return self.__class__._new(res, device=self.device)
11291129

1130-
def __imul__(self, other: Array | complex, /) -> Array:
1130+
def __imul__(self, other: Array | int | float | complex, /) -> Array:
11311131
"""
11321132
Performs the operation __imul__.
11331133
"""
@@ -1137,7 +1137,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
11371137
self._array.__imul__(other._array)
11381138
return self
11391139

1140-
def __rmul__(self, other: Array | complex, /) -> Array:
1140+
def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11411141
"""
11421142
Performs the operation __rmul__.
11431143
"""
@@ -1149,7 +1149,7 @@ def __rmul__(self, other: Array | complex, /) -> Array:
11491149
res = self._array.__rmul__(other._array)
11501150
return self.__class__._new(res, device=self.device)
11511151

1152-
def __ior__(self, other: Array | int, /) -> Array:
1152+
def __ior__(self, other: Array | bool | int, /) -> Array:
11531153
"""
11541154
Performs the operation __ior__.
11551155
"""
@@ -1159,7 +1159,7 @@ def __ior__(self, other: Array | int, /) -> Array:
11591159
self._array.__ior__(other._array)
11601160
return self
11611161

1162-
def __ror__(self, other: Array | int, /) -> Array:
1162+
def __ror__(self, other: Array | bool | int, /) -> Array:
11631163
"""
11641164
Performs the operation __ror__.
11651165
"""
@@ -1171,7 +1171,7 @@ def __ror__(self, other: Array | int, /) -> Array:
11711171
res = self._array.__ror__(other._array)
11721172
return self.__class__._new(res, device=self.device)
11731173

1174-
def __ipow__(self, other: Array | complex, /) -> Array:
1174+
def __ipow__(self, other: Array | int | float | complex, /) -> Array:
11751175
"""
11761176
Performs the operation __ipow__.
11771177
"""
@@ -1181,7 +1181,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
11811181
self._array.__ipow__(other._array)
11821182
return self
11831183

1184-
def __rpow__(self, other: Array | complex, /) -> Array:
1184+
def __rpow__(self, other: Array | int | float | complex, /) -> Array:
11851185
"""
11861186
Performs the operation __rpow__.
11871187
"""
@@ -1216,7 +1216,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12161216
res = self._array.__rrshift__(other._array)
12171217
return self.__class__._new(res, device=self.device)
12181218

1219-
def __isub__(self, other: Array | complex, /) -> Array:
1219+
def __isub__(self, other: Array | int | float | complex, /) -> Array:
12201220
"""
12211221
Performs the operation __isub__.
12221222
"""
@@ -1226,7 +1226,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
12261226
self._array.__isub__(other._array)
12271227
return self
12281228

1229-
def __rsub__(self, other: Array | complex, /) -> Array:
1229+
def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12301230
"""
12311231
Performs the operation __rsub__.
12321232
"""
@@ -1238,7 +1238,7 @@ def __rsub__(self, other: Array | complex, /) -> Array:
12381238
res = self._array.__rsub__(other._array)
12391239
return self.__class__._new(res, device=self.device)
12401240

1241-
def __itruediv__(self, other: Array | complex, /) -> Array:
1241+
def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
12421242
"""
12431243
Performs the operation __itruediv__.
12441244
"""
@@ -1248,7 +1248,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
12481248
self._array.__itruediv__(other._array)
12491249
return self
12501250

1251-
def __rtruediv__(self, other: Array | complex, /) -> Array:
1251+
def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12521252
"""
12531253
Performs the operation __rtruediv__.
12541254
"""
@@ -1260,7 +1260,7 @@ def __rtruediv__(self, other: Array | complex, /) -> Array:
12601260
res = self._array.__rtruediv__(other._array)
12611261
return self.__class__._new(res, device=self.device)
12621262

1263-
def __ixor__(self, other: Array | int, /) -> Array:
1263+
def __ixor__(self, other: Array | bool | int, /) -> Array:
12641264
"""
12651265
Performs the operation __ixor__.
12661266
"""
@@ -1270,7 +1270,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
12701270
self._array.__ixor__(other._array)
12711271
return self
12721272

1273-
def __rxor__(self, other: Array | int, /) -> Array:
1273+
def __rxor__(self, other: Array | bool | int, /) -> Array:
12741274
"""
12751275
Performs the operation __rxor__.
12761276
"""

0 commit comments

Comments
 (0)