Skip to content

Commit 6841758

Browse files
committed
address more comments
1 parent 565666a commit 6841758

File tree

5 files changed

+27
-11
lines changed

5 files changed

+27
-11
lines changed

array_api_compat/common/_aliases.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,17 +325,25 @@ def _asarray(
325325
# copy=False is not yet implemented in xp.asarray
326326
raise NotImplementedError("copy=False is not yet implemented")
327327
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
328+
#print('hit me')
328329
if dtype is not None and obj.dtype != dtype:
329330
copy = True
331+
#print(copy)
330332
if copy in COPY_TRUE:
331333
copy_kwargs = {}
332334
if namespace != "dask.array":
333335
copy_kwargs["copy"] = True
334336
else:
335337
# No copy kw in dask.asarray so we go thorugh np.asarray first
336338
# (like dask also does) but copy after
339+
if dtype is None:
340+
# Same dtype copy is no-op in dask
341+
#print("in here?")
342+
return obj.copy()
337343
import numpy as np
344+
#print(obj)
338345
obj = np.asarray(obj).copy()
346+
#print(obj)
339347
return xp.array(obj, dtype=dtype, **copy_kwargs)
340348
return obj
341349

array_api_compat/common/_linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ def matrix_rank(x: ndarray,
7777
# dimensional arrays.
7878
if x.ndim < 2:
7979
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
80-
S = xp.linalg.svdvals(x, **kwargs)
81-
#S = xp.linalg.svd(x, compute_uv=False, **kwargs)
80+
if hasattr(xp.linalg, "svdvals"):
81+
S = xp.linalg.svdvals(x, **kwargs)
82+
else:
83+
S = xp.linalg.svd(x, compute_uv=False, **kwargs)
8284
if rtol is None:
8385
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
8486
else:

array_api_compat/dask/array/_aliases.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,9 @@ def dask_arange(
120120
concatenate as concat,
121121
)
122122

123-
del da
123+
del da, partial
124+
125+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
126+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
127+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
128+
'bitwise_right_shift', 'concat', 'pow']

array_api_compat/dask/array/linalg.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,13 @@
1313
from typing import Optional, Union, Tuple
1414
from ...common._typing import ndarray, Device, Dtype
1515

16-
#cross = get_xp(da)(_linalg.cross)
17-
#outer = get_xp(da)(_linalg.outer)
1816
EighResult = _linalg.EighResult
1917
QRResult = _linalg.QRResult
2018
SlogdetResult = _linalg.SlogdetResult
2119
SVDResult = _linalg.SVDResult
2220
qr = get_xp(da)(_linalg.qr)
23-
#svd = get_xp(da)(_linalg.svd)
2421
cholesky = get_xp(da)(_linalg.cholesky)
2522
matrix_rank = get_xp(da)(_linalg.matrix_rank)
26-
#pinv = get_xp(da)(_linalg.pinv)
2723
matrix_norm = get_xp(da)(_linalg.matrix_norm)
2824

2925
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
@@ -34,9 +30,6 @@ def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
3430
vector_norm = get_xp(da)(_linalg.vector_norm)
3531
diagonal = get_xp(da)(_linalg.diagonal)
3632

37-
#__all__ = linalg_all + _linalg.__all__
38-
3933
del get_xp
4034
del da
41-
#del linalg_all
4235
del _linalg

dask-xfails.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,17 @@
99
#| Draw 1 (key): (slice(None, None, None), slice(None, None, None))
1010
#| Draw 2 (value): dask.array<zeros_like, shape=(0, 2), dtype=bool, chunksize=(0, 2), chunktype=numpy.ndarray>
1111

12-
# TODO: this also skips test_setitem_masking unnecessarily
12+
# Various shape mismatches e.g.
13+
ValueError: shape mismatch: value array of shape (0, 2) could not be broadcast to indexing result of shape (0, 2)
1314
array_api_tests/test_array_object.py::test_setitem
1415

16+
# Fails since bad upcast from uint8 -> int64
17+
# MRE:
18+
# a = da.array(0, dtype="uint8")
19+
# b = da.array(False)
20+
# a[b] = 0
21+
array_api_tests/test_array_object.py::test_setitem_masking
22+
1523
# Various indexing errors
1624
array_api_tests/test_array_object.py::test_getitem_masking
1725

0 commit comments

Comments
 (0)