@@ -88,7 +88,7 @@ def test_permute_dims_2d_3d(shapes):
88
88
def test_expand_dims_incorrect_type ():
89
89
X_list = [1 , 2 , 3 , 4 , 5 ]
90
90
with pytest .raises (TypeError ):
91
- dpt .permute_dims (X_list , 1 )
91
+ dpt .permute_dims (X_list , axis = 1 )
92
92
93
93
94
94
def test_expand_dims_0d ():
@@ -97,16 +97,16 @@ def test_expand_dims_0d():
97
97
Xnp = np .array (1 , dtype = "int64" )
98
98
X = dpt .asarray (Xnp , sycl_queue = q )
99
99
100
- Y = dpt .expand_dims (X , 0 )
101
- Ynp = np .expand_dims (Xnp , 0 )
100
+ Y = dpt .expand_dims (X , axis = 0 )
101
+ Ynp = np .expand_dims (Xnp , axis = 0 )
102
102
assert_array_equal (Ynp , dpt .asnumpy (Y ))
103
103
104
- Y = dpt .expand_dims (X , - 1 )
105
- Ynp = np .expand_dims (Xnp , - 1 )
104
+ Y = dpt .expand_dims (X , axis = - 1 )
105
+ Ynp = np .expand_dims (Xnp , axis = - 1 )
106
106
assert_array_equal (Ynp , dpt .asnumpy (Y ))
107
107
108
- pytest .raises (np .AxisError , dpt .expand_dims , X , 1 )
109
- pytest .raises (np .AxisError , dpt .expand_dims , X , - 2 )
108
+ pytest .raises (np .AxisError , dpt .expand_dims , X , axis = 1 )
109
+ pytest .raises (np .AxisError , dpt .expand_dims , X , axis = - 2 )
110
110
111
111
112
112
@pytest .mark .parametrize ("shapes" , [(3 ,), (3 , 3 ), (3 , 3 , 3 )])
@@ -119,12 +119,12 @@ def test_expand_dims_1d_3d(shapes):
119
119
X = dpt .asarray (Xnp , sycl_queue = q )
120
120
shape_len = len (shapes )
121
121
for axis in range (- shape_len - 1 , shape_len ):
122
- Y = dpt .expand_dims (X , axis )
123
- Ynp = np .expand_dims (Xnp , axis )
122
+ Y = dpt .expand_dims (X , axis = axis )
123
+ Ynp = np .expand_dims (Xnp , axis = axis )
124
124
assert_array_equal (Ynp , dpt .asnumpy (Y ))
125
125
126
- pytest .raises (np .AxisError , dpt .expand_dims , X , shape_len + 1 )
127
- pytest .raises (np .AxisError , dpt .expand_dims , X , - shape_len - 2 )
126
+ pytest .raises (np .AxisError , dpt .expand_dims , X , axis = shape_len + 1 )
127
+ pytest .raises (np .AxisError , dpt .expand_dims , X , axis = - shape_len - 2 )
128
128
129
129
130
130
@pytest .mark .parametrize (
@@ -135,8 +135,8 @@ def test_expand_dims_tuple(axes):
135
135
136
136
Xnp = np .empty ((3 , 3 , 3 ), dtype = "u1" )
137
137
X = dpt .asarray (Xnp , sycl_queue = q )
138
- Y = dpt .expand_dims (X , axes )
139
- Ynp = np .expand_dims (Xnp , axes )
138
+ Y = dpt .expand_dims (X , axis = axes )
139
+ Ynp = np .expand_dims (Xnp , axis = axes )
140
140
assert_array_equal (Ynp , dpt .asnumpy (Y ))
141
141
142
142
@@ -146,12 +146,12 @@ def test_expand_dims_incorrect_tuple():
146
146
except dpctl .SyclDeviceCreationError :
147
147
pytest .skip ("No SYCL devices available" )
148
148
with pytest .raises (np .AxisError ):
149
- dpt .expand_dims (X , (0 , - 6 ))
149
+ dpt .expand_dims (X , axis = (0 , - 6 ))
150
150
with pytest .raises (np .AxisError ):
151
- dpt .expand_dims (X , (0 , 5 ))
151
+ dpt .expand_dims (X , axis = (0 , 5 ))
152
152
153
153
with pytest .raises (ValueError ):
154
- dpt .expand_dims (X , (1 , 1 ))
154
+ dpt .expand_dims (X , axis = (1 , 1 ))
155
155
156
156
157
157
def test_squeeze_incorrect_type ():
@@ -456,9 +456,9 @@ def test_flip_0d():
456
456
Y = dpt .flip (X )
457
457
assert_array_equal (Ynp , dpt .asnumpy (Y ))
458
458
459
- pytest .raises (np .AxisError , dpt .flip , X , 0 )
460
- pytest .raises (np .AxisError , dpt .flip , X , 1 )
461
- pytest .raises (np .AxisError , dpt .flip , X , - 1 )
459
+ pytest .raises (np .AxisError , dpt .flip , X , axis = 0 )
460
+ pytest .raises (np .AxisError , dpt .flip , X , axis = 1 )
461
+ pytest .raises (np .AxisError , dpt .flip , X , axis = - 1 )
462
462
463
463
464
464
def test_flip_1d ():
@@ -468,12 +468,12 @@ def test_flip_1d():
468
468
X = dpt .asarray (Xnp , sycl_queue = q )
469
469
470
470
for ax in range (- X .ndim , X .ndim ):
471
- Ynp = np .flip (Xnp , ax )
472
- Y = dpt .flip (X , ax )
471
+ Ynp = np .flip (Xnp , axis = ax )
472
+ Y = dpt .flip (X , axis = ax )
473
473
assert_array_equal (Ynp , dpt .asnumpy (Y ))
474
474
475
- Ynp = np .flip (Xnp , 0 )
476
- Y = dpt .flip (X , 0 )
475
+ Ynp = np .flip (Xnp , axis = 0 )
476
+ Y = dpt .flip (X , axis = 0 )
477
477
assert_array_equal (Ynp , dpt .asnumpy (Y ))
478
478
479
479
@@ -497,8 +497,8 @@ def test_flip_2d_3d(shapes):
497
497
Xnp = np .arange (Xnp_size ).reshape (shapes )
498
498
X = dpt .asarray (Xnp , sycl_queue = q )
499
499
for ax in range (- X .ndim , X .ndim ):
500
- Y = dpt .flip (X , ax )
501
- Ynp = np .flip (Xnp , ax )
500
+ Y = dpt .flip (X , axis = ax )
501
+ Ynp = np .flip (Xnp , axis = ax )
502
502
assert_array_equal (Ynp , dpt .asnumpy (Y ))
503
503
504
504
@@ -569,8 +569,8 @@ def test_flip_multiple_axes(data):
569
569
Xnp_size = np .prod (shape )
570
570
Xnp = np .arange (Xnp_size ).reshape (shape )
571
571
X = dpt .asarray (Xnp , sycl_queue = q )
572
- Y = dpt .flip (X , axes )
573
- Ynp = np .flip (Xnp , axes )
572
+ Y = dpt .flip (X , axis = axes )
573
+ Ynp = np .flip (Xnp , axis = axes )
574
574
assert_array_equal (Ynp , dpt .asnumpy (Y ))
575
575
576
576
@@ -583,8 +583,10 @@ def test_roll_empty():
583
583
Y = dpt .roll (X , 1 )
584
584
Ynp = np .roll (Xnp , 1 )
585
585
assert_array_equal (Ynp , dpt .asnumpy (Y ))
586
- pytest .raises (np .AxisError , dpt .roll , X , 1 , 0 )
587
- pytest .raises (np .AxisError , dpt .roll , X , 1 , 1 )
586
+ with pytest .raises (np .AxisError ):
587
+ dpt .roll (X , 1 , axis = 0 )
588
+ with pytest .raises (np .AxisError ):
589
+ dpt .roll (X , 1 , axis = 1 )
588
590
589
591
590
592
@pytest .mark .parametrize (
@@ -605,12 +607,12 @@ def test_roll_1d(data):
605
607
X = dpt .asarray (Xnp , sycl_queue = q )
606
608
sh , ax = data
607
609
608
- Y = dpt .roll (X , sh , ax )
609
- Ynp = np .roll (Xnp , sh , ax )
610
+ Y = dpt .roll (X , sh , axis = ax )
611
+ Ynp = np .roll (Xnp , sh , axis = ax )
610
612
assert_array_equal (Ynp , dpt .asnumpy (Y ))
611
613
612
- Y = dpt .roll (X , sh , ax )
613
- Ynp = np .roll (Xnp , sh , ax )
614
+ Y = dpt .roll (X , sh , axis = ax )
615
+ Ynp = np .roll (Xnp , sh , axis = ax )
614
616
assert_array_equal (Ynp , dpt .asnumpy (Y ))
615
617
616
618
@@ -644,8 +646,8 @@ def test_roll_2d(data):
644
646
X = dpt .asarray (Xnp , sycl_queue = q )
645
647
sh , ax = data
646
648
647
- Y = dpt .roll (X , sh , ax )
648
- Ynp = np .roll (Xnp , sh , ax )
649
+ Y = dpt .roll (X , sh , axis = ax )
650
+ Ynp = np .roll (Xnp , sh , axis = ax )
649
651
assert_array_equal (Ynp , dpt .asnumpy (Y ))
650
652
651
653
@@ -664,10 +666,14 @@ def test_roll_validation():
664
666
665
667
def test_concat_incorrect_type ():
666
668
Xnp = np .ones ((2 , 2 ))
667
- pytest .raises (TypeError , dpt .concat )
668
- pytest .raises (TypeError , dpt .concat , [])
669
- pytest .raises (TypeError , dpt .concat , Xnp )
670
- pytest .raises (TypeError , dpt .concat , [Xnp , Xnp ])
669
+ with pytest .raises (TypeError ):
670
+ dpt .concat ()
671
+ with pytest .raises (TypeError ):
672
+ dpt .concat ([])
673
+ with pytest .raises (TypeError ):
674
+ dpt .concat (Xnp )
675
+ with pytest .raises (TypeError ):
676
+ dpt .concat ([Xnp , Xnp ])
671
677
672
678
673
679
def test_concat_incorrect_queue ():
@@ -719,7 +725,7 @@ def test_concat_incorrect_shape(data):
719
725
X = dpt .ones (Xshape , sycl_queue = q )
720
726
Y = dpt .ones (Yshape , sycl_queue = q )
721
727
722
- pytest .raises (ValueError , dpt .concat , [X , Y ], axis )
728
+ pytest .raises (ValueError , dpt .concat , [X , Y ], axis = axis )
723
729
724
730
725
731
@pytest .mark .parametrize (
@@ -827,7 +833,8 @@ def test_stack_incorrect_shape():
827
833
X = dpt .ones ((1 ,), sycl_queue = q )
828
834
Y = dpt .ones ((2 ,), sycl_queue = q )
829
835
830
- pytest .raises (ValueError , dpt .stack , [X , Y ], 0 )
836
+ with pytest .raises (ValueError ):
837
+ dpt .stack ([X , Y ], axis = 0 )
831
838
832
839
833
840
@pytest .mark .parametrize (
@@ -1111,7 +1118,7 @@ def test_unstack_axis1():
1111
1118
except dpctl .SyclDeviceCreationError :
1112
1119
pytest .skip ("No SYCL devices available" )
1113
1120
y = dpt .reshape (x_flat , (2 , 3 ))
1114
- res = dpt .unstack (y , 1 )
1121
+ res = dpt .unstack (y , axis = 1 )
1115
1122
1116
1123
assert_array_equal (dpt .asnumpy (y [:, 0 , ...]), dpt .asnumpy (res [0 ]))
1117
1124
assert_array_equal (dpt .asnumpy (y [:, 1 , ...]), dpt .asnumpy (res [1 ]))
@@ -1124,7 +1131,7 @@ def test_unstack_axis2():
1124
1131
except dpctl .SyclDeviceCreationError :
1125
1132
pytest .skip ("No SYCL devices available" )
1126
1133
y = dpt .reshape (x_flat , (4 , 5 , 3 ))
1127
- res = dpt .unstack (y , 2 )
1134
+ res = dpt .unstack (y , axis = 2 )
1128
1135
1129
1136
assert_array_equal (dpt .asnumpy (y [:, :, 0 , ...]), dpt .asnumpy (res [0 ]))
1130
1137
assert_array_equal (dpt .asnumpy (y [:, :, 1 , ...]), dpt .asnumpy (res [1 ]))
0 commit comments