Skip to content

Commit b10273b

Browse files
update code
1 parent 5ae8ec8 commit b10273b

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,9 @@ def _normalize_axes(axis, ndim):
420420
for a in axis:
421421
if a < lower or a > upper:
422422
# Match paddle error message (e.g., from sum())
423-
raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
423+
raise IndexError(
424+
f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}"
425+
)
424426
if a < 0:
425427
a = a + ndim
426428
if a in axes:
@@ -480,7 +482,9 @@ def prod(
480482

481483
# paddle.prod doesn't support multiple axes
482484
if isinstance(axis, tuple):
483-
return _reduce_multiple_axes(paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs)
485+
return _reduce_multiple_axes(
486+
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
487+
)
484488
if axis is None:
485489
# paddle doesn't support keepdims with axis=None
486490
res = paddle.prod(x, dtype=dtype, **kwargs)
@@ -610,7 +614,9 @@ def std(
610614
if isinstance(correction, float):
611615
_correction = int(correction)
612616
if correction != _correction:
613-
raise NotImplementedError("float correction in paddle std() is not yet supported")
617+
raise NotImplementedError(
618+
"float correction in paddle std() is not yet supported"
619+
)
614620
elif isinstance(correction, int):
615621
if correction not in [0, 1]:
616622
raise NotImplementedError("correction only can be 0 or 1")
@@ -648,7 +654,9 @@ def var(
648654
if isinstance(correction, float):
649655
_correction = int(correction)
650656
if correction != _correction:
651-
raise NotImplementedError("float correction in paddle std() is not yet supported")
657+
raise NotImplementedError(
658+
"float correction in paddle std() is not yet supported"
659+
)
652660
elif isinstance(correction, int):
653661
if correction not in [0, 1]:
654662
raise NotImplementedError("correction only can be 0 or 1")
@@ -709,7 +717,9 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
709717

710718
# The axis parameter doesn't work for flip() and roll()
711719
# accept axis=None
712-
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
720+
def flip(
721+
x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs
722+
) -> array:
713723
if axis is None:
714724
axis = tuple(range(x.ndim))
715725
# paddle.flip doesn't accept dim as an int but the method does
@@ -738,21 +748,27 @@ def where(condition: array, x1: array, x2: array, /) -> array:
738748
return paddle.where(condition, x1, x2)
739749

740750

741-
def empty_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
751+
def empty_like(
752+
x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
753+
) -> array:
742754
out = paddle.empty_like(x, dtype=dtype)
743755
if device is not None:
744756
out = out.to(device)
745757
return out
746758

747759

748-
def zeros_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
760+
def zeros_like(
761+
x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
762+
) -> array:
749763
out = paddle.zeros_like(x, dtype=dtype)
750764
if device is not None:
751765
out = out.to(device)
752766
return out
753767

754768

755-
def ones_like(x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> array:
769+
def ones_like(
770+
x: array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
771+
) -> array:
756772
out = paddle.ones_like(x, dtype=dtype)
757773
if device is not None:
758774
out = out.to(device)
@@ -774,7 +790,9 @@ def full_like(
774790

775791

776792
# paddle.reshape doesn't have the copy keyword
777-
def reshape(x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs) -> array:
793+
def reshape(
794+
x: array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, **kwargs
795+
) -> array:
778796
return paddle.reshape(x, shape, **kwargs)
779797

780798

@@ -825,7 +843,9 @@ def linspace(
825843
**kwargs,
826844
) -> array:
827845
if not endpoint:
828-
return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[:-1]
846+
return paddle.linspace(start, stop, num + 1, dtype=dtype, **kwargs).to(device)[
847+
:-1
848+
]
829849
return paddle.linspace(start, stop, num, dtype=dtype, **kwargs).to(device)
830850

831851

@@ -890,7 +910,9 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
890910
return paddle.unsqueeze(x, axis)
891911

892912

893-
def astype(x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None) -> array:
913+
def astype(
914+
x: array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = None
915+
) -> array:
894916
# if copy is not None:
895917
# raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
896918
t = x.to(dtype, device=device)
@@ -1036,7 +1058,7 @@ def sign(x: array, /) -> array:
10361058
else:
10371059
out = paddle.sign(x)
10381060
if paddle.is_floating_point(x):
1039-
out = paddle.where(paddle.isnan(x), paddle.nan, out)
1061+
out = paddle.where(paddle.isnan(x), paddle.full(x.shape, paddle.nan), out)
10401062
return out
10411063

10421064

@@ -1083,7 +1105,8 @@ def asarray(
10831105
return obj
10841106
else:
10851107
raise NotImplementedError(
1086-
"asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
1108+
"asarray(obj, ..., copy=False) is not supported "
1109+
"for obj do not has '__dlpack__()' method"
10871110
)
10881111
elif copy is True:
10891112
obj = np.array(obj, copy=True)
@@ -1164,11 +1187,18 @@ def _isscalar(a):
11641187

11651188

11661189
def cumulative_sum(
1167-
x: array, /, *, axis: Optional[int] = None, dtype: Optional[Dtype] = None, include_initial: bool = False
1190+
x: array,
1191+
/,
1192+
*,
1193+
axis: Optional[int] = None,
1194+
dtype: Optional[Dtype] = None,
1195+
include_initial: bool = False,
11681196
) -> array:
11691197
if axis is None:
11701198
if x.ndim > 1:
1171-
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
1199+
raise ValueError(
1200+
"axis must be specified in cumulative_sum for more than one dimension"
1201+
)
11721202
axis = 0
11731203

11741204
res = paddle.cumsum(x, axis=axis, dtype=dtype)
@@ -1185,7 +1215,12 @@ def cumulative_sum(
11851215

11861216

11871217
def searchsorted(
1188-
x1: array, x2: array, /, *, side: Literal["left", "right"] = "left", sorter: array | None = None
1218+
x1: array,
1219+
x2: array,
1220+
/,
1221+
*,
1222+
side: Literal["left", "right"] = "left",
1223+
sorter: array | None = None,
11891224
) -> array:
11901225
if sorter is None:
11911226
return paddle.searchsorted(x1, x2, right=(side == "right"))

array_api_compat/paddle/_info.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,16 @@ def default_dtypes(self, *, device=None):
154154
# value here because this error doesn't represent a different default
155155
# per-device.
156156
default_floating = paddle.get_default_dtype()
157-
default_complex = "complex64" if default_floating == "float32" else "complex128"
158-
default_integral = "int64"
157+
if default_floating in ["float16", "float32", "float64", "bfloat16"]:
158+
default_floating = getattr(paddle, default_floating)
159+
else:
160+
raise ValueError(f"Unsupported default floating: {default_floating}")
161+
default_complex = (
162+
paddle.complex64
163+
if default_floating == paddle.float32
164+
else paddle.complex128
165+
)
166+
default_integral = paddle.int64
159167
return {
160168
"real floating": default_floating,
161169
"complex floating": default_complex,
@@ -336,8 +344,14 @@ def devices(self):
336344
except ValueError as e:
337345
# The error message is something like:
338346
# ValueError: The device must be a string which is like 'cpu', 'gpu', 'gpu:x', 'xpu', 'xpu:x', 'npu', 'npu:x
339-
devices_names = e.args[0].split("The device must be a string which is like ")[1].split(", ")
340-
devices_names = [name.strip("'") for name in devices_names if ":" not in name]
347+
devices_names = (
348+
e.args[0]
349+
.split("The device must be a string which is like ")[1]
350+
.split(", ")
351+
)
352+
devices_names = [
353+
name.strip("'") for name in devices_names if ":" not in name
354+
]
341355

342356
# Next we need to check for different indices for different devices.
343357
# device(device_name, index=index) doesn't actually check if the

0 commit comments

Comments
 (0)