Skip to content

PyTorch compatibility layer #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 95 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
f49f42b
Start pytorch compatibility layer
asmeurer Jan 7, 2023
4b95748
Add vendor tests for torch
asmeurer Jan 7, 2023
1ecb7ca
Replace torch expand_dims wrapper to a wrapper around unsqueeze
asmeurer Jan 9, 2023
52b6054
Add torch support to the helper functions
asmeurer Jan 9, 2023
c484dbf
Add max and min wrappers for torch
asmeurer Jan 9, 2023
3023a89
Add a wrapper for torch.prod
asmeurer Jan 9, 2023
a1bbd9b
Add the torch prod wrapper to __all__
asmeurer Jan 9, 2023
dba7fa7
Return a copy from max and min with axis=()
asmeurer Jan 10, 2023
44d91e1
Add a size() helper function
asmeurer Jan 10, 2023
1faea7b
Add any and all torch wrappers and fix some issues with prod
asmeurer Jan 10, 2023
ad4484d
Add astype torch wrapper
asmeurer Jan 11, 2023
db3241d
Cast the input to prod/all/any to tensor
asmeurer Jan 11, 2023
3f0d913
More logical order for some functions
asmeurer Jan 11, 2023
2d25dd2
Add wrappers for two-argument elementwise functions
asmeurer Jan 24, 2023
c4c0cfa
Add bitwise_invert to torch
asmeurer Jan 24, 2023
a1917f8
Add torch wrappers for broadcast_to and can_cast
asmeurer Jan 24, 2023
ae28ce0
Add torch arange wrapper
asmeurer Jan 24, 2023
c3d9334
Add a wrapper for torch.eye
asmeurer Jan 24, 2023
f4d6df1
Add pytorch linspace wrapper
asmeurer Jan 24, 2023
14b6519
Add torch squeeze wrapper
asmeurer Jan 24, 2023
db3f579
Add torch flip and roll wrappers
asmeurer Jan 24, 2023
27b7e8c
Add a torch wrapper for nonzero
asmeurer Jan 24, 2023
e886644
Fix pyflakes warning
asmeurer Jan 24, 2023
fd5d179
Add torch wrapper for where
asmeurer Jan 24, 2023
a5f3253
Add sort wrapper to torch
asmeurer Feb 2, 2023
4a71e63
Pass kwargs through some torch wrappers
asmeurer Feb 2, 2023
157fc1e
Add torch mean(), std(), and var() wrappers
asmeurer Feb 2, 2023
c3efe6a
Add torch sum() and prod() wrappers
asmeurer Feb 3, 2023
ed46247
Add unique_* wrappers to torch
asmeurer Feb 4, 2023
eaf5358
Just raise NotImplementedError in pytorch unique_all()
asmeurer Feb 4, 2023
cd25f47
Fix to_device for pytorch tensors
asmeurer Feb 10, 2023
98ed0b2
Restrict the names imported from torch into the compat submodule
asmeurer Feb 10, 2023
be4b534
Allow torch sum and prod to upcast uint8 to int64
asmeurer Feb 10, 2023
ecda017
Don't unnecessarily flip the axes in flip()
asmeurer Feb 10, 2023
3ccee1b
Comment out dead code in the torch unique_all() wrapper
asmeurer Feb 10, 2023
7699755
Use flatten instead of ravel
asmeurer Feb 10, 2023
d38bad5
Improve some error messages
asmeurer Feb 10, 2023
24c0ea3
Use a better function name and use unsqueeze instead of None indexing
asmeurer Feb 13, 2023
48d1ae1
Add pytorch-xfails.txt (still need to validate)
asmeurer Feb 18, 2023
866647d
Move main namespace linear algebra helpers to _aliases.py
asmeurer Feb 18, 2023
85b71de
Merge branch 'master' into pytorch
asmeurer Feb 18, 2023
c441f33
Fix main namespace linalg functions in numpy and cupy
asmeurer Feb 18, 2023
04eef18
Add main namespace linalg functions to the torch wrapper
asmeurer Feb 18, 2023
453ecb8
Add torch wrapper for matmul
asmeurer Feb 18, 2023
5a3bbbe
Finish torch wrappers for matmul, vecdot, and tensordot
asmeurer Feb 20, 2023
b8bbdc8
Clean up pytorch-xfails file
asmeurer Feb 21, 2023
7d176b9
Update pytorch-xfails.txt
asmeurer Feb 21, 2023
17c0a91
Merge branch 'main' into pytorch
asmeurer Feb 21, 2023
1ffcb15
Install pytorch in on CI
asmeurer Feb 21, 2023
0db034d
Make the GitHub Actions workflow reusable so that we can test pytorch
asmeurer Feb 21, 2023
39cbfd4
Fix workflow path
asmeurer Feb 21, 2023
c3d0d8e
Fix variable interpolation syntax
asmeurer Feb 21, 2023
cf21cea
Allow specifying extra pytest args in the test yamls
asmeurer Feb 21, 2023
a95eeb6
Enable verbose output for the torch tests
asmeurer Feb 21, 2023
4737dc0
Revert "Enable verbose output for the torch tests"
asmeurer Feb 21, 2023
ea9c1e2
Skip the torch test that crashes the CI
asmeurer Feb 21, 2023
4904411
Skip another test that crashes on CI
asmeurer Feb 22, 2023
039af59
Disable linalg in the torch CI tests
asmeurer Feb 22, 2023
367c4b6
Add missing torch xfails
asmeurer Feb 22, 2023
f3ee38c
Do a verbose CI run for the pytorch array API tests
asmeurer Feb 22, 2023
c631fa3
Revert "Do a verbose CI run for the pytorch array API tests"
asmeurer Feb 22, 2023
33dacf9
Add some missing torch xfails
asmeurer Feb 22, 2023
82b8def
Do a verbose output run of the torch array API tests (with the correc…
asmeurer Feb 22, 2023
b014c1b
Revert "Do a verbose output run of the torch array API tests (with th…
asmeurer Feb 22, 2023
124d6c3
Add a missing torch xfail
asmeurer Feb 22, 2023
847a9e5
Add a missing torch xfail
asmeurer Feb 22, 2023
0857b86
Skip test_floor_divide, which core dumps on CI
asmeurer Feb 23, 2023
6354cd9
Add a missing torch xfail
asmeurer Feb 23, 2023
0565dee
Update the README
asmeurer Feb 23, 2023
e9b447c
Fix some formatting in the README
asmeurer Feb 23, 2023
bb4d3af
Typo fix
asmeurer Feb 23, 2023
5545635
Update torch reduction functions that don't support multiple axes
asmeurer Feb 23, 2023
a78f733
Add a test skip that crashes on CI
asmeurer Feb 23, 2023
0cefa7b
Add a CHANGELOG for the upcoming 1.1 release
asmeurer Feb 23, 2023
1aceff5
Add more skips for tests that crash on CI
asmeurer Feb 23, 2023
b016d4c
Skip a torch test that crashes CI
asmeurer Feb 23, 2023
1cd43f2
Add a torch xfail
asmeurer Feb 23, 2023
c8a5a70
Add a script to manually run the cupy tests
asmeurer Feb 23, 2023
7261dae
Add a torch skip
asmeurer Feb 23, 2023
795dbea
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 23, 2023
0368c8f
Use cupy specific skips and xfails
asmeurer Feb 24, 2023
2b456c2
Allow passing pytest args through in test_cupy.sh
asmeurer Feb 24, 2023
046ffd0
Add a shebang to test_cupy.sh
asmeurer Feb 24, 2023
c3eb0d5
Make the hypothesis examples database persistent in test_cupy.sh
asmeurer Feb 24, 2023
111a122
Fix sort() and argsort() with cupy
asmeurer Feb 24, 2023
09b5a6f
Add comments for the rest of the cupy xfails
asmeurer Feb 24, 2023
b780b9a
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 24, 2023
1908a00
Fix argument quoting in test_cupy.sh
asmeurer Feb 25, 2023
5cbd1c0
Update cupy skips and xfails
asmeurer Feb 25, 2023
6a63d5c
Update cupy xfails
asmeurer Feb 25, 2023
175f195
Update test_cupy.sh to run the vendoring tests
asmeurer Feb 25, 2023
d1c7999
Add a minor CHANGELOG entry
asmeurer Feb 25, 2023
0e923b6
Merge branch 'pytorch' of github.com:asmeurer/array-api-compat into p…
asmeurer Feb 25, 2023
2fb0a0a
Bump the version to 1.1
asmeurer Feb 25, 2023
3470b36
Add a missing torch skip
asmeurer Feb 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/array-api-tests-numpy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: Array API Tests (NumPy)

on: [push, pull_request]

jobs:
array-api-tests-numpy:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: numpy
10 changes: 10 additions & 0 deletions .github/workflows/array-api-tests-torch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Array API Tests (PyTorch)

on: [push, pull_request]

jobs:
array-api-tests-torch:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: torch
pytest-extra-args: "--disable-extension linalg"
21 changes: 15 additions & 6 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
name: Array API Tests

on: [push, pull_request]
on:
workflow_call:
inputs:
package-name:
required: true
type: string
pytest-extra-args:
required: false
type: string


env:
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci"
PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}"

jobs:
tests:
Expand Down Expand Up @@ -34,15 +43,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install numpy
python -m pip install ${{ inputs.package-name }}
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
- name: Run the array API testsuite (NumPy)
- name: Run the array API testsuite (${{ inputs.package-name }})
env:
ARRAY_API_TESTS_MODULE: array_api_compat.numpy
ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }}
# This enables the NEP 50 type promotion behavior (without it a lot of
# tests fail on bad scalar type promotion behavior)
NPY_PROMOTION_STATE: weak
run: |
export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/numpy-xfails.txt array_api_tests/
pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-skips.txt array_api_tests/
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest numpy
python -m pip install pytest numpy torch

- name: Run Tests
run: |
Expand Down
24 changes: 24 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# 1.1 (2023-02-24)

## Major Changes

- Added support for PyTorch.

- Add helper function `size()` (required if torch is used as
`torch.Tensor.size` is a method that is incompatible with the array API
[`.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size)).

- All wrapper functions that wrap existing library functions now pass through
arbitrary `**kwargs`.

## Minor Changes

- Added CI to run against the [array API testsuite](https://github.com/data-apis/array-api-tests).

- Fix `sort(stable=False)` and `argsort(stable=False)` with CuPy.

# 1.0 (2022-12-05)

## Major Changes

- Initial release. Includes support for NumPy and CuPy.
Loading