Skip to content

Commit 1fb929b

Browse files
committed
TYP: fix typing errors in numpy._info
1 parent 18870dc commit 1fb929b

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

array_api_compat/numpy/_info.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,26 @@
77
more details.
88
99
"""
10+
from numpy import bool_ as bool
1011
from numpy import (
12+
complex64,
13+
complex128,
1114
dtype,
12-
bool_ as bool,
13-
intp,
15+
float32,
16+
float64,
1417
int8,
1518
int16,
1619
int32,
1720
int64,
21+
intp,
1822
uint8,
1923
uint16,
2024
uint32,
2125
uint64,
22-
float32,
23-
float64,
24-
complex64,
25-
complex128,
2626
)
2727

28+
from ._typing import Device, DType
29+
2830

2931
class __array_namespace_info__:
3032
"""
@@ -131,7 +133,11 @@ def default_device(self):
131133
"""
132134
return "cpu"
133135

134-
def default_dtypes(self, *, device=None):
136+
def default_dtypes(
137+
self,
138+
*,
139+
device: Device | None = None,
140+
) -> dict[str, dtype[intp | float64 | complex128]]:
135141
"""
136142
The default data types used for new NumPy arrays.
137143
@@ -183,7 +189,12 @@ def default_dtypes(self, *, device=None):
183189
"indexing": dtype(intp),
184190
}
185191

186-
def dtypes(self, *, device=None, kind=None):
192+
def dtypes(
193+
self,
194+
*,
195+
device: Device | None = None,
196+
kind: str | tuple[str, ...] | None = None,
197+
) -> dict[str, DType]:
187198
"""
188199
The array API data types supported by NumPy.
189200
@@ -260,7 +271,7 @@ def dtypes(self, *, device=None, kind=None):
260271
"complex128": dtype(complex128),
261272
}
262273
if kind == "bool":
263-
return {"bool": bool}
274+
return {"bool": dtype(bool)}
264275
if kind == "signed integer":
265276
return {
266277
"int8": dtype(int8),
@@ -312,13 +323,13 @@ def dtypes(self, *, device=None, kind=None):
312323
"complex128": dtype(complex128),
313324
}
314325
if isinstance(kind, tuple):
315-
res = {}
326+
res: dict[str, DType] = {}
316327
for k in kind:
317328
res.update(self.dtypes(kind=k))
318329
return res
319330
raise ValueError(f"unsupported kind: {kind!r}")
320331

321-
def devices(self):
332+
def devices(self) -> list[Device]:
322333
"""
323334
The devices supported by NumPy.
324335

0 commit comments

Comments
 (0)