Skip to content

Commit 7b8555e

Browse files
fix default floating dtype of paddle.assaray
1 parent 8d2425e commit 7b8555e

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,16 @@ def is_complex(dtype):
10271027
elif kind == "integral":
10281028
return dtype in _int_dtypes
10291029
elif kind == "real floating":
1030-
return paddle.is_floating_point(dtype)
1030+
return dtype in [
1031+
paddle.framework.core.VarDesc.VarType.FP32,
1032+
paddle.framework.core.VarDesc.VarType.FP64,
1033+
paddle.framework.core.VarDesc.VarType.FP16,
1034+
paddle.framework.core.VarDesc.VarType.BF16,
1035+
paddle.framework.core.DataType.FLOAT32,
1036+
paddle.framework.core.DataType.FLOAT64,
1037+
paddle.framework.core.DataType.FLOAT16,
1038+
paddle.framework.core.DataType.BFLOAT16,
1039+
]
10311040
elif kind == "complex floating":
10321041
return is_complex(dtype)
10331042
elif kind == "numeric":
@@ -1109,10 +1118,14 @@ def asarray(
11091118
)
11101119
elif copy is True:
11111120
obj = np.array(obj, copy=True)
1121+
if np.issubdtype(obj.dtype, np.floating):
1122+
obj = obj.astype(paddle.get_default_dtype())
11121123
return paddle.to_tensor(obj, dtype=dtype, place=device)
11131124
else:
11141125
if not paddle.is_tensor(obj) or (dtype is not None and obj.dtype != dtype):
11151126
obj = np.array(obj, copy=False)
1127+
if np.issubdtype(obj.dtype, np.floating):
1128+
obj = obj.astype(paddle.get_default_dtype())
11161129
if dtype != paddle.bool and dtype != "bool":
11171130
obj = paddle.from_dlpack(obj.__dlpack__(), **kwargs).to(dtype)
11181131
else:

0 commit comments

Comments
 (0)