@@ -1027,7 +1027,16 @@ def is_complex(dtype):
1027
1027
elif kind == "integral" :
1028
1028
return dtype in _int_dtypes
1029
1029
elif kind == "real floating" :
1030
- return paddle .is_floating_point (dtype )
1030
+ return dtype in [
1031
+ paddle .framework .core .VarDesc .VarType .FP32 ,
1032
+ paddle .framework .core .VarDesc .VarType .FP64 ,
1033
+ paddle .framework .core .VarDesc .VarType .FP16 ,
1034
+ paddle .framework .core .VarDesc .VarType .BF16 ,
1035
+ paddle .framework .core .DataType .FLOAT32 ,
1036
+ paddle .framework .core .DataType .FLOAT64 ,
1037
+ paddle .framework .core .DataType .FLOAT16 ,
1038
+ paddle .framework .core .DataType .BFLOAT16 ,
1039
+ ]
1031
1040
elif kind == "complex floating" :
1032
1041
return is_complex (dtype )
1033
1042
elif kind == "numeric" :
@@ -1109,10 +1118,14 @@ def asarray(
1109
1118
)
1110
1119
elif copy is True :
1111
1120
obj = np .array (obj , copy = True )
1121
+ if np .issubdtype (obj .dtype , np .floating ):
1122
+ obj = obj .astype (paddle .get_default_dtype ())
1112
1123
return paddle .to_tensor (obj , dtype = dtype , place = device )
1113
1124
else :
1114
1125
if not paddle .is_tensor (obj ) or (dtype is not None and obj .dtype != dtype ):
1115
1126
obj = np .array (obj , copy = False )
1127
+ if np .issubdtype (obj .dtype , np .floating ):
1128
+ obj = obj .astype (paddle .get_default_dtype ())
1116
1129
if dtype != paddle .bool and dtype != "bool" :
1117
1130
obj = paddle .from_dlpack (obj .__dlpack__ (), ** kwargs ).to (dtype )
1118
1131
else :
0 commit comments