-
Notifications
You must be signed in to change notification settings - Fork 35
MAINT: dask: cosmetic tweaks #236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from ...common import _aliases | ||
from ...common._helpers import _check_device | ||
|
||
from ..._internal import get_xp | ||
|
||
|
@@ -40,19 +39,25 @@ | |
isdtype = get_xp(np)(_aliases.isdtype) | ||
unstack = get_xp(da)(_aliases.unstack) | ||
|
||
# da.astype doesn't respect copy=True | ||
def astype( | ||
x: Array, | ||
dtype: Dtype, | ||
/, | ||
*, | ||
copy: bool = True, | ||
device: Device | None = None | ||
device: Optional[Device] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change only matters if someone runs mypy/pyright in python 3.9. Runtime is unaffected due to |
||
) -> Array: | ||
""" | ||
Array API compatibility wrapper for astype(). | ||
|
||
See the corresponding documentation in the array library and/or the array API | ||
specification for more details. | ||
""" | ||
# TODO: respect device keyword? | ||
|
||
if not copy and dtype == x.dtype: | ||
return x | ||
# dask astype doesn't respect copy=True, | ||
# so call copy manually afterwards | ||
x = x.astype(dtype) | ||
return x.copy() if copy else x | ||
|
||
|
@@ -61,20 +66,24 @@ def astype( | |
# This arange func is modified from the common one to | ||
# not pass stop/step as keyword arguments, which will cause | ||
# an error with dask | ||
|
||
# TODO: delete the xp stuff, it shouldn't be necessary | ||
def _dask_arange( | ||
def arange( | ||
start: Union[int, float], | ||
/, | ||
stop: Optional[Union[int, float]] = None, | ||
step: Union[int, float] = 1, | ||
*, | ||
xp, | ||
dtype: Optional[Dtype] = None, | ||
device: Optional[Device] = None, | ||
**kwargs, | ||
) -> Array: | ||
_check_device(xp, device) | ||
""" | ||
Array API compatibility wrapper for arange(). | ||
|
||
See the corresponding documentation in the array library and/or the array API | ||
specification for more details. | ||
""" | ||
# TODO: respect device keyword? | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
args = [start] | ||
if stop is not None: | ||
args.append(stop) | ||
|
@@ -83,13 +92,12 @@ def _dask_arange( | |
# prepend the default value for start which is 0 | ||
args.insert(0, 0) | ||
args.append(step) | ||
return xp.arange(*args, dtype=dtype, **kwargs) | ||
|
||
arange = get_xp(da)(_dask_arange) | ||
eye = get_xp(da)(_aliases.eye) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplicate |
||
return da.arange(*args, dtype=dtype, **kwargs) | ||
|
||
|
||
linspace = get_xp(da)(_aliases.linspace) | ||
eye = get_xp(da)(_aliases.eye) | ||
linspace = get_xp(da)(_aliases.linspace) | ||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult) | ||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult) | ||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult) | ||
|
@@ -112,7 +120,6 @@ def _dask_arange( | |
reshape = get_xp(da)(_aliases.reshape) | ||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose) | ||
vecdot = get_xp(da)(_aliases.vecdot) | ||
|
||
nonzero = get_xp(da)(_aliases.nonzero) | ||
ceil = get_xp(np)(_aliases.ceil) | ||
floor = get_xp(np)(_aliases.floor) | ||
|
@@ -121,6 +128,7 @@ def _dask_arange( | |
tensordot = get_xp(np)(_aliases.tensordot) | ||
sign = get_xp(np)(_aliases.sign) | ||
|
||
|
||
# asarray also adds the copy keyword, which is not present in numpy 1.0. | ||
def asarray( | ||
obj: Union[ | ||
|
@@ -135,7 +143,7 @@ def asarray( | |
*, | ||
dtype: Optional[Dtype] = None, | ||
device: Optional[Device] = None, | ||
copy: "Optional[Union[bool, np._CopyMode]]" = None, | ||
copy: Optional[Union[bool, np._CopyMode]] = None, | ||
**kwargs, | ||
) -> Array: | ||
""" | ||
|
@@ -144,6 +152,8 @@ def asarray( | |
See the corresponding documentation in the array library and/or the array API | ||
specification for more details. | ||
""" | ||
# TODO: respect device keyword? | ||
|
||
if isinstance(obj, da.Array): | ||
if dtype is not None and dtype != obj.dtype: | ||
if copy is False: | ||
|
@@ -183,38 +193,40 @@ def asarray( | |
# Furthermore, the masking workaround in common._aliases.clip cannot work with | ||
# dask (meaning uint64 promoting to float64 is going to just be unfixed for | ||
# now). | ||
@get_xp(da) | ||
def clip( | ||
x: Array, | ||
/, | ||
min: Optional[Union[int, float, Array]] = None, | ||
max: Optional[Union[int, float, Array]] = None, | ||
*, | ||
xp, | ||
) -> Array: | ||
""" | ||
Array API compatibility wrapper for clip(). | ||
|
||
See the corresponding documentation in the array library and/or the array API | ||
specification for more details. | ||
""" | ||
def _isscalar(a): | ||
return isinstance(a, (int, float, type(None))) | ||
min_shape = () if _isscalar(min) else min.shape | ||
max_shape = () if _isscalar(max) else max.shape | ||
|
||
# TODO: This won't handle dask unknown shapes | ||
import numpy as np | ||
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) | ||
|
||
if min is not None: | ||
min = xp.broadcast_to(xp.asarray(min), result_shape) | ||
min = da.broadcast_to(da.asarray(min), result_shape) | ||
if max is not None: | ||
max = xp.broadcast_to(xp.asarray(max), result_shape) | ||
max = da.broadcast_to(da.asarray(max), result_shape) | ||
|
||
if min is None and max is None: | ||
return xp.positive(x) | ||
return da.positive(x) | ||
|
||
if min is None: | ||
return astype(xp.minimum(x, max), x.dtype) | ||
return astype(da.minimum(x, max), x.dtype) | ||
if max is None: | ||
return astype(xp.maximum(x, min), x.dtype) | ||
return astype(da.maximum(x, min), x.dtype) | ||
|
||
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype) | ||
return astype(da.minimum(da.maximum(x, min), max), x.dtype) | ||
|
||
# exclude these from all since dask.array has no sorting functions | ||
_da_unsupported = ['sort', 'argsort'] | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does nothing for dask