Skip to content

Commit 4be5517

Browse files
committed
fix more tests
1 parent 6841758 commit 4be5517

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ def dask_arange(
103103
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
104104
vecdot = get_xp(da)(_aliases.vecdot)
105105

106+
nonzero = get_xp(da)(_aliases.nonzero)
107+
sum = get_xp(np)(_aliases.sum)
108+
prod = get_xp(np)(_aliases.prod)
109+
ceil = get_xp(np)(_aliases.ceil)
110+
floor = get_xp(np)(_aliases.floor)
111+
trunc = get_xp(np)(_aliases.trunc)
112+
matmul = get_xp(np)(_aliases.matmul)
113+
tensordot = get_xp(np)(_aliases.tensordot)
114+
106115
from dask.array import (
107116
# Element wise aliases
108117
arccos as acos,
@@ -120,9 +129,17 @@ def dask_arange(
120129
concatenate as concat,
121130
)
122131

123-
del da, partial
132+
# exclude these from all since
133+
_da_unsupported = ['sort', 'argsort']
134+
135+
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
136+
137+
__all__ = common_aliases + ['asarray', 'bool', 'acos',
138+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
139+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
140+
'bitwise_right_shift', 'concat', 'pow',
141+
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
142+
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
143+
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
124144

125-
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
126-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
127-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
128-
'bitwise_right_shift', 'concat', 'pow']
145+
del da, partial, common_aliases, _da_unsupported,

dask-xfails.txt

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
5959
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
6060
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
6161

62-
# dask doesn't return int when input is already int for ceil/floor/trunc
63-
# Use $ to denote end of regex so we don't xfail other tests accidentally
64-
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
65-
# TODO: this xfails more than it should ... (e.g. test_floor_divide works)
66-
array_api_tests/test_operators_and_elementwise_functions.py::test_floor
67-
array_api_tests/test_operators_and_elementwise_functions.py::test_trunc
68-
69-
# Dask doesn't raise an error for this test
70-
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
71-
7262
# Fails because shape is NaN since we don't materialize it yet
7363
array_api_tests/test_searching_functions.py::test_nonzero
7464
array_api_tests/test_set_functions.py::test_unique_all
@@ -80,7 +70,6 @@ array_api_tests/test_set_functions.py::test_unique_values
8070

8171
# Linalg failures (signature failures/missing methods)
8272

83-
8473
# fails for ndim > 2
8574
array_api_tests/test_linalg.py::test_svdvals
8675
array_api_tests/test_linalg.py::test_cholesky
@@ -139,9 +128,5 @@ array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__
139128
# Some cases unsupported by dask
140129
array_api_tests/test_manipulation_functions.py::test_roll
141130

142-
# Dtype doesn't match (output is float32 but should be float64)
143-
array_api_tests/test_statistical_functions.py::test_prod
144-
array_api_tests/test_statistical_functions.py::test_sum
145-
146131
# No mT on dask array
147132
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices

0 commit comments

Comments
 (0)