diff --git a/.github/workflows/array-api-tests-numpy.yml b/.github/workflows/array-api-tests-numpy.yml new file mode 100644 index 00000000..1ba54a84 --- /dev/null +++ b/.github/workflows/array-api-tests-numpy.yml @@ -0,0 +1,9 @@ +name: Array API Tests (NumPy) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml new file mode 100644 index 00000000..e8caeffa --- /dev/null +++ b/.github/workflows/array-api-tests-torch.yml @@ -0,0 +1,10 @@ +name: Array API Tests (PyTorch) + +on: [push, pull_request] + +jobs: + array-api-tests-torch: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: torch + pytest-extra-args: "--disable-extension linalg" diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 79c8b0b7..a1f9c223 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -1,9 +1,18 @@ name: Array API Tests -on: [push, pull_request] +on: + workflow_call: + inputs: + package-name: + required: true + type: string + pytest-extra-args: + required: false + type: string + env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci" + PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }}" jobs: tests: @@ -34,15 +43,15 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install numpy + python -m pip install ${{ inputs.package-name }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt - - name: Run the array API testsuite (NumPy) + - name: Run the array API testsuite (${{ inputs.package-name }}) env: - ARRAY_API_TESTS_MODULE: array_api_compat.numpy + ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.package-name }} # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak run: | export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat" cd ${GITHUB_WORKSPACE}/array-api-tests - pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/numpy-xfails.txt array_api_tests/ + pytest ${PYTEST_ARGS} --xfails-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-xfails.txt --skips-file ${GITHUB_WORKSPACE}/array-api-compat/${{ inputs.package-name }}-skips.txt array_api_tests/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b014304b..4ff074c4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,7 +15,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest numpy + python -m pip install pytest numpy torch - name: Run Tests run: | diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..b32a2d33 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,24 @@ +# 1.1 (2023-02-24) + +## Major Changes + +- Added support for PyTorch. + +- Add helper function `size()` (required if torch is used as + `torch.Tensor.size` is a method that is incompatible with the array API + [`.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size)). + +- All wrapper functions that wrap existing library functions now pass through + arbitrary `**kwargs`. + +## Minor Changes + +- Added CI to run against the [array API testsuite](https://github.com/data-apis/array-api-tests). + +- Fix `sort(stable=False)` and `argsort(stable=False)` with CuPy. + +# 1.0 (2022-12-05) + +## Major Changes + +- Initial release. Includes support for NumPy and CuPy. diff --git a/README.md b/README.md index a49fcb5b..90eee1d2 100644 --- a/README.md +++ b/README.md @@ -1,93 +1,124 @@ # Array API compatibility library -This is a small wrapper around NumPy and CuPy that is compatible with the -[Array API standard](https://data-apis.org/array-api/latest/). See also [NEP -47](https://numpy.org/neps/nep-0047-array-api-standard.html). - -Unlike `numpy.array_api`, this is not a strict minimal implementation of the -Array API, but rather just an extension of the main NumPy and CuPy namespaces -with changes needed to be compliant with the Array API. - -Library authors using the Array API may wish to test against `numpy.array_api` -to ensure they are not using functionality outside of the standard, but prefer -this implementation for the default when working with NumPy or CuPy arrays. - -See https://numpy.org/doc/stable/reference/array_api.html for a full list of -changes. In particular, unlike `numpy.array_api`, this package does not use a -separate Array object, but rather just uses `numpy.ndarray` directly. +This is a small wrapper around common array libraries that is compatible with +the [Array API standard](https://data-apis.org/array-api/latest/). Currently, +NumPy, CuPy, and PyTorch are supported. If you want support for other array +libraries, or if you encounter any issues, please [open an +issue](https://github.com/data-apis/array-api-compat/issues). Note that some of the functionality in this library is backwards incompatible -with NumPy. - -This library also supports CuPy in addition to NumPy. If you want support for -other array libraries, please [open an -issue](https://github.com/data-apis/array-api-compat/issues). +with the corresponding wrapped libraries. The end-goal is to eventually make +each array library itself fully compatible with the array API, but this +requires making backwards incompatible changes in many cases, so this will +take some time. -Library authors using the Array API may wish to test against `numpy.array_api` -to ensure they are not using functionality outside of the standard, but prefer -this implementation for end users who use NumPy arrays. +Currently all libraries here are implemented against the 2021.12 version of +the standard. Support for the [2022.12 +version](https://data-apis.org/array-api/2022.12/changelog.html), which adds +complex number support as well as several additional functions, will be added +later this year. ## Usage -To use this library replace +The typical usage of this library will be to get the corresponding array API +compliant namespace from the input arrays using `get_namespace()`, like ```py -import numpy as np +def your_function(x, y): + xp = array_api_compat.get_namespace(x, y) + # Now use xp as the array library namespace + return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) ``` -with +If you wish to have library-specific code-paths, you can import the +corresponding wrapped namespace for each library, like ```py import array_api_compat.numpy as np ``` -and replace - ```py -import cupy as cp +import array_api_compat.cupy as cp ``` -with - ```py -import array_api_compat.cupy as cp +import array_api_compat.torch as torch ``` -Each will include all the functions from the normal NumPy/CuPy namespace, -except that functions that are part of the array API are wrapped so that they -have the correct array API behavior. In each case, the array object used will -be the same array object from the wrapped library. - +Each will include all the functions from the normal NumPy/CuPy/PyTorch +namespace, except that functions that are part of the array API are wrapped so +that they have the correct array API behavior. In each case, the array object +used will be the same array object from the wrapped library. + +## Difference between `array_api_compat` and `numpy.array_api` + +`numpy.array_api` is a strict minimal implementation of the Array API (see +[NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html)). For +example, `numpy.array_api` does not include any functions that are not part of +the array API specification, and will explicitly disallow behaviors that are +not required by the spec (e.g., [cross-kind type +promotions](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)). +(`cupy.array_api` is similar to `numpy.array_api`) + +`array_api_compat`, on the other hand, is just an extension of the +corresponding array library namespaces with changes needed to be compliant +with the array API. It includes all additional library functions not mentioned +in the spec, and allows any library behaviors not explicitly disallowed by it, +such as cross-kind casting. + +In particular, unlike `numpy.array_api`, this package does not use a separate +`Array` object, but rather just uses the corresponding array library array +objects (`numpy.ndarray`, `cupy.ndarray`, `torch.Tensor`, etc.) directly. This +is because those are the objects that are going to be passed as inputs to +functions by end users. This does mean that a few behaviors cannot be wrapped +(see below), but most of the array API functional, so this does not affect +most things. + +Array consuming library authors coding against the array API may wish to test +against `numpy.array_api` to ensure they are not using functionality outside +of the standard, but prefer this implementation for the default behavior for +end-users. ## Helper Functions -In addition to the default NumPy/CuPy namespace and functions in the array API -specification, there are several helper functions -included that aren't part of the specification but which are useful for using -the array API: +In addition to the wrapped library namespaces and functions in the array API +specification, there are several helper functions included here that aren't +part of the specification but which are useful for using the array API: - `is_array_api_obj(x)`: Return `True` if `x` is an array API compatible array object. - `get_namespace(*xs)`: Get the corresponding array API namespace for the - arrays `xs`. If the arrays are NumPy or CuPy arrays, the returned namespace - will be `array_api_compat.numpy` or `array_api_compat.cupy` so that it is - array API compatible. + arrays `xs`. For example, if the arrays are NumPy arrays, the returned + namespace will be `array_api_compat.numpy`. Note that this function will + also work for namespaces that aren't supported by this compat library but + which do support the array API (i.e., arrays that have the + `__array_namespace__` attribute). - `device(x)`: Equivalent to [`x.device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.device.html) in the array API specification. Included because `numpy.ndarray` does not include the `device` attribute and this library does not wrap or extend the - array object. Note that for NumPy, `device` is always `"cpu"`. + array object. Note that for NumPy, `device(x)` is always `"cpu"`. - `to_device(x, device, /, *, stream=None)`: Equivalent to [`x.to_device`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.to_device.html). - Included because neither NumPy's nor CuPy's ndarray objects include this - method. For NumPy, this function effectively does nothing since the only - supported device is the CPU, but for CuPy, this method supports CuPy CUDA + Included because neither NumPy's, CuPy's, nor PyTorch's array objects + include this method. For NumPy, this function effectively does nothing since + the only supported device is the CPU, but for CuPy, this method supports + CuPy CUDA [Device](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Device.html) and [Stream](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html) + objects. For PyTorch, this is the same as + [`x.to(device)`](https://pytorch.org/docs/stable/generated/torch.Tensor.to.html) + (the `stream` argument is not supported in PyTorch). + +- `size(x)`: Equivalent to + [`x.size`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html#array_api.array.size), + i.e., the number of elements in the array. Included because PyTorch's + `Tensor` defines `size` as a method which returns the shape, and this cannot + be wrapped because this compat library doesn't wrap or extend the array objects. ## Known Differences from the Array API Specification @@ -95,6 +126,8 @@ the array API: There are some known differences between this library and the array API specification: +### NumPy and CuPy + - The array methods `__array_namespace__`, `device` (for NumPy), `to_device`, and `mT` are not defined. This reuses `np.ndarray` and `cp.ndarray` and we don't want to monkeypatch or wrap it. The helper functions `device()` and @@ -102,16 +135,61 @@ specification: `x.mT` can be replaced with `xp.linalg.matrix_transpose(x)`. `get_namespace(x)` should be used instead of `x.__array_namespace__`. -- NumPy value-based casting for scalars will be in effect unless explicitly - disabled with the environment variable NPY_PROMOTION_STATE=weak or - np._set_promotion_state('weak') (requires NumPy 1.24 or newer, see NEP 50 - and https://github.com/numpy/numpy/issues/22341) +- Value-based casting for scalars will be in effect unless explicitly disabled + with the environment variable `NPY_PROMOTION_STATE=weak` or + `np._set_promotion_state('weak')` (requires NumPy 1.24 or newer, see [NEP + 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and + https://github.com/numpy/numpy/issues/22341) - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. +### PyTorch + +- Like NumPy/CuPy, we do not wrap the `torch.Tensor` object. It is missing the + `__array_namespace__` and `to_device` methods, so the corresponding helper + functions `get_namespace()` and `to_device()` in this library should be + used instead (see above). + +- The `x.size` attribute on `torch.Tensor` is a function that behaves + differently from + [`x.size`](https://data-apis.org/array-api/draft/API_specification/generated/array_api.array.size.html) + in the spec. Use the `size(x)` helper function as a portable workaround (see + above). + +- The `linalg` extension is not yet implemented. + +- PyTorch does not have unsigned integer types other than `uint8`, and no + attempt is made to implement them here. + +- PyTorch has type promotion semantics that differ from the array API + specification for 0-D tensor objects. The array functions in this wrapper + library do work around this, but the operators on the Tensor object do not, + as no operators or methods on the Tensor object are modified. If this is a + concern, use the functional form instead of the operator form, e.g., `add(x, + y)` instead of `x + y`. + +- [`unique_all()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html#array_api.unique_all) + is not implemented, due to the fact that `torch.unique` does not support + returning the `indices` array. The other + [`unique_*`](https://data-apis.org/array-api/latest/API_specification/set_functions.html) + functions are implemented. + +- Slices do not support negative steps. + +- [`std()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html#array_api.std) + and + [`var()`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html#array_api.var) + do not support floating-point `correction`. + +- The `stream` argument of the `to_device()` helper (see above) is not + supported. + +- As with NumPy, type annotations and positional-only arguments may not + exactly match the spec for functions that are not wrapped at all. + ## Vendoring This library supports vendoring as an installation method. To vendor the @@ -127,17 +205,21 @@ references the name "array_api_compat"). Alternatively, the library may be installed as dependency on PyPI. -## Implementation +## Implementation Notes As noted before, the goal of this library is to reuse the NumPy and CuPy array objects, rather than wrapping or extending them. This means that the functions need to accept and return `np.ndarray` for NumPy and `cp.ndarray` for CuPy. -Each namespace (`array_api_compat.numpy` and `array_api_compat.cupy`) is -populated with the normal library namespace (like `from numpy import *`). Then -specific functions are replaced with wrapped variants. Wrapped functions that -have the same logic between NumPy and CuPy (which is most functions) are in -`array_api_compat/common/`. These functions are defined like +Each namespace (`array_api_compat.numpy`, `array_api_compat.cupy`, and +`array_api_compat.torch`) is populated with the normal library namespace (like +`from numpy import *`). Then specific functions are replaced with wrapped +variants. + +Since NumPy and CuPy are nearly identical in behavior, most wrapping logic can +be shared between them. Wrapped functions that have the same logic between +NumPy and CuPy are in `array_api_compat/common/`. +These functions are defined like ```py # In array_api_compat/common/_aliases.py @@ -147,10 +229,10 @@ def acos(x, /, xp): ``` The `xp` argument refers to the original array namespace (either `numpy` or -`cupy`). Then in the specific `array_api_compat/numpy` and -`array_api_compat/cupy` namespace, the `get_xp` decorator is applied to these -functions, which automatically removes the `xp` argument from the function -signature and replaces it with the corresponding array library, like +`cupy`). Then in the specific `array_api_compat/numpy/` and +`array_api_compat/cupy/` namespaces, the `@get_xp` decorator is applied to +these functions, which automatically removes the `xp` argument from the +function signature and replaces it with the corresponding array library, like ```py # In array_api_compat/numpy/_aliases.py @@ -177,6 +259,15 @@ acos = get_xp(cp)(_aliases.acos) ``` Since NumPy and CuPy are nearly identical in their behaviors, this allows -writing the wrapping logic for both libraries only once. If support is added -for other libraries which differ significantly from NumPy, their wrapper code -should go in their specific sub-namespace instead of `common/`. +writing the wrapping logic for both libraries only once. + +PyTorch uses a similar layout in `array_api_compat/torch/`, but it differs +enough from NumPy/CuPy that very few common wrappers for those libraries are +reused. + +See https://numpy.org/doc/stable/reference/array_api.html for a full list of +changes from the base NumPy (the differences for CuPy are nearly identical). A +corresponding document does not yet exist for PyTorch, but you can examine the +various comments in the +[implementation](https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/torch/_aliases.py) +to see what functions and behaviors have been wrapped. diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index ca195443..1b054683 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.0' +__version__ = '1.1' from .common import * diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 12544bd3..72bb3f2d 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Tuple, Union, List + from typing import Optional, Sequence, Tuple, Union, List from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol from typing import NamedTuple @@ -332,17 +332,19 @@ def argsort( **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" + # We set it in kwargs like this because numpy.sort uses kind='quicksort' + # as the default whereas cupy.sort uses kind=None. + if stable: + kwargs['kind'] = "stable" if not descending: - res = xp.argsort(x, axis=axis, kind=kind, **kwargs) + res = xp.argsort(x, axis=axis, **kwargs) else: # As NumPy has no native descending sort, we imitate it here. Note that # simply flipping the results of xp.argsort(x, ...) would not # respect the relative order like it would in native descending sorts. res = xp.flip( - xp.argsort(xp.flip(x, axis=axis), axis=axis, kind=kind), + xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs), axis=axis, - **kwargs, ) # Rely on flip()/argsort() to validate axis normalised_axis = axis if axis >= 0 else x.ndim + axis @@ -355,8 +357,11 @@ def sort( **kwargs, ) -> ndarray: # Note: this keyword argument is different, and the default is different. - kind = "stable" if stable else "quicksort" - res = xp.sort(x, axis=axis, kind=kind, **kwargs) + # We set it in kwargs like this because numpy.sort uses kind='quicksort' + # as the default whereas cupy.sort uses kind=None. + if stable: + kwargs['kind'] = "stable" + res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res @@ -408,8 +413,50 @@ def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: return x return xp.trunc(x, **kwargs) +# linear algebra functions + +def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + return xp.matmul(x1, x2, **kwargs) + +# Unlike transpose, matrix_transpose only transposes the last two axes. +def matrix_transpose(x: ndarray, /, xp) -> ndarray: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return xp.swapaxes(x, -1, -2) + +def tensordot(x1: ndarray, + x2: ndarray, + /, + xp, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> ndarray: + return xp.tensordot(x1, x2, axes=axes, **kwargs) + +def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + if hasattr(xp, 'broadcast_tensors'): + _broadcast = xp.broadcast_tensors + else: + _broadcast = xp.broadcast_arrays + + x1_, x2_ = _broadcast(x1, x2) + x1_ = xp.moveaxis(x1_, axis, -1) + x2_ = xp.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return res[..., 0, 0] + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', + 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', - 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc'] + 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', + 'matrix_transpose', 'tensordot', 'vecdot'] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a1310b1c..6a4a43fd 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations import sys +import math def _is_numpy_array(x): # Avoid importing NumPy if it isn't already @@ -29,11 +30,24 @@ def _is_cupy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, (cp.ndarray, cp.generic)) +def _is_torch_array(x): + # Avoid importing torch if it isn't already + if 'torch' not in sys.modules: + return False + + import torch + + # TODO: Should we reject ndarray subclasses? + return isinstance(x, torch.Tensor) + def is_array_api_obj(x): """ Check if x is an array API compatible array object. """ - return _is_numpy_array(x) or _is_cupy_array(x) or hasattr(x, '__array_namespace__') + return _is_numpy_array(x) \ + or _is_cupy_array(x) \ + or _is_torch_array(x) \ + or hasattr(x, '__array_namespace__') def get_namespace(*xs, _use_compat=True): """ @@ -139,7 +153,12 @@ def _cupy_to_device(x, device, /, stream=None): prev_stream.use() return arr -def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array": +def _torch_to_device(x, device, /, stream=None): + if stream is not None: + raise NotImplementedError + return x.to(device) + +def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -169,7 +188,16 @@ def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, An elif _is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) - + elif _is_torch_array(x): + return _torch_to_device(x, device, stream=stream) return x.to_device(device, stream=stream) -__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device'] +def size(x): + """ + Return the total number of elements of x + """ + if None in x.shape: + return None + return math.prod(x.shape) + +__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device', 'size'] diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index c42879d6..07daefd9 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -7,28 +7,16 @@ from numpy.core.numeric import normalize_axis_tuple +from ._aliases import matmul, matrix_transpose, tensordot, vecdot from .._internal import get_xp # These are in the main NumPy namespace but not in numpy.linalg def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: return xp.cross(x1, x2, axis=axis, **kwargs) -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: - return xp.matmul(x1, x2, **kwargs) - def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: return xp.outer(x1, x2, **kwargs) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: - return xp.tensordot(x1, x2, axes=axes, **kwargs) - class EighResult(NamedTuple): eigenvalues: ndarray eigenvectors: ndarray @@ -103,31 +91,11 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) -# Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return xp.swapaxes(x, -1, -2) - # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = xp.broadcast_arrays(x1, x2) - x1_ = xp.moveaxis(x1_, axis, -1) - x2_ = xp.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return res[..., 0, 0] - def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 8cace32e..ce7f3780 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -57,6 +57,10 @@ ceil = get_xp(cp)(_aliases.ceil) floor = get_xp(cp)(_aliases.floor) trunc = get_xp(cp)(_aliases.trunc) +matmul = get_xp(cp)(_aliases.matmul) +matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) +tensordot = get_xp(cp)(_aliases.tensordot) +vecdot = get_xp(cp)(_aliases.vecdot) __all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 04c71dec..99c4cc68 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -10,13 +10,12 @@ from ..common import _linalg from .._internal import get_xp +from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) import cupy as cp cross = get_xp(cp)(_linalg.cross) -matmul = get_xp(cp)(_linalg.matmul) outer = get_xp(cp)(_linalg.outer) -tensordot = get_xp(cp)(_linalg.tensordot) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult @@ -29,9 +28,7 @@ matrix_rank = get_xp(cp)(_linalg.matrix_rank) pinv = get_xp(cp)(_linalg.pinv) matrix_norm = get_xp(cp)(_linalg.matrix_norm) -matrix_transpose = get_xp(cp)(_linalg.matrix_transpose) svdvals = get_xp(cp)(_linalg.svdvals) -vecdot = get_xp(cp)(_linalg.vecdot) vector_norm = get_xp(cp)(_linalg.vector_norm) diagonal = get_xp(cp)(_linalg.diagonal) trace = get_xp(cp)(_linalg.trace) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 9df2c3fb..2022b842 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -57,6 +57,10 @@ ceil = get_xp(np)(_aliases.ceil) floor = get_xp(np)(_aliases.floor) trunc = get_xp(np)(_aliases.trunc) +matmul = get_xp(np)(_aliases.matmul) +matrix_transpose = get_xp(np)(_aliases.matrix_transpose) +tensordot = get_xp(np)(_aliases.tensordot) +vecdot = get_xp(np)(_aliases.vecdot) __all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index ac04b055..26d6e88e 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -3,13 +3,12 @@ from ..common import _linalg from .._internal import get_xp +from ._aliases import (matmul, matrix_transpose, tensordot, vecdot) import numpy as np cross = get_xp(np)(_linalg.cross) -matmul = get_xp(np)(_linalg.matmul) outer = get_xp(np)(_linalg.outer) -tensordot = get_xp(np)(_linalg.tensordot) EighResult = _linalg.EighResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult @@ -22,9 +21,7 @@ matrix_rank = get_xp(np)(_linalg.matrix_rank) pinv = get_xp(np)(_linalg.pinv) matrix_norm = get_xp(np)(_linalg.matrix_norm) -matrix_transpose = get_xp(np)(_linalg.matrix_transpose) svdvals = get_xp(np)(_linalg.svdvals) -vecdot = get_xp(np)(_linalg.vecdot) vector_norm = get_xp(np)(_linalg.vector_norm) diagonal = get_xp(np)(_linalg.diagonal) trace = get_xp(np)(_linalg.trace) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py new file mode 100644 index 00000000..7dfdf482 --- /dev/null +++ b/array_api_compat/torch/__init__.py @@ -0,0 +1,19 @@ +from torch import * + +# Several names are not included in the above import * +import torch +for n in dir(torch): + if (n.startswith('_') + or n.endswith('_') + or 'cuda' in n + or 'cpu' in n + or 'backward' in n): + continue + exec(n + ' = torch.' + n) + +# These imports may overwrite names from the import * above. +from ._aliases import * + +from ..common._helpers import * + +__array_api_version__ = '2021.12' diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py new file mode 100644 index 00000000..ecd0ba10 --- /dev/null +++ b/array_api_compat/torch/_aliases.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +from functools import wraps +from builtins import all as builtin_all + +from ..common._aliases import (UniqueAllResult, UniqueCountsResult, + UniqueInverseResult, + matrix_transpose as _aliases_matrix_transpose, + vecdot as _aliases_vecdot) +from .._internal import get_xp + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import List, Optional, Sequence, Tuple, Union + from ..common._typing import Device + from torch import dtype as Dtype + +import torch +array = torch.Tensor + +_array_api_dtypes = { + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float32, + torch.float64, +} + +_promotion_table = { + # bool + (torch.bool, torch.bool): torch.bool, + # ints + (torch.int8, torch.int8): torch.int8, + (torch.int8, torch.int16): torch.int16, + (torch.int8, torch.int32): torch.int32, + (torch.int8, torch.int64): torch.int64, + (torch.int16, torch.int8): torch.int16, + (torch.int16, torch.int16): torch.int16, + (torch.int16, torch.int32): torch.int32, + (torch.int16, torch.int64): torch.int64, + (torch.int32, torch.int8): torch.int32, + (torch.int32, torch.int16): torch.int32, + (torch.int32, torch.int32): torch.int32, + (torch.int32, torch.int64): torch.int64, + (torch.int64, torch.int8): torch.int64, + (torch.int64, torch.int16): torch.int64, + (torch.int64, torch.int32): torch.int64, + (torch.int64, torch.int64): torch.int64, + # uints + (torch.uint8, torch.uint8): torch.uint8, + # ints and uints (mixed sign) + (torch.int8, torch.uint8): torch.int16, + (torch.int16, torch.uint8): torch.int16, + (torch.int32, torch.uint8): torch.int32, + (torch.int64, torch.uint8): torch.int64, + (torch.uint8, torch.int8): torch.int16, + (torch.uint8, torch.int16): torch.int16, + (torch.uint8, torch.int32): torch.int32, + (torch.uint8, torch.int64): torch.int64, + # floats + (torch.float32, torch.float32): torch.float32, + (torch.float32, torch.float64): torch.float64, + (torch.float64, torch.float32): torch.float64, + (torch.float64, torch.float64): torch.float64, +} + + +def _two_arg(f): + @wraps(f) + def _f(x1, x2, /, **kwargs): + x1, x2 = _fix_promotion(x1, x2) + return f(x1, x2, **kwargs) + if _f.__doc__ is None: + _f.__doc__ = f"""\ +Array API compatibility wrapper for torch.{f.__name__}. + +See the corresponding PyTorch documentation and/or the array API specification +for more details. + +""" + return _f + +def _fix_promotion(x1, x2, only_scalar=True): + if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: + return x1, x2 + # If an argument is 0-D pytorch downcasts the other argument + if not only_scalar or x1.shape == (): + dtype = result_type(x1, x2) + x2 = x2.to(dtype) + if not only_scalar or x2.shape == (): + dtype = result_type(x1, x2) + x1 = x1.to(dtype) + return x1, x2 + +def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype: + if len(arrays_and_dtypes) == 0: + raise TypeError("At least one array or dtype must be provided") + if len(arrays_and_dtypes) == 1: + x = arrays_and_dtypes[0] + if isinstance(x, torch.dtype): + return x + return x.dtype + if len(arrays_and_dtypes) > 2: + return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) + + x, y = arrays_and_dtypes + xdt = x.dtype if not isinstance(x, torch.dtype) else x + ydt = y.dtype if not isinstance(y, torch.dtype) else y + + if (xdt, ydt) in _promotion_table: + return _promotion_table[xdt, ydt] + + # This doesn't result_type(dtype, dtype) for non-array API dtypes + # because torch.result_type only accepts tensors. This does however, allow + # cross-kind promotion. + return torch.result_type(x, y) + +def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + if not isinstance(from_, torch.dtype): + from_ = from_.dtype + return torch.can_cast(from_, to) + +# Basic renames +permute_dims = torch.permute +bitwise_invert = torch.bitwise_not + +# Two-arg elementwise functions +# These require a wrapper to do the correct type promotion on 0-D tensors +add = _two_arg(torch.add) +atan2 = _two_arg(torch.atan2) +bitwise_and = _two_arg(torch.bitwise_and) +bitwise_left_shift = _two_arg(torch.bitwise_left_shift) +bitwise_or = _two_arg(torch.bitwise_or) +bitwise_right_shift = _two_arg(torch.bitwise_right_shift) +bitwise_xor = _two_arg(torch.bitwise_xor) +divide = _two_arg(torch.divide) +# Also a rename. torch.equal does not broadcast +equal = _two_arg(torch.eq) +floor_divide = _two_arg(torch.floor_divide) +greater = _two_arg(torch.greater) +greater_equal = _two_arg(torch.greater_equal) +less = _two_arg(torch.less) +less_equal = _two_arg(torch.less_equal) +logaddexp = _two_arg(torch.logaddexp) +# logical functions are not included here because they only accept bool in the +# spec, so type promotion is irrelevant. +multiply = _two_arg(torch.multiply) +not_equal = _two_arg(torch.not_equal) +pow = _two_arg(torch.pow) +remainder = _two_arg(torch.remainder) +subtract = _two_arg(torch.subtract) + +# These wrappers are mostly based on the fact that pytorch uses 'dim' instead +# of 'axis'. + +# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 +def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + return torch.amax(x, axis, keepdims=keepdims) + +def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + return torch.amin(x, axis, keepdims=keepdims) + +# torch.sort also returns a tuple +# https://github.com/pytorch/pytorch/issues/70921 +def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: + return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +def _normalize_axes(axis, ndim): + axes = [] + if ndim == 0 and axis: + # Better error message in this case + raise IndexError(f"Dimension out of range: {axis[0]}") + lower, upper = -ndim, ndim - 1 + for a in axis: + if a < lower or a > upper: + # Match torch error message (e.g., from sum()) + raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}") + if a < 0: + a = a + ndim + if a in axes: + # Use IndexError instead of RuntimeError, and "axis" instead of "dim" + raise IndexError(f"Axis {a} appears multiple times in the list of axes") + axes.append(a) + return sorted(axes) + +def _axis_none_keepdims(x, ndim, keepdims): + # Apply keepdims when axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + # Note that this is only valid for the axis=None case. + if keepdims: + for i in range(ndim): + x = torch.unsqueeze(x, 0) + return x + +def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): + # Some reductions don't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + axes = _normalize_axes(axis, x.ndim) + for a in reversed(axes): + x = torch.movedim(x, a, -1) + x = torch.flatten(x, -len(axes)) + + out = f(x, -1, **kwargs) + + if keepdims: + for a in axes: + out = torch.unsqueeze(out, a) + return out + +def prod(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + + # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic + # below because it still needs to upcast. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: + return x.to(torch.int64) + return x.clone() + return x.to(dtype) + + # torch.prod doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.prod(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) + + +def sum(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype: Optional[Dtype] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + + # https://github.com/pytorch/pytorch/issues/29137. + # Make sure it upcasts. + if axis == (): + if dtype is None: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what sum does + # when axis=None. + if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: + return x.to(torch.int64) + return x.clone() + return x.to(dtype) + + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.sum(x, dtype=dtype, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res + + return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) + +def any(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + if axis == (): + return x.to(torch.bool) + # torch.any doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs) + return res.to(torch.bool) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.any(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(torch.bool) + + # torch.any doesn't return bool for uint8 + return torch.any(x, axis, keepdims=keepdims).to(torch.bool) + +def all(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + x = torch.asarray(x) + ndim = x.ndim + if axis == (): + return x.to(torch.bool) + # torch.all doesn't support multiple axes + # (https://github.com/pytorch/pytorch/issues/56586). + if isinstance(axis, tuple): + res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs) + return res.to(torch.bool) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.all(x, **kwargs) + res = _axis_none_keepdims(res, ndim, keepdims) + return res.to(torch.bool) + + # torch.all doesn't return bool for uint8 + return torch.all(x, axis, keepdims=keepdims).to(torch.bool) + +def mean(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + **kwargs) -> array: + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.clone(x) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.mean(x, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.mean(x, axis, keepdims=keepdims, **kwargs) + +def std(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs) -> array: + # Note, float correction is not supported + # https://github.com/pytorch/pytorch/issues/61492. We don't try to + # implement it here for now. + + # if isinstance(correction, float): + # correction = int(correction) + + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.std(x, tuple(range(x.ndim)), correction=correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.std(x, axis, correction=correction, keepdims=keepdims, **kwargs) + +def var(x: array, + /, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + correction: Union[int, float] = 0.0, + keepdims: bool = False, + **kwargs) -> array: + # Note, float correction is not supported + # https://github.com/pytorch/pytorch/issues/61492. We don't try to + # implement it here for now. + + # if isinstance(correction, float): + # correction = int(correction) + + # https://github.com/pytorch/pytorch/issues/29137 + if axis == (): + return torch.zeros_like(x) + if isinstance(axis, int): + axis = (axis,) + if axis is None: + # torch doesn't support keepdims with axis=None + # (https://github.com/pytorch/pytorch/issues/71209) + res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs) + res = _axis_none_keepdims(res, x.ndim, keepdims) + return res + return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs) + +# torch.concat doesn't support dim=None +# https://github.com/pytorch/pytorch/issues/70925 +def concat(arrays: Union[Tuple[array, ...], List[array]], + /, + *, + axis: Optional[int] = 0, + **kwargs) -> array: + if axis is None: + arrays = tuple(ar.flatten() for ar in arrays) + axis = 0 + return torch.concat(arrays, axis, **kwargs) + +# torch.squeeze only accepts int dim and doesn't require it +# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was +# added at https://github.com/pytorch/pytorch/pull/89017. +def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: + if isinstance(axis, int): + axis = (axis,) + for a in axis: + if x.shape[a] != 1: + raise ValueError("squeezed dimensions must be equal to 1") + axes = _normalize_axes(axis, x.ndim) + # Remove this once pytorch 1.14 is released with the above PR #89017. + sequence = [a - i for i, a in enumerate(axes)] + for a in sequence: + x = torch.squeeze(x, a) + return x + +# The axis parameter doesn't work for flip() and roll() +# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't +# accept axis=None +def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: + if axis is None: + axis = tuple(range(x.ndim)) + # torch.flip doesn't accept dim as an int but the method does + # https://github.com/pytorch/pytorch/issues/18095 + return x.flip(axis) + +def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array: + return torch.roll(x, shift, axis) + +def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: + return torch.nonzero(x, as_tuple=True, **kwargs) + +def where(condition: array, x1: array, x2: array, /) -> array: + x1, x2 = _fix_promotion(x1, x2) + return torch.where(condition, x1, x2) + +# torch.arange doesn't support returning empty arrays +# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some +# keyword argument combinations +# (https://github.com/pytorch/pytorch/issues/70914) +def arange(start: Union[int, float], + /, + stop: Optional[Union[int, float]] = None, + step: Union[int, float] = 1, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if stop is None: + start, stop = 0, start + if step > 0 and stop <= start or step < 0 and stop >= start: + if dtype is None: + if builtin_all(isinstance(i, int) for i in [start, stop, step]): + dtype = torch.int64 + else: + dtype = torch.float32 + return torch.empty(0, dtype=dtype, device=device, **kwargs) + return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + +# torch.eye does not accept None as a default for the second argument and +# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) +def eye(n_rows: int, + n_cols: Optional[int] = None, + /, + *, + k: int = 0, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if n_cols is None: + n_cols = n_rows + z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) + if abs(k) <= n_rows + n_cols: + z.diagonal(k).fill_(1) + return z + +# torch.linspace doesn't have the endpoint parameter +def linspace(start: Union[int, float], + stop: Union[int, float], + /, + num: int, + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + endpoint: bool = True, + **kwargs) -> array: + if not endpoint: + return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] + return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) + +# torch.full does not accept an int size +# https://github.com/pytorch/pytorch/issues/70906 +def full(shape: Union[int, Tuple[int, ...]], + fill_value: Union[bool, int, float, complex], + *, + dtype: Optional[Dtype] = None, + device: Optional[Device] = None, + **kwargs) -> array: + if isinstance(shape, int): + shape = (shape,) + + return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) + +# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 +def expand_dims(x: array, /, *, axis: int = 0) -> array: + return torch.unsqueeze(x, axis) + +def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: + return x.to(dtype, copy=copy) + +def broadcast_arrays(*arrays: array) -> List[array]: + shape = torch.broadcast_shapes(*[a.shape for a in arrays]) + return [torch.broadcast_to(a, shape) for a in arrays] + +# https://github.com/pytorch/pytorch/issues/70920 +def unique_all(x: array) -> UniqueAllResult: + # torch.unique doesn't support returning indices. + # https://github.com/pytorch/pytorch/issues/36748. The workaround + # suggested in that issue doesn't actually function correctly (it relies + # on non-deterministic behavior of scatter()). + raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)") + + # values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True) + # # torch.unique incorrectly gives a 0 count for nan values. + # # https://github.com/pytorch/pytorch/issues/94106 + # counts[torch.isnan(values)] = 1 + # return UniqueAllResult(values, indices, inverse_indices, counts) + +def unique_counts(x: array) -> UniqueCountsResult: + values, counts = torch.unique(x, return_counts=True) + + # torch.unique incorrectly gives a 0 count for nan values. + # https://github.com/pytorch/pytorch/issues/94106 + counts[torch.isnan(values)] = 1 + return UniqueCountsResult(values, counts) + +def unique_inverse(x: array) -> UniqueInverseResult: + values, inverse = torch.unique(x, return_inverse=True) + return UniqueInverseResult(values, inverse) + +def unique_values(x: array) -> array: + return torch.unique(x) + +def matmul(x1: array, x2: array, /, **kwargs) -> array: + # torch.matmul doesn't type promote (but differently from _fix_promotion) + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.matmul(x1, x2, **kwargs) + +matrix_transpose = get_xp(torch)(_aliases_matrix_transpose) +_vecdot = get_xp(torch)(_aliases_vecdot) + +def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return _vecdot(x1, x2, axis=axis) + +# torch.tensordot uses dims instead of axes +def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: + # Note: torch.tensordot fails with integer dtypes when there is only 1 + # element in the axis (https://github.com/pytorch/pytorch/issues/84530). + x1, x2 = _fix_promotion(x1, x2, only_scalar=False) + return torch.tensordot(x1, x2, dims=axes, **kwargs) + +__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', + 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', + 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', + 'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal', + 'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder', + 'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all', + 'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll', + 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', + 'expand_dims', 'astype', 'broadcast_arrays', 'unique_all', + 'unique_counts', 'unique_inverse', 'unique_values', + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot'] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py new file mode 100644 index 00000000..8d223fd4 --- /dev/null +++ b/array_api_compat/torch/linalg.py @@ -0,0 +1 @@ +raise ImportError("The array api compat torch.linalg module extension is not yet implemented") diff --git a/cupy-skips.txt b/cupy-skips.txt new file mode 100644 index 00000000..e618603f --- /dev/null +++ b/cupy-skips.txt @@ -0,0 +1,3 @@ +# Hangs +array_api_tests/test_linalg.py::test_qr +array_api_tests/test_linalg.py::test_matrix_rank diff --git a/cupy-xfails.txt b/cupy-xfails.txt new file mode 100644 index 00000000..86b24611 --- /dev/null +++ b/cupy-xfails.txt @@ -0,0 +1,172 @@ +# cupy doesn't have __index__ (and we cannot wrap the ndarray object) +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint8)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(int8)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(int16)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(int32)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] + +# testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) +array_api_tests/test_array_object.py::test_getitem + +# copy=False is not yet implemented +array_api_tests/test_creation_functions.py::test_asarray_arrays + +# finfo test is testing that the result is a float instead of float32 (see +# also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32] + +# Some array attributes are missing, and we do not wrap the array object +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-__index__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] +array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] + +# Some linalg tests depend on .mT instead of matrix_transpose() +# and some require https://github.com/data-apis/array-api-tests/pull/101 to +array_api_tests/test_linalg.py::test_eigvalsh +array_api_tests/test_linalg.py::test_matrix_norm +array_api_tests/test_linalg.py::test_matrix_power +array_api_tests/test_linalg.py::test_solve +array_api_tests/test_linalg.py::test_svd +array_api_tests/test_linalg.py::test_svdvals +# cupy uses 2023.12 trace() behavior https://github.com/data-apis/array-api/pull/502 +array_api_tests/test_linalg.py::test_trace +# We cannot modify array methods +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] +# floating point inaccuracy +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] + +# cupy (arg)min/max wrong with infinities +# https://github.com/cupy/cupy/issues/7424 +array_api_tests/test_searching_functions.py::test_argmax +array_api_tests/test_searching_functions.py::test_argmin +array_api_tests/test_statistical_functions.py::test_min +array_api_tests/test_statistical_functions.py::test_max + +# testsuite incorrectly thinks meshgrid doesn't have indexing argument +# (https://github.com/data-apis/array-api-tests/issues/171) +array_api_tests/test_signatures.py::test_func_signature[meshgrid] + +# We cannot add array attributes +array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] +array_api_tests/test_signatures.py::test_array_method_signature[__index__] +array_api_tests/test_signatures.py::test_array_method_signature[to_device] + +# We do not attempt to workaround special cases (and the operator method ones + +array_api_tests/test_special_cases.py::test_unary[abs(x_i is -0) -> +0] +array_api_tests/test_special_cases.py::test_unary[__abs__(x_i is -0) -> +0] +array_api_tests/test_special_cases.py::test_unary[asin(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[asinh(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[atan(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[atanh(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[ceil(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[cos(x_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_unary[cosh(x_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_unary[exp(x_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_binary[add(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_binary[add(x1_i is -0 and x2_i is +0) -> +0] +array_api_tests/test_special_cases.py::test_binary[add(x1_i is +0 and x2_i is -0) -> +0] +array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -0 and x2_i is +0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is +0 and x2_i is -0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is +0 and x2_i is -0) -> roughly +pi] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i is +0) -> -0] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i is -0) -> roughly -pi] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i < 0) -> roughly -pi] +array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] +array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[pow(x2_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] +array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[pow(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x2_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i > 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i < 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i > 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i < 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is +0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is +0 and x2_i is -0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__iadd__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__itruediv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x2_i is -0) -> 1] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i > 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i < 0 and x2_i is -0) -> NaN] +array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] diff --git a/numpy-skips.txt b/numpy-skips.txt new file mode 100644 index 00000000..e69de29b diff --git a/test_cupy.sh b/test_cupy.sh new file mode 100755 index 00000000..d023a860 --- /dev/null +++ b/test_cupy.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# We cannot test cupy on CI so this script will test it manually. Assumes it +# is being run in an environment that has cupy and the array-api-tests +# dependencies installed +set -x +set -e + +# Run the vendoring tests in this repo +pytest + +tmpdir=$(mktemp -d) +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export PYTHONPATH=$SCRIPT_DIR + +PYTEST_ARGS="--max-examples 200 -v -rxXfE --ci" + +cd $tmpdir +git clone https://github.com/data-apis/array-api-tests +cd array-api-tests + +# Remove this once https://github.com/data-apis/array-api-tests/pull/157 is +# merged +git remote add asmeurer https://github.com/asmeurer/array-api-tests +git fetch asmeurer +git checkout asmeurer/xfails-file + +git submodule update --init + +# store the hypothesis examples database in this directory, so that failures +# will be remembered across runs +mkdir -p $SCRIPT_DIR/.hypothesis +ln -s $SCRIPT_DIR/.hypothesis .hypothesis + +export ARRAY_API_TESTS_MODULE=array_api_compat.cupy +pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt --skips-file $SCRIPT_DIR/cupy-skips.txt "$@" diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 85f68626..93a961aa 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -13,3 +13,7 @@ def test_vendoring_cupy(): from vendor_test import uses_cupy uses_cupy._test_cupy() + +def test_vendoring_torch(): + from vendor_test import uses_torch + uses_torch._test_torch() diff --git a/torch-skips.txt b/torch-skips.txt new file mode 100644 index 00000000..f2d0f202 --- /dev/null +++ b/torch-skips.txt @@ -0,0 +1,10 @@ +# These tests cause a core dump on CI, so we have to skip them entirely +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] diff --git a/torch-xfails.txt b/torch-xfails.txt new file mode 100644 index 00000000..f67ca189 --- /dev/null +++ b/torch-xfails.txt @@ -0,0 +1,195 @@ +# Note: see array_api_compat/torch/_aliases.py for links to corresponding +# pytorch issues + +# We cannot wrap the array object + +# Indexing does not support negative step +array_api_tests/test_array_object.py::test_getitem +array_api_tests/test_array_object.py::test_setitem +# Masking doesn't suport 0 dimensions in the mask +array_api_tests/test_array_object.py::test_getitem_masking +# torch doesn't have uint dtypes other than uint8 +array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)] +array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)] +array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)] +array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)] + +# Overflow error from large inputs +array_api_tests/test_creation_functions.py::test_arange +# pytorch linspace bug (should be fixed in torch 2.0) +array_api_tests/test_creation_functions.py::test_linspace + +# torch doesn't have higher uint dtypes +array_api_tests/test_data_type_functions.py::test_iinfo[uint16] +array_api_tests/test_data_type_functions.py::test_iinfo[uint32] +array_api_tests/test_data_type_functions.py::test_iinfo[uint64] + +# --disable-extension broken with test_has_names.py +# https://github.com/data-apis/array-api-tests/issues/169 +array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose] +array_api_tests/test_has_names.py::test_has_names[linalg-outer] +array_api_tests/test_has_names.py::test_has_names[linalg-tensordot] +array_api_tests/test_has_names.py::test_has_names[linalg-trace] + +# We cannot wrap the tensor object +array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] +array_api_tests/test_has_names.py::test_has_names[array_method-to_device] + + +# tensordot doesn't allow integer dtypes in some corner cases +array_api_tests/test_linalg.py::test_tensordot + +# A numerical difference in stacking (will be fixed by +# https://github.com/data-apis/array-api-tests/pull/101) +array_api_tests/test_linalg.py::test_matmul +# We cannot wrap the tensor object +array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[__and__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[__or__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__xor__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[__le__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x1, x2)] +# This test is skipped instead of xfailed because it causes core dumps on CI +# array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] + + +# Mac-only bug (overflow near float max) +# array_api_tests/test_operators_and_elementwise_functions.py::test_log1p + +# torch doesn't handle shifting by more than the bit size correctly +# https://github.com/pytorch/pytorch/issues/70904 +array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[bitwise_right_shift(x1, x2)] +# Torch bug for remainder in some cases with large values +array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] + +# unique_all cannot be implemented because torch's unique does not support +# returning indices +array_api_tests/test_set_functions.py::test_unique_all +# unique_inverse incorrectly counts nan values +# (https://github.com/pytorch/pytorch/issues/94106) +array_api_tests/test_set_functions.py::test_unique_inverse + +# The test suite incorrectly divides by 0 here +# (https://github.com/data-apis/array-api-tests/issues/170) +array_api_tests/test_signatures.py::test_func_signature[floor_divide] +array_api_tests/test_signatures.py::test_func_signature[remainder] +array_api_tests/test_signatures.py::test_array_method_signature[__mod__] + +# We cannot add attributes to the tensor object +array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] +array_api_tests/test_signatures.py::test_array_method_signature[to_device] + + +# We do not attempt to work around special-case differences (most are on +# tensor methods which we couldn't fix anyway). +array_api_tests/test_special_cases.py::test_binary[add(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is +infinity and isfinite(x2_i)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__add__(x1_i is -infinity and isfinite(x2_i)) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is +infinity) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x2_i is -infinity) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__add__((x1_i is +0 or x1_i == -0) and isfinite(x2_i) and x2_i != 0) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and (x2_i is +0 or x2_i == -0)) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[__add__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__truediv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is +infinity) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) > 1 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is +infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(abs(x1_i) < 1 and x2_i is -infinity) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +infinity and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is NaN and not x2_i == 0) -> NaN] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] +array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] +array_api_tests/test_special_cases.py::test_iop[__iadd__(isfinite(x1_i) and x1_i != 0 and x2_i == -x1_i) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] + +# Float correction is not supported by pytorch +# (https://github.com/data-apis/array-api-tests/issues/168) +array_api_tests/test_special_cases.py::test_empty_arrays[std] +array_api_tests/test_special_cases.py::test_empty_arrays[var] +array_api_tests/test_special_cases.py::test_nan_propagation[std] +array_api_tests/test_special_cases.py::test_nan_propagation[var] +array_api_tests/test_statistical_functions.py::test_std +array_api_tests/test_statistical_functions.py::test_var + +# The test suite is incorrectly checking sums that have loss of significance +# (https://github.com/data-apis/array-api-tests/issues/168) +array_api_tests/test_statistical_functions.py::test_sum diff --git a/vendor_test/uses_torch.py b/vendor_test/uses_torch.py new file mode 100644 index 00000000..b828ad33 --- /dev/null +++ b/vendor_test/uses_torch.py @@ -0,0 +1,22 @@ +# Basic test that vendoring works + +from .vendored._compat import torch as torch_compat + +import torch + +def _test_torch(): + a = torch_compat.asarray([1., 2., 3.]) + b = torch_compat.arange(3, dtype=torch_compat.float64) + assert a.dtype == torch_compat.float32 == torch.float32 + assert b.dtype == torch_compat.float64 == torch.float64 + + # torch.expand_dims does not exist. Update this to use something else if it is added + res = torch_compat.expand_dims(a, axis=0) + assert res.dtype == torch_compat.float32 == torch.float32 + assert res.shape == (1, 3) + assert isinstance(res.shape, torch.Size) + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(res, torch.Tensor) + + torch.testing.assert_allclose(res, [[1., 2., 3.]])