Skip to content

Commit 83fef36

Browse files
authored
Implement python array API Inspection namespace (#2275)
The PR proposes to implement `dpnp.__array_namespace_info__`, what is a python array API Inspection namespace. It is required to achieve the compliance with python array API. The implementation leverages on appropriate namespace exposed by dpctl.tensor. In addition, this PR makes an `__array_api_version__` attribute available in dpnp. It also borrowed from dpctl. The PR adds a dedication documentation page describing `Array API standard compatibility`, including reference on new `dpnp.__array_namespace_info__`.
1 parent b72f953 commit 83fef36

File tree

8 files changed

+249
-21
lines changed

8 files changed

+249
-21
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,6 @@ array_api_tests/test_signatures.py::test_func_signature[unique_counts]
1717
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
1818
array_api_tests/test_signatures.py::test_func_signature[unique_values]
1919

20-
# no '__array_namespace_info__' function
21-
array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__]
22-
array_api_tests/test_inspection_functions.py::test_array_namespace_info
23-
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
24-
array_api_tests/test_searching_functions.py::test_searchsorted
25-
array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__]
26-
array_api_tests/test_signatures.py::test_info_func_signature[capabilities]
27-
array_api_tests/test_signatures.py::test_info_func_signature[default_device]
28-
array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
29-
array_api_tests/test_signatures.py::test_info_func_signature[devices]
30-
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]
31-
3220
# do not return a namedtuple
3321
array_api_tests/test_linalg.py::test_eigh
3422
array_api_tests/test_linalg.py::test_slogdet

doc/reference/array_api.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
.. _array-api-standard-compatibility:
2+
3+
.. https://numpy.org/doc/stable/reference/array_api.html
4+
5+
********************************
6+
Array API standard compatibility
7+
********************************
8+
9+
DPNP's main namespace as well as the :mod:`dpnp.fft` and :mod:`dpnp.linalg`
10+
namespaces are compatible with the
11+
`2023.12 version <https://data-apis.org/array-api/2023.12/index.html>`__
12+
of the Python array API standard.
13+
14+
Inspection
15+
==========
16+
17+
DPNP implements the `array API inspection utilities
18+
<https://data-apis.org/array-api/latest/API_specification/inspection.html>`__.
19+
These functions can be accessed via the ``__array_namespace_info__()``
20+
function, which returns a namespace containing the inspection utilities.
21+
22+
.. autosummary::
23+
:toctree: generated/
24+
:nosignatures:
25+
26+
dpnp.__array_namespace_info__

doc/reference/fft.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
.. _routines.fft:
2+
3+
.. py:module:: dpnp.fft
4+
15
Discrete Fourier Transform
26
==========================
37

doc/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ API reference of the Data Parallel Extension for NumPy*
3333
dtypes_table
3434
comparison
3535
misc
36+
array_api

doc/reference/linalg.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
.. _routines.linalg:
2+
3+
.. py:module:: dpnp.linalg
4+
15
Linear algebra
26
==============
37

dpnp/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,17 @@
6363
)
6464

6565
# Borrowed from DPCTL
66-
from dpctl.tensor import DLDeviceType
66+
from dpctl.tensor import __array_api_version__, DLDeviceType
6767

68-
from dpnp.dpnp_array import dpnp_array as ndarray
69-
from dpnp.dpnp_flatiter import flatiter as flatiter
70-
from dpnp.dpnp_iface_types import *
71-
from dpnp.dpnp_iface import *
72-
from dpnp.dpnp_iface import __all__ as _iface__all__
73-
from dpnp.dpnp_iface_utils import *
74-
from dpnp.dpnp_iface_utils import __all__ as _ifaceutils__all__
75-
from dpnp._version import get_versions
68+
from .dpnp_array import dpnp_array as ndarray
69+
from .dpnp_array_api_info import __array_namespace_info__
70+
from .dpnp_flatiter import flatiter as flatiter
71+
from .dpnp_iface_types import *
72+
from .dpnp_iface import *
73+
from .dpnp_iface import __all__ as _iface__all__
74+
from .dpnp_iface_utils import *
75+
from .dpnp_iface_utils import __all__ as _ifaceutils__all__
76+
from ._version import get_versions
7677

7778
__all__ = _iface__all__
7879
__all__ += _ifaceutils__all__

dpnp/dpnp_array_api_info.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
# *****************************************************************************
3+
# Copyright (c) 2025, Intel Corporation
4+
# All rights reserved.
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
# - Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# - Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
24+
# THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
"""
28+
Array API Inspection namespace
29+
30+
This is the namespace for inspection functions as defined by the array API
31+
standard. See
32+
https://data-apis.org/array-api/latest/API_specification/inspection.html for
33+
more details.
34+
35+
"""
36+
37+
import dpctl.tensor as dpt
38+
39+
__all__ = ["__array_namespace_info__"]
40+
41+
42+
def __array_namespace_info__():
43+
"""
44+
Returns a namespace with Array API namespace inspection utilities.
45+
46+
The array API inspection namespace defines the following functions:
47+
48+
- capabilities()
49+
- default_device()
50+
- default_dtypes()
51+
- dtypes()
52+
- devices()
53+
54+
Returns
55+
-------
56+
info : ModuleType
57+
The array API inspection namespace for DPNP.
58+
59+
Examples
60+
--------
61+
>>> import dpnp as np
62+
>>> info = np.__array_namespace_info__()
63+
>>> info.default_dtypes() # may vary and depends on default device
64+
{'real floating': dtype('float64'),
65+
'complex floating': dtype('complex128'),
66+
'integral': dtype('int64'),
67+
'indexing': dtype('int64')}
68+
69+
"""
70+
71+
return dpt.__array_namespace_info__()

dpnp/tests/test_array_api_info.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy
2+
import pytest
3+
from dpctl import SyclDeviceCreationError, get_devices, select_default_device
4+
from dpctl.tensor._tensor_impl import default_device_complex_type
5+
6+
import dpnp
7+
from dpnp.tests.helper import (
8+
has_support_aspect64,
9+
is_win_platform,
10+
numpy_version,
11+
)
12+
13+
info = dpnp.__array_namespace_info__()
14+
default_device = select_default_device()
15+
16+
17+
def test_capabilities():
18+
caps = info.capabilities()
19+
assert caps["boolean indexing"] is True
20+
assert caps["data-dependent shapes"] is True
21+
assert caps["max dimensions"] == 64
22+
23+
24+
def test_default_device():
25+
assert info.default_device() == default_device
26+
27+
28+
def test_default_dtypes():
29+
dtypes = info.default_dtypes()
30+
assert (
31+
dtypes["real floating"]
32+
== dpnp.default_float_type()
33+
== dpnp.asarray(0.0).dtype
34+
)
35+
# TODO: add dpnp.default_complex_type() function
36+
assert (
37+
dtypes["complex floating"]
38+
== default_device_complex_type(default_device)
39+
== dpnp.asarray(0.0j).dtype
40+
)
41+
if not is_win_platform() or numpy_version() >= "2.0.0":
42+
# numpy changed default integer on Windows since 2.0
43+
assert dtypes["integral"] == dpnp.intp == dpnp.asarray(0).dtype
44+
assert (
45+
dtypes["indexing"] == dpnp.intp == dpnp.argmax(dpnp.zeros(10)).dtype
46+
)
47+
48+
with pytest.raises(
49+
TypeError, match="Unsupported type for device argument:"
50+
):
51+
info.default_dtypes(device=1)
52+
53+
54+
def test_dtypes_all():
55+
dtypes = info.dtypes()
56+
assert dtypes == (
57+
{
58+
"bool": dpnp.bool_,
59+
"int8": numpy.int8, # TODO: replace with dpnp.int8
60+
"int16": numpy.int16, # TODO: replace with dpnp.int16
61+
"int32": dpnp.int32,
62+
"int64": dpnp.int64,
63+
"uint8": numpy.uint8, # TODO: replace with dpnp.uint8
64+
"uint16": numpy.uint16, # TODO: replace with dpnp.uint16
65+
"uint32": numpy.uint32, # TODO: replace with dpnp.uint32
66+
"uint64": numpy.uint64, # TODO: replace with dpnp.uint64
67+
"float32": dpnp.float32,
68+
}
69+
| ({"float64": dpnp.float64} if has_support_aspect64() else {})
70+
| {"complex64": dpnp.complex64}
71+
| ({"complex128": dpnp.complex128} if has_support_aspect64() else {})
72+
)
73+
74+
75+
dtype_categories = {
76+
"bool": {"bool": dpnp.bool_},
77+
"signed integer": {
78+
"int8": numpy.int8, # TODO: replace with dpnp.int8
79+
"int16": numpy.int16, # TODO: replace with dpnp.int16
80+
"int32": dpnp.int32,
81+
"int64": dpnp.int64,
82+
},
83+
"unsigned integer": { # TODO: replace with dpnp dtypes once available
84+
"uint8": numpy.uint8,
85+
"uint16": numpy.uint16,
86+
"uint32": numpy.uint32,
87+
"uint64": numpy.uint64,
88+
},
89+
"integral": ("signed integer", "unsigned integer"),
90+
"real floating": {"float32": dpnp.float32}
91+
| ({"float64": dpnp.float64} if has_support_aspect64() else {}),
92+
"complex floating": {"complex64": dpnp.complex64}
93+
| ({"complex128": dpnp.complex128} if has_support_aspect64() else {}),
94+
"numeric": ("integral", "real floating", "complex floating"),
95+
}
96+
97+
98+
@pytest.mark.parametrize("kind", dtype_categories)
99+
def test_dtypes_kind(kind):
100+
expected = dtype_categories[kind]
101+
if isinstance(expected, tuple):
102+
assert info.dtypes(kind=kind) == info.dtypes(kind=expected)
103+
else:
104+
assert info.dtypes(kind=kind) == expected
105+
106+
107+
def test_dtypes_tuple():
108+
dtypes = info.dtypes(kind=("bool", "integral"))
109+
assert dtypes == {
110+
"bool": dpnp.bool_,
111+
"int8": numpy.int8, # TODO: replace with dpnp.int8
112+
"int16": numpy.int16, # TODO: replace with dpnp.int16
113+
"int32": dpnp.int32,
114+
"int64": dpnp.int64,
115+
"uint8": numpy.uint8, # TODO: replace with dpnp.uint8
116+
"uint16": numpy.uint16, # TODO: replace with dpnp.uint16
117+
"uint32": numpy.uint32, # TODO: replace with dpnp.uint32
118+
"uint64": numpy.uint64, # TODO: replace with dpnp.uint64
119+
}
120+
121+
122+
def test_dtypes_invalid_kind():
123+
with pytest.raises(ValueError, match="Unrecognized data type kind"):
124+
info.dtypes(kind="invalid")
125+
126+
127+
def test_dtypes_invalid_device():
128+
with pytest.raises(SyclDeviceCreationError, match="Could not create"):
129+
info.dtypes(device="str")
130+
131+
132+
def test_devices():
133+
assert info.devices() == get_devices()

0 commit comments

Comments
 (0)