@@ -252,6 +252,22 @@ def _test_conv_add_relu_(self, device, rand_seed):
252
252
253
253
return conv_op_output , conv_op_input , add_src
254
254
255
+ def _test_conv_relu_ (self , device , rand_seed ):
256
+ ipex .enable_auto_dnnl ()
257
+ torch .manual_seed (rand_seed )
258
+ conv_op = torch .nn .Conv2d (1 , 1 , (7 , 7 )).to (device = device )
259
+ conv_op_input = torch .rand ((1 , 1 , 10 , 10 )).to (device = device )
260
+ conv_op_output = conv_op (conv_op_input )
261
+ conv_op_output .relu_ ()
262
+ return conv_op_output
263
+
264
+ def test_conv_relu_ (self ):
265
+ rand_seed = int (get_rand_seed ())
266
+ res_dcpp_dnnl = self ._test_conv_relu_ ("dpcpp:0" , rand_seed )
267
+ self .assertTrue (ipex .is_dil_tensor (res_dcpp_dnnl ))
268
+ res_cpu = self ._test_conv_relu_ ("cpu" , rand_seed )
269
+ self .assertEqual (res_cpu , res_dcpp_dnnl .to ('cpu' ))
270
+
255
271
def test_conv_add_relu_ (self ):
256
272
ipex .enable_auto_dnnl ()
257
273
rand_seed = int (get_rand_seed ())
@@ -260,18 +276,18 @@ def test_conv_add_relu_(self):
260
276
261
277
ipex .disable_auto_dnnl ()
262
278
res_dcpp_cpu , input_dpcpp_cpu , _ = self ._test_conv_add_relu_ ("dpcpp:0" , rand_seed )
263
-
279
+
264
280
res_cpu , input_cpu , _ = self ._test_conv_add_relu_ ("cpu" , rand_seed )
265
281
self .assertEqual (res_cpu , res_dcpp_cpu .to ('cpu' ))
266
282
self .assertEqual (res_cpu , res_dcpp_dnnl .to ('cpu' ))
267
283
268
284
ipex .enable_auto_dnnl ()
269
- res_dcpp_dnnl .sum ()# .backward()
270
- res_dcpp_cpu .sum ()# .backward()
271
- res_cpu .sum ()# .backward()
285
+ res_dcpp_dnnl .sum ().backward ()
286
+ res_dcpp_cpu .sum ().backward ()
287
+ res_cpu .sum ().backward ()
272
288
273
- # self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0)
274
- # self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0)
289
+ self .assertEqual (input_dpcpp_dnnl .grad .to ('cpu' ), input_cpu .grad , prec = 0.0 )
290
+ self .assertEqual (input_dpcpp_cpu .grad .to ('cpu' ), input_cpu .grad , prec = 0.0 )
275
291
276
292
class TestLinearAlgebraOps (TestCase ):
277
293
def test_mm (self ):
0 commit comments