Skip to content

Commit b143d02

Browse files
authored
Merge pull request #14 from asmeurer/pytorch
WIP: PyTorch compatibility layer
2 parents 4af45d0 + 3470b36 commit b143d02

25 files changed

+1367
-127
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name: Array API Tests (NumPy)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-numpy:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: numpy
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: Array API Tests (PyTorch)
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
array-api-tests-torch:
7+
uses: ./.github/workflows/array-api-tests.yml
8+
with:
9+
package-name: torch
10+
pytest-extra-args: "--disable-extension linalg"

.github/workflows/array-api-tests.yml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
name: Array API Tests
22

3-
on: [push, pull_request]
3+
on:
4+
workflow_call:
5+
inputs:
6+
package-name:
7+
required: true
8+
type: string
9+
pytest-extra-args:
10+
required: false
11+
type: string
12+
413

514
env:
6-
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci"
15+
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}"
716

817
jobs:
918
tests:
@@ -34,15 +43,15 @@ jobs:
3443
- name: Install dependencies
3544
run: |
3645
python -m pip install --upgrade pip
37-
python -m pip install numpy
46+
python -m pip install ${{ inputs.package-name }}
3847
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
39-
- name: Run the array API testsuite (NumPy)
48+
- name: Run the array API testsuite (${{ inputs.package-name }})
4049
env:
41-
ARRAY_API_TESTS_MODULE: array_api_compat.numpy
50+
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }}
4251
# This enables the NEP 50 type promotion behavior (without it a lot of
4352
# tests fail on bad scalar type promotion behavior)
4453
NPY_PROMOTION_STATE: weak
4554
run: |
4655
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
4756
cd ${GITHUB_WORKSPACE}/array-api-tests
48-
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/numpy-xfails.txt array_api_tests/
57+
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-skips.txt array_api_tests/

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Install Dependencies
1616
run: |
1717
python -m pip install --upgrade pip
18-
python -m pip install pytest numpy
18+
python -m pip install pytest numpy torch
1919
2020
- name: Run Tests
2121
run: |

CHANGELOG.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# 1.1 (2023-02-24)
2+
3+
## Major Changes
4+
5+
- Added support for PyTorch.
6+
7+
- Add helper function `size()` (required if torch is used as
8+
`torch.Tensor.size` is a method that is incompatible with the array API
9+
[`.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size)).
10+
11+
- All wrapper functions that wrap existing library functions now pass through
12+
arbitrary `**kwargs`.
13+
14+
## Minor Changes
15+
16+
- Added CI to run against the [array API testsuite](https://github.com/data-apis/array-api-tests).
17+
18+
- Fix `sort(stable=False)` and `argsort(stable=False)` with CuPy.
19+
20+
# 1.0 (2022-12-05)
21+
22+
## Major Changes
23+
24+
- Initial release. Includes support for NumPy and CuPy.

0 commit comments

Comments
 (0)