Skip to content

Commit 9442237

Browse files
authored
Merge pull request #235 from crusaderky/dask_asarray
BUG: dask: `asarray` should not materialize the graph
2 parents adbb6ef + c94ec0b commit 9442237

File tree

4 files changed

+142
-27
lines changed

4 files changed

+142
-27
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,23 @@ def asarray(
144144
See the corresponding documentation in the array library and/or the array API
145145
specification for more details.
146146
"""
147+
if isinstance(obj, da.Array):
148+
if dtype is not None and dtype != obj.dtype:
149+
if copy is False:
150+
raise ValueError("Unable to avoid copy when changing dtype")
151+
obj = obj.astype(dtype)
152+
return obj.copy() if copy else obj
153+
147154
if copy is False:
148-
# copy=False is not yet implemented in dask
149-
raise NotImplementedError("copy=False is not yet implemented")
150-
elif copy is True:
151-
if isinstance(obj, da.Array) and dtype is None:
152-
return obj.copy()
153-
# Go through numpy, since dask copy is no-op by default
154-
obj = np.array(obj, dtype=dtype, copy=True)
155-
return da.array(obj, dtype=dtype)
156-
else:
157-
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
158-
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
159-
# see https://github.com/dask/dask/pull/11524/
160-
obj = np.array(obj, dtype=dtype, copy=True)
161-
return da.from_array(obj)
162-
return obj
163-
164-
return da.asarray(obj, dtype=dtype, **kwargs)
155+
raise NotImplementedError(
156+
"Unable to avoid copy when converting a non-dask object to dask"
157+
)
158+
159+
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
160+
# see https://github.com/dask/dask/pull/11524/
161+
obj = np.array(obj, dtype=dtype, copy=True)
162+
return da.from_array(obj)
163+
165164

166165
from dask.array import (
167166
# Element wise aliases

tests/test_all.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,7 @@ def test_all(library):
4040
all_names = module.__all__
4141

4242
if set(dir_names) != set(all_names):
43-
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
44-
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
43+
extra_dir = set(dir_names) - set(all_names)
44+
extra_all = set(all_names) - set(dir_names)
45+
assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}"
46+
assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}"

tests/test_common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ def test_asarray_copy(library):
226226
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
227227

228228
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
229-
supports_copy_false = False
230-
elif library in ['cupy', 'dask.array']:
231-
supports_copy_false = False
229+
supports_copy_false_other_ns = False
230+
supports_copy_false_same_ns = False
231+
elif library == 'cupy':
232+
supports_copy_false_other_ns = False
233+
supports_copy_false_same_ns = False
234+
elif library == 'dask.array':
235+
supports_copy_false_other_ns = False
236+
supports_copy_false_same_ns = True
232237
else:
233-
supports_copy_false = True
238+
supports_copy_false_other_ns = True
239+
supports_copy_false_same_ns = True
234240

235241
a = asarray([1])
236242
b = asarray(a, copy=True)
@@ -240,7 +246,7 @@ def test_asarray_copy(library):
240246
assert all(a[0] == 0)
241247

242248
a = asarray([1])
243-
if supports_copy_false:
249+
if supports_copy_false_same_ns:
244250
b = asarray(a, copy=False)
245251
assert is_lib_func(b)
246252
a[0] = 0
@@ -249,7 +255,7 @@ def test_asarray_copy(library):
249255
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
250256

251257
a = asarray([1])
252-
if supports_copy_false:
258+
if supports_copy_false_same_ns:
253259
pytest.raises(ValueError, lambda: asarray(a, copy=False,
254260
dtype=xp.float64))
255261
else:
@@ -281,7 +287,7 @@ def test_asarray_copy(library):
281287
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
282288
asarray(obj, copy=True) # No error
283289
asarray(obj, copy=None) # No error
284-
if supports_copy_false:
290+
if supports_copy_false_other_ns:
285291
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
286292
else:
287293
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
@@ -294,7 +300,7 @@ def test_asarray_copy(library):
294300
assert all(b[0] == 1.0)
295301

296302
a = array.array('f', [1.0])
297-
if supports_copy_false:
303+
if supports_copy_false_other_ns:
298304
b = asarray(a, copy=False)
299305
assert is_lib_func(b)
300306
a[0] = 0.0

tests/test_dask.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from contextlib import contextmanager
2+
3+
import dask
4+
import numpy as np
5+
import pytest
6+
import dask.array as da
7+
8+
from array_api_compat import array_namespace
9+
10+
11+
@pytest.fixture
12+
def xp():
13+
"""Fixture returning the wrapped dask namespace"""
14+
return array_namespace(da.empty(0))
15+
16+
17+
@contextmanager
18+
def assert_no_compute():
19+
"""
20+
Context manager that raises if at any point inside it anything calls compute()
21+
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
22+
"""
23+
def get(dsk, *args, **kwargs):
24+
raise AssertionError("Called compute() or persist()")
25+
26+
with dask.config.set(scheduler=get):
27+
yield
28+
29+
30+
def test_assert_no_compute():
31+
"""Test the assert_no_compute context manager"""
32+
a = da.asarray(True)
33+
with pytest.raises(AssertionError, match="Called compute"):
34+
with assert_no_compute():
35+
bool(a)
36+
37+
# Exiting the context manager restores the original scheduler
38+
assert bool(a) is True
39+
40+
41+
# Test no_compute for functions that use generic _aliases with xp=np
42+
43+
def test_unary_ops_no_compute(xp):
44+
with assert_no_compute():
45+
a = xp.asarray([1.5, -1.5])
46+
xp.ceil(a)
47+
xp.floor(a)
48+
xp.trunc(a)
49+
xp.sign(a)
50+
51+
52+
def test_matmul_tensordot_no_compute(xp):
53+
A = da.ones((4, 4), chunks=2)
54+
B = da.zeros((4, 4), chunks=2)
55+
with assert_no_compute():
56+
xp.matmul(A, B)
57+
xp.tensordot(A, B)
58+
59+
60+
# Test no_compute for functions that are fully bespoke for dask
61+
62+
def test_asarray_no_compute(xp):
63+
with assert_no_compute():
64+
a = xp.arange(10)
65+
xp.asarray(a)
66+
xp.asarray(a, dtype=np.int16)
67+
xp.asarray(a, dtype=a.dtype)
68+
xp.asarray(a, copy=True)
69+
xp.asarray(a, copy=True, dtype=np.int16)
70+
xp.asarray(a, copy=True, dtype=a.dtype)
71+
xp.asarray(a, copy=False)
72+
xp.asarray(a, copy=False, dtype=a.dtype)
73+
74+
75+
@pytest.mark.parametrize("copy", [True, False])
76+
def test_astype_no_compute(xp, copy):
77+
with assert_no_compute():
78+
a = xp.arange(10)
79+
xp.astype(a, np.int16, copy=copy)
80+
xp.astype(a, a.dtype, copy=copy)
81+
82+
83+
def test_clip_no_compute(xp):
84+
with assert_no_compute():
85+
a = xp.arange(10)
86+
xp.clip(a)
87+
xp.clip(a, 1)
88+
xp.clip(a, 1, 8)
89+
90+
91+
def test_generators_are_lazy(xp):
92+
"""
93+
Test that generator functions are fully lazy, e.g. that
94+
da.ones(n) is not implemented as da.asarray(np.ones(n))
95+
"""
96+
size = 100_000_000_000 # 800 GB
97+
chunks = size // 10 # 10x 80 GB chunks
98+
99+
with assert_no_compute():
100+
xp.zeros(size, chunks=chunks)
101+
xp.ones(size, chunks=chunks)
102+
xp.empty(size, chunks=chunks)
103+
xp.full(size, fill_value=123, chunks=chunks)
104+
a = xp.arange(size, chunks=chunks)
105+
xp.zeros_like(a)
106+
xp.ones_like(a)
107+
xp.empty_like(a)
108+
xp.full_like(a, fill_value=123)

0 commit comments

Comments
 (0)