Skip to content

MAINT: dask: cosmetic tweaks #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from ...common import _aliases
from ...common._helpers import _check_device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does nothing for dask


from ..._internal import get_xp

Expand Down Expand Up @@ -40,19 +39,25 @@
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)

# da.astype doesn't respect copy=True
def astype(
x: Array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Device | None = None
device: Optional[Device] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change only matters if someone runs mypy/pyright in python 3.9. Runtime is unaffected due to from __future__ import annotations.

) -> Array:
"""
Array API compatibility wrapper for astype().

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

if not copy and dtype == x.dtype:
return x
# dask astype doesn't respect copy=True,
# so call copy manually afterwards
x = x.astype(dtype)
return x.copy() if copy else x

Expand All @@ -61,20 +66,24 @@ def astype(
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask

# TODO: delete the xp stuff, it shouldn't be necessary
def _dask_arange(
def arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
"""
Array API compatibility wrapper for arange().

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

args = [start]
if stop is not None:
args.append(stop)
Expand All @@ -83,13 +92,12 @@ def _dask_arange(
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)

arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate

return da.arange(*args, dtype=dtype, **kwargs)


linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
linspace = get_xp(da)(_aliases.linspace)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
Expand All @@ -112,7 +120,6 @@ def _dask_arange(
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)

nonzero = get_xp(da)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
Expand All @@ -121,6 +128,7 @@ def _dask_arange(
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)


# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: Union[
Expand All @@ -135,7 +143,7 @@ def asarray(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
copy: Optional[Union[bool, np._CopyMode]] = None,
**kwargs,
) -> Array:
"""
Expand All @@ -144,6 +152,8 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
# TODO: respect device keyword?

if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
Expand Down Expand Up @@ -183,38 +193,40 @@ def asarray(
# Furthermore, the masking workaround in common._aliases.clip cannot work with
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
# now).
@get_xp(da)
def clip(
x: Array,
/,
min: Optional[Union[int, float, Array]] = None,
max: Optional[Union[int, float, Array]] = None,
*,
xp,
) -> Array:
"""
Array API compatibility wrapper for clip().

See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
def _isscalar(a):
return isinstance(a, (int, float, type(None)))
min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape

# TODO: This won't handle dask unknown shapes
import numpy as np
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)

if min is not None:
min = xp.broadcast_to(xp.asarray(min), result_shape)
min = da.broadcast_to(da.asarray(min), result_shape)
if max is not None:
max = xp.broadcast_to(xp.asarray(max), result_shape)
max = da.broadcast_to(da.asarray(max), result_shape)

if min is None and max is None:
return xp.positive(x)
return da.positive(x)

if min is None:
return astype(xp.minimum(x, max), x.dtype)
return astype(da.minimum(x, max), x.dtype)
if max is None:
return astype(xp.maximum(x, min), x.dtype)
return astype(da.maximum(x, min), x.dtype)

return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
return astype(da.minimum(da.maximum(x, min), max), x.dtype)

# exclude these from all since dask.array has no sorting functions
_da_unsupported = ['sort', 'argsort']
Expand Down
Loading