Skip to content

Commit 873eeff

Browse files
committed
Update function stubs from the spec
1 parent 075150b commit 873eeff

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

array_api_tests/function_stubs/creation_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from ._types import (List, NestedSequence, Optional, SupportsBufferProtocol, SupportsDLPack, Tuple,
1414
Union, array, device, dtype)
15-
from collections.abc import Sequence
1615

1716
def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
1817
pass
@@ -26,7 +25,7 @@ def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None,
2625
def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
2726
pass
2827

29-
def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
28+
def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: int = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
3029
pass
3130

3231
def from_dlpack(x: object, /) -> array:
@@ -41,7 +40,7 @@ def full_like(x: array, /, fill_value: Union[int, float], *, dtype: Optional[dty
4140
def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array:
4241
pass
4342

44-
def meshgrid(*arrays: Sequence[array], indexing: str = 'xy') -> List[array, ...]:
43+
def meshgrid(*arrays: array, indexing: str = 'xy') -> List[array, ...]:
4544
pass
4645

4746
def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:

array_api_tests/function_stubs/data_type_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from __future__ import annotations
1212

1313
from ._types import List, Tuple, Union, array, dtype, finfo_object, iinfo_object
14-
from collections.abc import Sequence
1514

16-
def broadcast_arrays(*arrays: Sequence[array]) -> List[array]:
15+
def broadcast_arrays(*arrays: array) -> List[array]:
1716
pass
1817

1918
def broadcast_to(x: array, /, shape: Tuple[int, ...]) -> array:
@@ -28,7 +27,7 @@ def finfo(type: Union[dtype, array], /) -> finfo_object:
2827
def iinfo(type: Union[dtype, array], /) -> iinfo_object:
2928
pass
3029

31-
def result_type(*arrays_and_dtypes: Sequence[Union[array, dtype]]) -> dtype:
30+
def result_type(*arrays_and_dtypes: Union[array, dtype]) -> dtype:
3231
pass
3332

3433
__all__ = ['broadcast_arrays', 'broadcast_to', 'can_cast', 'finfo', 'iinfo', 'result_type']

array_api_tests/function_stubs/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def solve(x1: array, x2: array, /) -> array:
6868
def svd(x: array, /, *, full_matrices: bool = True) -> Tuple[array, array, array]:
6969
pass
7070

71-
def svdvals(x: array, /) -> Union[array, Tuple[array, ...]]:
71+
def svdvals(x: array, /) -> array:
7272
pass
7373

7474
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
@@ -77,10 +77,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
7777
def trace(x: array, /, *, offset: int = 0) -> array:
7878
pass
7979

80-
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
80+
def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
8181
pass
8282

83-
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf]]] = 2) -> array:
83+
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Union[int, float, Literal[inf, -inf]] = 2) -> array:
8484
pass
8585

8686
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

array_api_tests/function_stubs/linear_algebra_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from ._types import Optional, Tuple, Union, array
13+
from ._types import Tuple, Union, array
1414
from collections.abc import Sequence
1515

1616
def matmul(x1: array, x2: array, /) -> array:
@@ -22,7 +22,7 @@ def matrix_transpose(x: array, /) -> array:
2222
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:
2323
pass
2424

25-
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
25+
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
2626
pass
2727

2828
__all__ = ['matmul', 'matrix_transpose', 'tensordot', 'vecdot']

0 commit comments

Comments
 (0)