|
7 | 7 | more details.
|
8 | 8 |
|
9 | 9 | """
|
| 10 | +from numpy import bool_ as bool |
10 | 11 | from numpy import (
|
| 12 | + complex64, |
| 13 | + complex128, |
11 | 14 | dtype,
|
12 |
| - bool_ as bool, |
13 |
| - intp, |
| 15 | + float32, |
| 16 | + float64, |
14 | 17 | int8,
|
15 | 18 | int16,
|
16 | 19 | int32,
|
17 | 20 | int64,
|
| 21 | + intp, |
18 | 22 | uint8,
|
19 | 23 | uint16,
|
20 | 24 | uint32,
|
21 | 25 | uint64,
|
22 |
| - float32, |
23 |
| - float64, |
24 |
| - complex64, |
25 |
| - complex128, |
26 | 26 | )
|
27 | 27 |
|
| 28 | +from ._typing import Device, DType |
| 29 | + |
28 | 30 |
|
29 | 31 | class __array_namespace_info__:
|
30 | 32 | """
|
@@ -131,7 +133,11 @@ def default_device(self):
|
131 | 133 | """
|
132 | 134 | return "cpu"
|
133 | 135 |
|
134 |
| - def default_dtypes(self, *, device=None): |
| 136 | + def default_dtypes( |
| 137 | + self, |
| 138 | + *, |
| 139 | + device: Device | None = None, |
| 140 | + ) -> dict[str, dtype[intp | float64 | complex128]]: |
135 | 141 | """
|
136 | 142 | The default data types used for new NumPy arrays.
|
137 | 143 |
|
@@ -183,7 +189,12 @@ def default_dtypes(self, *, device=None):
|
183 | 189 | "indexing": dtype(intp),
|
184 | 190 | }
|
185 | 191 |
|
186 |
| - def dtypes(self, *, device=None, kind=None): |
| 192 | + def dtypes( |
| 193 | + self, |
| 194 | + *, |
| 195 | + device: Device | None = None, |
| 196 | + kind: str | tuple[str, ...] | None = None, |
| 197 | + ) -> dict[str, DType]: |
187 | 198 | """
|
188 | 199 | The array API data types supported by NumPy.
|
189 | 200 |
|
@@ -260,7 +271,7 @@ def dtypes(self, *, device=None, kind=None):
|
260 | 271 | "complex128": dtype(complex128),
|
261 | 272 | }
|
262 | 273 | if kind == "bool":
|
263 |
| - return {"bool": bool} |
| 274 | + return {"bool": dtype(bool)} |
264 | 275 | if kind == "signed integer":
|
265 | 276 | return {
|
266 | 277 | "int8": dtype(int8),
|
@@ -312,13 +323,13 @@ def dtypes(self, *, device=None, kind=None):
|
312 | 323 | "complex128": dtype(complex128),
|
313 | 324 | }
|
314 | 325 | if isinstance(kind, tuple):
|
315 |
| - res = {} |
| 326 | + res: dict[str, DType] = {} |
316 | 327 | for k in kind:
|
317 | 328 | res.update(self.dtypes(kind=k))
|
318 | 329 | return res
|
319 | 330 | raise ValueError(f"unsupported kind: {kind!r}")
|
320 | 331 |
|
321 |
| - def devices(self): |
| 332 | + def devices(self) -> list[Device]: |
322 | 333 | """
|
323 | 334 | The devices supported by NumPy.
|
324 | 335 |
|
|
0 commit comments