@@ -420,7 +420,9 @@ def _normalize_axes(axis, ndim):
420
420
for a in axis :
421
421
if a < lower or a > upper :
422
422
# Match paddle error message (e.g., from sum())
423
- raise IndexError (f"Dimension out of range (expected to be in range of [{ lower } , { upper } ], but got { a } " )
423
+ raise IndexError (
424
+ f"Dimension out of range (expected to be in range of [{ lower } , { upper } ], but got { a } "
425
+ )
424
426
if a < 0 :
425
427
a = a + ndim
426
428
if a in axes :
@@ -480,7 +482,9 @@ def prod(
480
482
481
483
# paddle.prod doesn't support multiple axes
482
484
if isinstance (axis , tuple ):
483
- return _reduce_multiple_axes (paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs )
485
+ return _reduce_multiple_axes (
486
+ paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs
487
+ )
484
488
if axis is None :
485
489
# paddle doesn't support keepdims with axis=None
486
490
res = paddle .prod (x , dtype = dtype , ** kwargs )
@@ -610,7 +614,9 @@ def std(
610
614
if isinstance (correction , float ):
611
615
_correction = int (correction )
612
616
if correction != _correction :
613
- raise NotImplementedError ("float correction in paddle std() is not yet supported" )
617
+ raise NotImplementedError (
618
+ "float correction in paddle std() is not yet supported"
619
+ )
614
620
elif isinstance (correction , int ):
615
621
if correction not in [0 , 1 ]:
616
622
raise NotImplementedError ("correction only can be 0 or 1" )
@@ -648,7 +654,9 @@ def var(
648
654
if isinstance (correction , float ):
649
655
_correction = int (correction )
650
656
if correction != _correction :
651
- raise NotImplementedError ("float correction in paddle std() is not yet supported" )
657
+ raise NotImplementedError (
658
+ "float correction in paddle std() is not yet supported"
659
+ )
652
660
elif isinstance (correction , int ):
653
661
if correction not in [0 , 1 ]:
654
662
raise NotImplementedError ("correction only can be 0 or 1" )
@@ -709,7 +717,9 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
709
717
710
718
# The axis parameter doesn't work for flip() and roll()
711
719
# accept axis=None
712
- def flip (x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs ) -> array :
720
+ def flip (
721
+ x : array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None , ** kwargs
722
+ ) -> array :
713
723
if axis is None :
714
724
axis = tuple (range (x .ndim ))
715
725
# paddle.flip doesn't accept dim as an int but the method does
@@ -738,21 +748,27 @@ def where(condition: array, x1: array, x2: array, /) -> array:
738
748
return paddle .where (condition , x1 , x2 )
739
749
740
750
741
- def empty_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
751
+ def empty_like (
752
+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
753
+ ) -> array :
742
754
out = paddle .empty_like (x , dtype = dtype )
743
755
if device is not None :
744
756
out = out .to (device )
745
757
return out
746
758
747
759
748
- def zeros_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
760
+ def zeros_like (
761
+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
762
+ ) -> array :
749
763
out = paddle .zeros_like (x , dtype = dtype )
750
764
if device is not None :
751
765
out = out .to (device )
752
766
return out
753
767
754
768
755
- def ones_like (x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None ) -> array :
769
+ def ones_like (
770
+ x : array , / , * , dtype : Optional [Dtype ] = None , device : Optional [Device ] = None
771
+ ) -> array :
756
772
out = paddle .ones_like (x , dtype = dtype )
757
773
if device is not None :
758
774
out = out .to (device )
@@ -774,7 +790,9 @@ def full_like(
774
790
775
791
776
792
# paddle.reshape doesn't have the copy keyword
777
- def reshape (x : array , / , shape : Tuple [int , ...], copy : Optional [bool ] = None , ** kwargs ) -> array :
793
+ def reshape (
794
+ x : array , / , shape : Tuple [int , ...], copy : Optional [bool ] = None , ** kwargs
795
+ ) -> array :
778
796
return paddle .reshape (x , shape , ** kwargs )
779
797
780
798
@@ -825,7 +843,9 @@ def linspace(
825
843
** kwargs ,
826
844
) -> array :
827
845
if not endpoint :
828
- return paddle .linspace (start , stop , num + 1 , dtype = dtype , ** kwargs ).to (device )[:- 1 ]
846
+ return paddle .linspace (start , stop , num + 1 , dtype = dtype , ** kwargs ).to (device )[
847
+ :- 1
848
+ ]
829
849
return paddle .linspace (start , stop , num , dtype = dtype , ** kwargs ).to (device )
830
850
831
851
@@ -890,7 +910,9 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
890
910
return paddle .unsqueeze (x , axis )
891
911
892
912
893
- def astype (x : array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None ) -> array :
913
+ def astype (
914
+ x : array , dtype : Dtype , / , * , copy : bool = True , device : Optional [Device ] = None
915
+ ) -> array :
894
916
# if copy is not None:
895
917
# raise NotImplementedError("paddle.astype doesn't yet support the copy keyword")
896
918
t = x .to (dtype , device = device )
@@ -1036,7 +1058,7 @@ def sign(x: array, /) -> array:
1036
1058
else :
1037
1059
out = paddle .sign (x )
1038
1060
if paddle .is_floating_point (x ):
1039
- out = paddle .where (paddle .isnan (x ), paddle .nan , out )
1061
+ out = paddle .where (paddle .isnan (x ), paddle .full ( x . shape , paddle . nan ) , out )
1040
1062
return out
1041
1063
1042
1064
@@ -1083,7 +1105,8 @@ def asarray(
1083
1105
return obj
1084
1106
else :
1085
1107
raise NotImplementedError (
1086
- "asarray(obj, ..., copy=False) is not supported " "for obj do not has '__dlpack__()' method"
1108
+ "asarray(obj, ..., copy=False) is not supported "
1109
+ "for obj do not has '__dlpack__()' method"
1087
1110
)
1088
1111
elif copy is True :
1089
1112
obj = np .array (obj , copy = True )
@@ -1164,11 +1187,18 @@ def _isscalar(a):
1164
1187
1165
1188
1166
1189
def cumulative_sum (
1167
- x : array , / , * , axis : Optional [int ] = None , dtype : Optional [Dtype ] = None , include_initial : bool = False
1190
+ x : array ,
1191
+ / ,
1192
+ * ,
1193
+ axis : Optional [int ] = None ,
1194
+ dtype : Optional [Dtype ] = None ,
1195
+ include_initial : bool = False ,
1168
1196
) -> array :
1169
1197
if axis is None :
1170
1198
if x .ndim > 1 :
1171
- raise ValueError ("axis must be specified in cumulative_sum for more than one dimension" )
1199
+ raise ValueError (
1200
+ "axis must be specified in cumulative_sum for more than one dimension"
1201
+ )
1172
1202
axis = 0
1173
1203
1174
1204
res = paddle .cumsum (x , axis = axis , dtype = dtype )
@@ -1185,7 +1215,12 @@ def cumulative_sum(
1185
1215
1186
1216
1187
1217
def searchsorted (
1188
- x1 : array , x2 : array , / , * , side : Literal ["left" , "right" ] = "left" , sorter : array | None = None
1218
+ x1 : array ,
1219
+ x2 : array ,
1220
+ / ,
1221
+ * ,
1222
+ side : Literal ["left" , "right" ] = "left" ,
1223
+ sorter : array | None = None ,
1189
1224
) -> array :
1190
1225
if sorter is None :
1191
1226
return paddle .searchsorted (x1 , x2 , right = (side == "right" ))
0 commit comments