@@ -33,6 +33,13 @@ def get_rand_seed():
33
33
return int (time .time () * 1000000000 )
34
34
35
35
device = ipex .DEVICE
36
+
37
+ def convert_blocked (t ):
38
+ assert t .dim () == 4 , "only support converting 4d tensor"
39
+ c = t .size (1 )
40
+ t = t .clone ().to (device )
41
+ return F .conv2d (t , torch .ones (c , 1 , 1 , 1 ).to (device ), groups = c )
42
+
36
43
class TestConv (TestCase ):
37
44
def test_Conv2d_with_cpu (self ):
38
45
rand_seed = int (get_rand_seed ())
@@ -202,6 +209,78 @@ def test_mul_(self):
202
209
a2 = self ._test_mul_ ('cpu' , rand_seed )
203
210
self .assertEqual (a2 , a1 .to ('cpu' ))
204
211
212
+ def test_mixed_format (self ):
213
+ ipex .core .enable_auto_dnnl ()
214
+ rand_seed = int (get_rand_seed ())
215
+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
216
+ torch .manual_seed (rand_seed )
217
+
218
+ shape = (2 , 3 , 4 , 5 )
219
+
220
+ for fname in ['add' , 'mul' ]:
221
+
222
+ x_cpu = torch .ones (shape ) * 5
223
+ y_cpu = torch .ones (shape ) * 4
224
+
225
+ # block tensor is a dpcpp tensor
226
+ x_plain = x_cpu .clone ().to (device )
227
+ y_plain = y_cpu .clone ().to (device )
228
+ x_block = convert_blocked (x_cpu .clone ())
229
+ y_block = convert_blocked (y_cpu .clone ())
230
+
231
+ fn = getattr (torch , fname )
232
+ ref = fn (x_cpu , y_cpu )
233
+
234
+ # test add, mul
235
+ def test_outplace (a , b ):
236
+ a = a .clone ()
237
+ b = b .clone ()
238
+ self .assertEqual (fn (a , b ), ref )
239
+
240
+ test_outplace (x_plain , y_plain )
241
+ test_outplace (x_plain , y_block )
242
+ test_outplace (y_block , x_plain )
243
+ test_outplace (x_block , y_block )
244
+
245
+ # test add_out, mul_out
246
+ def test_out (a , b , o ):
247
+ a = a .clone ()
248
+ b = b .clone ()
249
+ o = o .clone ()
250
+ y = fn (a , b , out = o )
251
+ self .assertEqual (y , ref )
252
+ self .assertEqual (o , ref )
253
+
254
+ out = torch .ones (shape ).to (device )
255
+ test_out (x_plain , y_plain , out )
256
+ test_out (x_plain , y_block , out )
257
+ test_out (y_block , x_plain , out )
258
+ test_out (x_block , y_block , out )
259
+ out = torch .ones (1 ).to (device )
260
+ test_out (x_plain , y_plain , out )
261
+ test_out (x_plain , y_block , out )
262
+ test_out (y_block , x_plain , out )
263
+ test_out (x_block , y_block , out )
264
+
265
+ # test add_, mul_
266
+ def test_inplace (a , b ):
267
+ a = a .clone ()
268
+ b = b .clone ()
269
+ y = getattr (a , fname + '_' )(b )
270
+ self .assertEqual (a , ref )
271
+ self .assertEqual (y , ref )
272
+
273
+ test_inplace (x_plain , y_plain )
274
+ test_inplace (x_plain , y_block )
275
+ test_inplace (y_block , x_plain )
276
+ test_inplace (x_block , y_block )
277
+
278
+ # test broadcast
279
+ scalar = torch .ones (1 ).to (device )
280
+ self .assertEqual (fn (x_plain , scalar ), fn (x_cpu , scalar ))
281
+ self .assertEqual (fn (scalar , x_plain ), fn (scalar , x_cpu ))
282
+
283
+
205
284
class TestRelu (TestCase ):
206
285
def _test_relu_ (self , device , rand_seed ):
207
286
torch .manual_seed (rand_seed )
@@ -388,6 +467,11 @@ def test_addmm(self):
388
467
torch .addmm (input = res_dpcpp , mat1 = b1_dpcpp , mat2 = b2_dpcpp , alpha = alpha , beta = beta , out = y_dpcpp )
389
468
self .assertEqual (y_cpu , y_dpcpp )
390
469
470
+ res_cpu .addmm_ (mat1 = b1_cpu , mat2 = b2_cpu , alpha = alpha , beta = beta )
471
+ res_dpcpp .addmm_ (mat1 = b1_cpu , mat2 = b2_cpu , alpha = alpha , beta = beta )
472
+ self .assertEqual (res_cpu , res_dpcpp )
473
+
474
+
391
475
def test_addbmm (self ):
392
476
ipex .core .enable_auto_dnnl ()
393
477
rand_seed = int (get_rand_seed ())
@@ -415,6 +499,10 @@ def test_addbmm(self):
415
499
torch .addbmm (res_dpcpp , b1_dpcpp , b2_dpcpp , beta = beta , alpha = alpha , out = y_dpcpp )
416
500
self .assertEqual (y_cpu , y_dpcpp , 1e-4 )
417
501
502
+ res_cpu .addbmm_ (b1_cpu , b2_cpu , beta = beta , alpha = alpha )
503
+ res_dpcpp .addbmm_ (b1_dpcpp , b2_dpcpp , beta = beta , alpha = alpha )
504
+ self .assertEqual (res_cpu , res_dpcpp , 1e-4 )
505
+
418
506
def test_baddbmm (self ):
419
507
ipex .core .enable_auto_dnnl ()
420
508
rand_seed = int (get_rand_seed ())
@@ -441,6 +529,9 @@ def test_baddbmm(self):
441
529
torch .baddbmm (res_cpu , b1_cpu , b2_cpu , alpha = alpha , beta = beta , out = y_cpu ),
442
530
torch .baddbmm (res_dpcpp , b1_dpcpp , b2_dpcpp , alpha = alpha , beta = beta , out = y_dpcpp ),
443
531
self .assertEqual (y_cpu , y_dpcpp )
532
+ res_cpu .baddbmm_ (b1_cpu , b2_cpu , alpha = alpha , beta = beta )
533
+ res_dpcpp .baddbmm_ (b1_cpu , b2_cpu , alpha = alpha , beta = beta )
534
+ self .assertEqual (res_cpu , res_dpcpp )
444
535
445
536
class TestLinear (TestCase ):
446
537
def test_linear (self ):
0 commit comments