Skip to content

Commit be4addb

Browse files
committed
MAINT: dask.array: cosmetic tweaks
1 parent adbb6ef commit be4addb

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from ...common import _aliases
4-
from ...common._helpers import _check_device
54

65
from ..._internal import get_xp
76

@@ -40,19 +39,25 @@
4039
isdtype = get_xp(np)(_aliases.isdtype)
4140
unstack = get_xp(da)(_aliases.unstack)
4241

42+
# da.astype doesn't respect copy=True
4343
def astype(
4444
x: Array,
4545
dtype: Dtype,
4646
/,
4747
*,
4848
copy: bool = True,
49-
device: Device | None = None
49+
device: Optional[Device] = None
5050
) -> Array:
51+
"""
52+
Array API compatibility wrapper for astype().
53+
54+
See the corresponding documentation in the array library and/or the array API
55+
specification for more details.
56+
"""
5157
# TODO: respect device keyword?
58+
5259
if not copy and dtype == x.dtype:
5360
return x
54-
# dask astype doesn't respect copy=True,
55-
# so call copy manually afterwards
5661
x = x.astype(dtype)
5762
return x.copy() if copy else x
5863

@@ -61,20 +66,24 @@ def astype(
6166
# This arange func is modified from the common one to
6267
# not pass stop/step as keyword arguments, which will cause
6368
# an error with dask
64-
65-
# TODO: delete the xp stuff, it shouldn't be necessary
66-
def _dask_arange(
69+
def arange(
6770
start: Union[int, float],
6871
/,
6972
stop: Optional[Union[int, float]] = None,
7073
step: Union[int, float] = 1,
7174
*,
72-
xp,
7375
dtype: Optional[Dtype] = None,
7476
device: Optional[Device] = None,
7577
**kwargs,
7678
) -> Array:
77-
_check_device(xp, device)
79+
"""
80+
Array API compatibility wrapper for arange().
81+
82+
See the corresponding documentation in the array library and/or the array API
83+
specification for more details.
84+
"""
85+
# TODO: respect device keyword?
86+
7887
args = [start]
7988
if stop is not None:
8089
args.append(stop)
@@ -83,13 +92,12 @@ def _dask_arange(
8392
# prepend the default value for start which is 0
8493
args.insert(0, 0)
8594
args.append(step)
86-
return xp.arange(*args, dtype=dtype, **kwargs)
8795

88-
arange = get_xp(da)(_dask_arange)
89-
eye = get_xp(da)(_aliases.eye)
96+
return da.arange(*args, dtype=dtype, **kwargs)
97+
9098

91-
linspace = get_xp(da)(_aliases.linspace)
9299
eye = get_xp(da)(_aliases.eye)
100+
linspace = get_xp(da)(_aliases.linspace)
93101
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
94102
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
95103
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
@@ -112,7 +120,6 @@ def _dask_arange(
112120
reshape = get_xp(da)(_aliases.reshape)
113121
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
114122
vecdot = get_xp(da)(_aliases.vecdot)
115-
116123
nonzero = get_xp(da)(_aliases.nonzero)
117124
ceil = get_xp(np)(_aliases.ceil)
118125
floor = get_xp(np)(_aliases.floor)
@@ -121,6 +128,7 @@ def _dask_arange(
121128
tensordot = get_xp(np)(_aliases.tensordot)
122129
sign = get_xp(np)(_aliases.sign)
123130

131+
124132
# asarray also adds the copy keyword, which is not present in numpy 1.0.
125133
def asarray(
126134
obj: Union[
@@ -135,7 +143,7 @@ def asarray(
135143
*,
136144
dtype: Optional[Dtype] = None,
137145
device: Optional[Device] = None,
138-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
146+
copy: Optional[Union[bool, np._CopyMode]] = None,
139147
**kwargs,
140148
) -> Array:
141149
"""
@@ -144,6 +152,8 @@ def asarray(
144152
See the corresponding documentation in the array library and/or the array API
145153
specification for more details.
146154
"""
155+
# TODO: respect device keyword?
156+
147157
if copy is False:
148158
# copy=False is not yet implemented in dask
149159
raise NotImplementedError("copy=False is not yet implemented")
@@ -184,38 +194,40 @@ def asarray(
184194
# Furthermore, the masking workaround in common._aliases.clip cannot work with
185195
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
186196
# now).
187-
@get_xp(da)
188197
def clip(
189198
x: Array,
190199
/,
191200
min: Optional[Union[int, float, Array]] = None,
192201
max: Optional[Union[int, float, Array]] = None,
193-
*,
194-
xp,
195202
) -> Array:
203+
"""
204+
Array API compatibility wrapper for clip().
205+
206+
See the corresponding documentation in the array library and/or the array API
207+
specification for more details.
208+
"""
196209
def _isscalar(a):
197210
return isinstance(a, (int, float, type(None)))
198211
min_shape = () if _isscalar(min) else min.shape
199212
max_shape = () if _isscalar(max) else max.shape
200213

201214
# TODO: This won't handle dask unknown shapes
202-
import numpy as np
203215
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
204216

205217
if min is not None:
206-
min = xp.broadcast_to(xp.asarray(min), result_shape)
218+
min = da.broadcast_to(da.asarray(min), result_shape)
207219
if max is not None:
208-
max = xp.broadcast_to(xp.asarray(max), result_shape)
220+
max = da.broadcast_to(da.asarray(max), result_shape)
209221

210222
if min is None and max is None:
211-
return xp.positive(x)
223+
return da.positive(x)
212224

213225
if min is None:
214-
return astype(xp.minimum(x, max), x.dtype)
226+
return astype(da.minimum(x, max), x.dtype)
215227
if max is None:
216-
return astype(xp.maximum(x, min), x.dtype)
228+
return astype(da.maximum(x, min), x.dtype)
217229

218-
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
230+
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
219231

220232
# exclude these from all since dask.array has no sorting functions
221233
_da_unsupported = ['sort', 'argsort']

0 commit comments

Comments
 (0)