7
7
from tests .third_party .cupy import testing
8
8
9
9
10
- @testing .gpu
11
10
class TestDims (unittest .TestCase ):
12
11
def check_atleast (self , func , xp ):
13
12
a = testing .shaped_arange ((), xp )
@@ -55,6 +54,13 @@ def test_broadcast_to(self, xp, dtype):
55
54
b = xp .broadcast_to (a , (2 , 3 , 3 , 4 ))
56
55
return b
57
56
57
+ @testing .numpy_cupy_array_equal ()
58
+ def test_broadcast_to_int (self , xp ):
59
+ # Broadcast 0-dim array to 1-dim.
60
+ a = xp .array (10 , dtype = xp .float32 )
61
+ b = xp .broadcast_to (a , 10 )
62
+ return b
63
+
58
64
@testing .for_all_dtypes ()
59
65
def test_broadcast_to_fail (self , dtype ):
60
66
for xp in (numpy , cupy ):
@@ -286,7 +292,6 @@ def test_external_squeeze(self, xp):
286
292
{"shapes" : [(0 , 1 , 1 , 3 ), (2 , 1 , 0 , 0 , 3 )]},
287
293
{"shapes" : [(0 , 1 , 1 , 0 , 3 ), (5 , 2 , 0 , 1 , 0 , 0 , 3 ), (2 , 1 , 0 , 0 , 0 , 3 )]},
288
294
)
289
- @testing .gpu
290
295
class TestBroadcast (unittest .TestCase ):
291
296
def _broadcast (self , xp , dtype , shapes ):
292
297
arrays = [testing .shaped_arange (s , xp , dtype ) for s in shapes ]
@@ -296,9 +301,9 @@ def _broadcast(self, xp, dtype, shapes):
296
301
def test_broadcast (self , dtype ):
297
302
broadcast_np = self ._broadcast (numpy , dtype , self .shapes )
298
303
broadcast_cp = self ._broadcast (cupy , dtype , self .shapes )
299
- self . assertEqual ( broadcast_np .shape , broadcast_cp .shape )
300
- self . assertEqual ( broadcast_np .size , broadcast_cp .size )
301
- self . assertEqual ( broadcast_np .nd , broadcast_cp .nd )
304
+ assert broadcast_np .shape == broadcast_cp .shape
305
+ assert broadcast_np .size == broadcast_cp .size
306
+ assert broadcast_np .nd == broadcast_cp .nd
302
307
303
308
@testing .for_all_dtypes ()
304
309
@testing .numpy_cupy_array_equal ()
@@ -329,7 +334,6 @@ def test_broadcast_arrays(self, xp, dtype):
329
334
},
330
335
{"shapes" : [(0 ,), (2 ,)]},
331
336
)
332
- @testing .gpu
333
337
class TestInvalidBroadcast (unittest .TestCase ):
334
338
@testing .for_all_dtypes ()
335
339
def test_invalid_broadcast (self , dtype ):
0 commit comments