Skip to content

Commit 72919ed

Browse files
committed
Merge branch 'main' into lithomas1-fix-dask
2 parents 2e4c796 + 40603a9 commit 72919ed

25 files changed

+466
-1100
lines changed

.github/workflows/ruff.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
pip install ruff
1717
# Update output format to enable automatic inline annotations.
1818
- name: Run Ruff
19-
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
19+
run: ruff check --output-format=github .

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,24 @@ namespace, except that functions that are part of the array API are wrapped so
7171
that they have the correct array API behavior. In each case, the array object
7272
used will be the same array object from the wrapped library.
7373

74-
## Difference between `array_api_compat` and `numpy.array_api`
74+
## Difference between `array_api_compat` and `array_api_strict`
7575

76-
`numpy.array_api` is a strict minimal implementation of the Array API (see
76+
`array_api_strict` is a strict minimal implementation of the array API standard, formerly
77+
known as `numpy.array_api` (see
7778
[NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For
78-
example, `numpy.array_api` does not include any functions that are not part of
79+
example, `array_api_strict` does not include any functions that are not part of
7980
the array API specification, and will explicitly disallow behaviors that are
8081
not required by the spec (e.g., [cross-kind type
8182
promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)).
82-
(`cupy.array_api` is similar to `numpy.array_api`)
83+
(`cupy.array_api` is similar to `array_api_strict`)
8384

8485
`array_api_compat`, on the other hand, is just an extension of the
8586
corresponding array library namespaces with changes needed to be compliant
8687
with the array API. It includes all additional library functions not mentioned
8788
in the spec, and allows any library behaviors not explicitly disallowed by it,
8889
such as cross-kind casting.
8990

90-
In particular, unlike `numpy.array_api`, this package does not use a separate
91+
In particular, unlike `array_api_strict`, this package does not use a separate
9192
`Array` object, but rather just uses the corresponding array library array
9293
objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This
9394
is because those are the objects that are going to be passed as inputs to
@@ -96,7 +97,7 @@ functions by end users. This does mean that a few behaviors cannot be wrapped
9697
most things.
9798

9899
Array consuming library authors coding against the array API may wish to test
99-
against `numpy.array_api` to ensure they are not using functionality outside
100+
against `array_api_strict` to ensure they are not using functionality outside
100101
of the standard, but prefer this implementation for the default behavior for
101102
end-users.
102103

array_api_compat/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
66
https://numpy.org/neps/nep-0047-array-api-standard.html.
77
8-
Unlike numpy.array_api, this is not a strict minimal implementation of the
8+
Unlike array_api_strict, this is not a strict minimal implementation of the
99
Array API, but rather just an extension of the main NumPy namespace with
1010
changes needed to be compliant with the Array API. See
1111
https://numpy.org/doc/stable/reference/array_api.html for a full list of
12-
changes. In particular, unlike numpy.array_api, this package does not use a
12+
changes. In particular, unlike array_api_strict, this package does not use a
1313
separate Array object, but rather just uses numpy.ndarray directly.
1414
15-
Library authors using the Array API may wish to test against numpy.array_api
15+
Library authors using the Array API may wish to test against array_api_strict
1616
to ensure they are not using functionality outside of the standard, but prefer
1717
this implementation for the default when working with NumPy arrays.
1818

array_api_compat/_internal.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import wraps
66
from inspect import signature
77

8-
98
def get_xp(xp):
109
"""
1110
Decorator to automatically replace xp with the corresponding array module.
@@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
4544
return wrapped_f
4645

4746
return inner
48-
49-
50-
def _get_all_public_members(module, exclude=None, extend_all=False):
51-
"""Get all public members of a module.
52-
53-
Parameters
54-
----------
55-
module : module
56-
The module to get members from.
57-
exclude : callable, optional
58-
A callable that takes a name and returns True if the name should be
59-
excluded from the list of members.
60-
extend_all : bool, optional
61-
If True, extend the module's __all__ attribute with the members of the
62-
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
63-
"""
64-
members = getattr(module, "__all__", [])
65-
66-
if members and not extend_all:
67-
return members
68-
69-
if exclude is None:
70-
exclude = lambda name: name.startswith("_") # noqa: E731
71-
72-
members = members + [_ for _ in dir(module) if not exclude(_)]
73-
74-
# remove duplicates
75-
return list(set(members))

array_api_compat/common/__init__.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1 @@
1-
from ._helpers import (
2-
array_namespace,
3-
device,
4-
get_namespace,
5-
is_array_api_obj,
6-
is_cupy_array,
7-
is_dask_array,
8-
is_jax_array,
9-
is_numpy_array,
10-
is_torch_array,
11-
size,
12-
to_device,
13-
)
14-
15-
__all__ = [
16-
"array_namespace",
17-
"device",
18-
"get_namespace",
19-
"is_array_api_obj",
20-
"is_cupy_array",
21-
"is_dask_array",
22-
"is_jax_array",
23-
"is_numpy_array",
24-
"is_torch_array",
25-
"size",
26-
"to_device",
27-
]
1+
from ._helpers import * # noqa: F403

array_api_compat/common/_aliases.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def zeros_like(
146146

147147
# The functions here return namedtuples (np.unique() returns a normal
148148
# tuple).
149+
150+
# Note that these named tuples aren't actually part of the standard namespace,
151+
# but I don't see any issue with exporting the names here regardless.
149152
class UniqueAllResult(NamedTuple):
150153
values: ndarray
151154
indices: ndarray
@@ -543,5 +546,13 @@ def isdtype(
543546
# This will allow things that aren't required by the spec, like
544547
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
545548
# more strict here to match the type annotation? Note that the
546-
# numpy.array_api implementation will be very strict.
549+
# array_api_strict implementation will be very strict.
547550
return dtype == kind
551+
552+
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
553+
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
554+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
555+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
556+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
557+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
558+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,19 @@ def size(x):
307307
if None in x.shape:
308308
return None
309309
return math.prod(x.shape)
310+
311+
__all__ = [
312+
"array_namespace",
313+
"device",
314+
"get_namespace",
315+
"is_array_api_obj",
316+
"is_cupy_array",
317+
"is_dask_array",
318+
"is_jax_array",
319+
"is_numpy_array",
320+
"is_torch_array",
321+
"size",
322+
"to_device",
323+
]
324+
325+
_all_ignore = ['sys', 'math', 'inspect']

array_api_compat/common/_linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matrix_transpose, isdtype
14+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
149149
dtype = xp.float64
150150
elif x.dtype == xp.complex64:
151151
dtype = xp.complex128
152-
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
152+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
153+
154+
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
155+
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
156+
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
157+
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
158+
'trace']

array_api_compat/common/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
2020
SupportsBufferProtocol = Any
2121

2222
Array = Any
23-
Device = Any
23+
Device = Any

array_api_compat/cupy/__init__.py

Lines changed: 7 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,14 @@
1-
import cupy as _cp
2-
from cupy import * # noqa: F401, F403
1+
from cupy import * # noqa: F403
32

43
# from cupy import * doesn't overwrite these builtin names
5-
from cupy import abs, max, min, round
6-
7-
from .._internal import _get_all_public_members
8-
from ..common._helpers import (
9-
array_namespace,
10-
device,
11-
get_namespace,
12-
is_array_api_obj,
13-
size,
14-
to_device,
15-
)
4+
from cupy import abs, max, min, round # noqa: F401
165

176
# These imports may overwrite names from the import * above.
18-
from ._aliases import (
19-
UniqueAllResult,
20-
UniqueCountsResult,
21-
UniqueInverseResult,
22-
acos,
23-
acosh,
24-
arange,
25-
argsort,
26-
asarray,
27-
asarray_cupy,
28-
asin,
29-
asinh,
30-
astype,
31-
atan,
32-
atan2,
33-
atanh,
34-
bitwise_invert,
35-
bitwise_left_shift,
36-
bitwise_right_shift,
37-
bool,
38-
ceil,
39-
concat,
40-
empty,
41-
empty_like,
42-
eye,
43-
floor,
44-
full,
45-
full_like,
46-
isdtype,
47-
linspace,
48-
matmul,
49-
matrix_transpose,
50-
nonzero,
51-
ones,
52-
ones_like,
53-
permute_dims,
54-
pow,
55-
prod,
56-
reshape,
57-
sort,
58-
std,
59-
sum,
60-
tensordot,
61-
trunc,
62-
unique_all,
63-
unique_counts,
64-
unique_inverse,
65-
unique_values,
66-
var,
67-
vecdot,
68-
zeros,
69-
zeros_like,
70-
)
71-
72-
__all__ = []
73-
74-
__all__ += _get_all_public_members(_cp)
75-
76-
__all__ += [
77-
"abs",
78-
"max",
79-
"min",
80-
"round",
81-
]
82-
83-
__all__ += [
84-
"array_namespace",
85-
"device",
86-
"get_namespace",
87-
"is_array_api_obj",
88-
"size",
89-
"to_device",
90-
]
91-
92-
__all__ += [
93-
"UniqueAllResult",
94-
"UniqueCountsResult",
95-
"UniqueInverseResult",
96-
"acos",
97-
"acosh",
98-
"arange",
99-
"argsort",
100-
"asarray",
101-
"asarray_cupy",
102-
"asin",
103-
"asinh",
104-
"astype",
105-
"atan",
106-
"atan2",
107-
"atanh",
108-
"bitwise_invert",
109-
"bitwise_left_shift",
110-
"bitwise_right_shift",
111-
"bool",
112-
"ceil",
113-
"concat",
114-
"empty",
115-
"empty_like",
116-
"eye",
117-
"floor",
118-
"full",
119-
"full_like",
120-
"isdtype",
121-
"linspace",
122-
"matmul",
123-
"matrix_transpose",
124-
"nonzero",
125-
"ones",
126-
"ones_like",
127-
"permute_dims",
128-
"pow",
129-
"prod",
130-
"reshape",
131-
"sort",
132-
"std",
133-
"sum",
134-
"tensordot",
135-
"trunc",
136-
"unique_all",
137-
"unique_counts",
138-
"unique_inverse",
139-
"unique_values",
140-
"var",
141-
"zeros",
142-
"zeros_like",
143-
]
144-
145-
__all__ += [
146-
"matrix_transpose",
147-
"vecdot",
148-
]
7+
from ._aliases import * # noqa: F403
1498

1509
# See the comment in the numpy __init__.py
151-
__import__(__package__ + ".linalg")
10+
__import__(__package__ + '.linalg')
11+
12+
from ..common._helpers import * # noqa: F401,F403
15213

153-
__array_api_version__ = "2022.12"
14+
__array_api_version__ = '2022.12'

0 commit comments

Comments
 (0)