Skip to content

Commit ae97845

Browse files
authored
Merge branch 'master' into tensor-asarray-support-for-usm-ndarray-protocol
2 parents 21d6cb6 + 6cc2348 commit ae97845

30 files changed

+1590
-72
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# array API tests to be skipped
2+
3+
# no 'uint8' dtype
4+
array_api_tests/test_array_object.py::test_getitem_masking
5+
6+
# no 'isdtype' function
7+
array_api_tests/test_data_type_functions.py::test_isdtype
8+
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
9+
array_api_tests/test_signatures.py::test_func_signature[isdtype]
10+
11+
# missing unique-like functions
12+
array_api_tests/test_has_names.py::test_has_names[set-unique_all]
13+
array_api_tests/test_has_names.py::test_has_names[set-unique_counts]
14+
array_api_tests/test_has_names.py::test_has_names[set-unique_inverse]
15+
array_api_tests/test_has_names.py::test_has_names[set-unique_values]
16+
array_api_tests/test_set_functions.py::test_unique_all
17+
array_api_tests/test_set_functions.py::test_unique_counts
18+
array_api_tests/test_set_functions.py::test_unique_inverse
19+
array_api_tests/test_set_functions.py::test_unique_values
20+
array_api_tests/test_signatures.py::test_func_signature[unique_all]
21+
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
22+
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
23+
array_api_tests/test_signatures.py::test_func_signature[unique_values]
24+
25+
# no '__array_namespace_info__' function
26+
array_api_tests/test_has_names.py::test_has_names[info-__array_namespace_info__]
27+
array_api_tests/test_inspection_functions.py::test_array_namespace_info
28+
array_api_tests/test_inspection_functions.py::test_array_namespace_info_dtypes
29+
array_api_tests/test_searching_functions.py::test_searchsorted
30+
array_api_tests/test_signatures.py::test_func_signature[__array_namespace_info__]
31+
array_api_tests/test_signatures.py::test_info_func_signature[capabilities]
32+
array_api_tests/test_signatures.py::test_info_func_signature[default_device]
33+
array_api_tests/test_signatures.py::test_info_func_signature[default_dtypes]
34+
array_api_tests/test_signatures.py::test_info_func_signature[devices]
35+
array_api_tests/test_signatures.py::test_info_func_signature[dtypes]
36+
37+
# do not return a namedtuple
38+
array_api_tests/test_linalg.py::test_eigh
39+
array_api_tests/test_linalg.py::test_slogdet
40+
array_api_tests/test_linalg.py::test_svd
41+
42+
# hypothesis found failures
43+
array_api_tests/test_linalg.py::test_qr
44+
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
45+
46+
# unexpected result is returned
47+
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
48+
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
49+
50+
# missing 'descending' keyword argument
51+
array_api_tests/test_signatures.py::test_func_signature[argsort]
52+
array_api_tests/test_signatures.py::test_func_signature[sort]
53+
array_api_tests/test_sorting_functions.py::test_argsort
54+
array_api_tests/test_sorting_functions.py::test_sort
55+
56+
# missing 'correction' keyword argument
57+
array_api_tests/test_signatures.py::test_func_signature[std]
58+
array_api_tests/test_signatures.py::test_func_signature[var]
59+
60+
# wrong shape is returned
61+
array_api_tests/test_linalg.py::test_vecdot
62+
array_api_tests/test_linalg.py::test_linalg_vecdot
63+
64+
# tuple index out of range
65+
array_api_tests/test_linalg.py::test_linalg_matmul
66+
67+
# arrays have different values
68+
array_api_tests/test_linalg.py::test_linalg_tensordot

.github/workflows/conda-package.yml

Lines changed: 186 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ jobs:
210210

211211
- name: Smoke test
212212
run: |
213-
python -c "import dpnp, dpctl; dpctl.lsplatform()"
213+
python -c "import dpctl; dpctl.lsplatform()"
214214
python -c "import dpnp; print(dpnp.__version__)"
215215
216216
- name: Run tests
@@ -350,7 +350,7 @@ jobs:
350350
351351
- name: Smoke test
352352
run: |
353-
python -c "import dpnp, dpctl; dpctl.lsplatform()"
353+
python -c "import dpctl; dpctl.lsplatform()"
354354
python -c "import dpnp; print(dpnp.__version__)"
355355
356356
- name: Run tests
@@ -432,6 +432,190 @@ jobs:
432432
env:
433433
ANACONDA_TOKEN: ${{ secrets.ANACONDA_TOKEN }}
434434

435+
array-api-conformity:
436+
name: Array API conformity
437+
438+
needs: build
439+
440+
permissions:
441+
# Needed to add a comment to a pull request's issue
442+
pull-requests: write
443+
444+
strategy:
445+
matrix:
446+
python: ['3.12']
447+
os: [ubuntu-22.04]
448+
449+
runs-on: ${{ matrix.os }}
450+
451+
defaults:
452+
run:
453+
shell: bash -el {0}
454+
455+
continue-on-error: true
456+
457+
env:
458+
array-api-tests-path: '${{ github.workspace }}/array-api-tests/'
459+
json-report-file: '${{ github.workspace }}/.report.json'
460+
dpnp-repo-path: '${{ github.workspace }}/dpnp/'
461+
array-api-skips-file: '${{ github.workspace }}/dpnp/.github/workflows/array-api-skips.txt'
462+
channel-path: '${{ github.workspace }}/channel/'
463+
pkg-path-in-channel: '${{ github.workspace }}/channel/linux-64/'
464+
extracted-pkg-path: '${{ github.workspace }}/pkg/'
465+
ver-json-path: '${{ github.workspace }}/version.json'
466+
467+
steps:
468+
- name: Download artifact
469+
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
470+
with:
471+
name: ${{ env.PACKAGE_NAME }} ${{ runner.os }} Python ${{ matrix.python }}
472+
path: ${{ env.pkg-path-in-channel }}
473+
474+
- name: Extract package archive
475+
run: |
476+
mkdir -p ${{ env.extracted-pkg-path }}
477+
tar -xvf ${{ env.pkg-path-in-channel }}/${{ env.PACKAGE_NAME }}-*.tar.bz2 -C ${{ env.extracted-pkg-path }}
478+
479+
- name: Setup miniconda
480+
id: setup_miniconda
481+
continue-on-error: true
482+
uses: conda-incubator/setup-miniconda@d2e6a045a86077fb6cad6f5adf368e9076ddaa8d # v3.1.0
483+
with:
484+
miniforge-version: latest
485+
use-mamba: 'true'
486+
channels: conda-forge
487+
conda-remove-defaults: 'true'
488+
python-version: ${{ matrix.python }}
489+
activate-environment: 'array-api-conformity'
490+
491+
- name: ReSetup miniconda
492+
if: steps.setup_miniconda.outcome == 'failure'
493+
uses: conda-incubator/setup-miniconda@d2e6a045a86077fb6cad6f5adf368e9076ddaa8d # v3.1.0
494+
with:
495+
miniforge-version: latest
496+
use-mamba: 'true'
497+
channels: conda-forge
498+
conda-remove-defaults: 'true'
499+
python-version: ${{ matrix.python }}
500+
activate-environment: 'array-api-conformity'
501+
502+
- name: Install conda-index
503+
id: install_conda_index
504+
continue-on-error: true
505+
run: mamba install conda-index=${{ env.CONDA_INDEX_VERSION }}
506+
507+
- name: ReInstall conda-index
508+
if: steps.install_conda_index.outcome == 'failure'
509+
run: mamba install conda-index=${{ env.CONDA_INDEX_VERSION }}
510+
511+
- name: Create conda channel
512+
run: |
513+
python -m conda_index ${{ env.channel-path }}
514+
515+
- name: Test conda channel
516+
run: |
517+
conda search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
518+
cat ${{ env.ver-json-path }}
519+
520+
- name: Get package version
521+
run: |
522+
export PACKAGE_VERSION=$(python -c "${{ env.VER_SCRIPT1 }} ${{ env.VER_SCRIPT2 }}")
523+
524+
echo PACKAGE_VERSION=${PACKAGE_VERSION}
525+
echo "PACKAGE_VERSION=$PACKAGE_VERSION" >> $GITHUB_ENV
526+
527+
- name: Install dpnp
528+
id: install_dpnp
529+
continue-on-error: true
530+
run: |
531+
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
532+
env:
533+
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
534+
535+
- name: ReInstall dpnp
536+
if: steps.install_dpnp.outcome == 'failure'
537+
run: |
538+
mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
539+
env:
540+
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
541+
542+
- name: List installed packages
543+
run: mamba list
544+
545+
- name: Smoke test
546+
run: |
547+
python -c "import dpctl; dpctl.lsplatform()"
548+
python -c "import dpnp; print(dpnp.__version__)"
549+
550+
- name: Clone array API tests repo
551+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
552+
with:
553+
repository: 'data-apis/array-api-tests'
554+
path: ${{ env.array-api-tests-path }}
555+
fetch-depth: 0
556+
submodules: 'recursive'
557+
558+
- name: Install array API test dependencies
559+
run: |
560+
pip install -r requirements.txt
561+
working-directory: ${{ env.array-api-tests-path }}
562+
563+
- name: Install jq
564+
run: |
565+
sudo apt-get install jq
566+
567+
- name: List installed packages
568+
run: mamba list
569+
570+
- name: Smoke test
571+
run: |
572+
python -c "import dpctl; dpctl.lsplatform()"
573+
python -c "import dpnp; print(dpnp.__version__)"
574+
575+
# need to fetch array-api-skips.txt
576+
- name: Checkout DPNP repo
577+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
578+
with:
579+
path: ${{ env.dpnp-repo-path }}
580+
581+
- name: Run array API conformance tests
582+
run: |
583+
python -m pytest --json-report --json-report-file=${{ env.json-report-file }} --disable-deadline --skips-file ${{ env.array-api-skips-file }} array_api_tests || true
584+
env:
585+
ARRAY_API_TESTS_MODULE: 'dpnp'
586+
SYCL_CACHE_PERSISTENT: 1
587+
working-directory: ${{ env.array-api-tests-path }}
588+
589+
- name: Set Github environment variables
590+
run: |
591+
FILE=${{ env.json-report-file }}
592+
if test -f "$FILE"; then
593+
PASSED_TESTS=$(jq '.summary | .passed // 0' $FILE)
594+
FAILED_TESTS=$(jq '.summary | .failed // 0' $FILE)
595+
SKIPPED_TESTS=$(jq '.summary | .skipped // 0' $FILE)
596+
MESSAGE="Array API standard conformance tests for dpnp=$PACKAGE_VERSION ran successfully.
597+
Passed: $PASSED_TESTS
598+
Failed: $FAILED_TESTS
599+
Skipped: $SKIPPED_TESTS"
600+
echo "MESSAGE<<EOF" >> $GITHUB_ENV
601+
echo "$MESSAGE" >> $GITHUB_ENV
602+
echo "EOF" >> $GITHUB_ENV
603+
else
604+
echo "Array API standard conformance tests failed to run for dpnp=$PACKAGE_VERSION."
605+
exit 1
606+
fi
607+
608+
- name: Output API summary
609+
run: echo "::notice ${{ env.MESSAGE }}"
610+
611+
- name: Post result to PR
612+
if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork }}
613+
uses: mshick/add-pr-comment@b8f338c590a895d50bcbfa6c5859251edc8952fc # v2.8.2
614+
with:
615+
message: |
616+
${{ env.MESSAGE }}
617+
refresh-message-position: true
618+
435619
cleanup_packages:
436620
name: Clean up anaconda packages
437621

dpnp/dpnp_array.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,6 @@ def mT(self):
153153

154154
return dpnp_array._create_from_usm_ndarray(self._array_obj.mT)
155155

156-
def to_device(self, target_device):
157-
"""Transfer array to target device."""
158-
159-
return dpnp_array(
160-
shape=self.shape, buffer=self.get_array().to_device(target_device)
161-
)
162-
163156
@property
164157
def sycl_queue(self):
165158
return self._array_obj.sycl_queue
@@ -1712,6 +1705,48 @@ def take(self, indices, axis=None, out=None, mode="wrap"):
17121705

17131706
return dpnp.take(self, indices, axis=axis, out=out, mode=mode)
17141707

1708+
def to_device(self, device, /, *, stream=None):
1709+
"""
1710+
Transfers this array to specified target device.
1711+
1712+
Parameters
1713+
----------
1714+
device : {string, SyclDevice, SyclQueue}
1715+
Array API concept of target device. It can be an OneAPI filter
1716+
selector string, an instance of :class:`dpctl.SyclDevice`
1717+
corresponding to a non-partitioned SYCL device, an instance of
1718+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1719+
returned by :obj:`dpnp.dpnp_array.dpnp_array.device` property.
1720+
stream : {SyclQueue, None}, optional
1721+
Execution queue to synchronize with. If ``None``, synchronization
1722+
is not performed.
1723+
Default: ``None``.
1724+
1725+
Returns
1726+
-------
1727+
out : dpnp.ndarray
1728+
A view if data copy is not required, and a copy otherwise.
1729+
If copying is required, it is done by copying from the original
1730+
allocation device to the host, followed by copying from host
1731+
to the target device.
1732+
1733+
Examples
1734+
--------
1735+
>>> import dpnp as np, dpctl
1736+
>>> x = np.full(100, 2, dtype=np.int64)
1737+
>>> q_prof = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1738+
>>> # return a view with profile-enabled queue
1739+
>>> y = x.to_device(q_prof)
1740+
>>> timer = dpctl.SyclTimer()
1741+
>>> with timer(q_prof):
1742+
... z = y * y
1743+
>>> print(timer.dt)
1744+
1745+
"""
1746+
1747+
usm_res = self._array_obj.to_device(device, stream=stream)
1748+
return dpnp_array._create_from_usm_ndarray(usm_res)
1749+
17151750
# 'tobytes',
17161751
# 'tofile',
17171752
# 'tolist',

dpnp/dpnp_iface.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
"get_result_array",
6969
"get_usm_ndarray",
7070
"get_usm_ndarray_or_scalar",
71+
"is_cuda_backend",
7172
"is_supported_array_or_scalar",
7273
"is_supported_array_type",
7374
"synchronize_array_data",
@@ -681,6 +682,41 @@ def get_usm_ndarray_or_scalar(a):
681682
return a if dpnp.isscalar(a) else get_usm_ndarray(a)
682683

683684

685+
def is_cuda_backend(obj=None):
686+
"""
687+
Checks that object has a CUDA backend.
688+
689+
Parameters
690+
----------
691+
obj : {Device, SyclDevice, SyclQueue, dpnp.ndarray, usm_ndarray, None},
692+
optional
693+
An input object with sycl_device property to check device backend.
694+
If `obj` is ``None``, device backend will be checked for the default
695+
queue.
696+
Default: ``None``.
697+
698+
Returns
699+
-------
700+
out : bool
701+
Return ``True`` if data of the input object resides on a CUDA backend,
702+
otherwise ``False``.
703+
704+
"""
705+
706+
if obj is None:
707+
sycl_device = dpctl.select_default_device()
708+
elif isinstance(obj, dpctl.SyclDevice):
709+
sycl_device = obj
710+
else:
711+
sycl_device = getattr(obj, "sycl_device", None)
712+
if (
713+
sycl_device is not None
714+
and sycl_device.backend == dpctl.backend_type.cuda
715+
):
716+
return True
717+
return False
718+
719+
684720
def is_supported_array_or_scalar(a):
685721
"""
686722
Return ``True`` if `a` is a scalar or an array of either

0 commit comments

Comments
 (0)