Skip to content

Commit 7d09fb0

Browse files
committed
Refine unit test case
1 parent a158a0e commit 7d09fb0

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

tests/cpu/test_lazy_reorder.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,22 @@ def _test_conv_add_relu_(self, device, rand_seed):
252252

253253
return conv_op_output, conv_op_input, add_src
254254

255+
def _test_conv_relu_(self, device, rand_seed):
256+
ipex.enable_auto_dnnl()
257+
torch.manual_seed(rand_seed)
258+
conv_op = torch.nn.Conv2d(1, 1, (7, 7)).to(device=device)
259+
conv_op_input = torch.rand((1, 1, 10, 10)).to(device=device)
260+
conv_op_output = conv_op(conv_op_input)
261+
conv_op_output.relu_()
262+
return conv_op_output
263+
264+
def test_conv_relu_(self):
265+
rand_seed = int(get_rand_seed())
266+
res_dcpp_dnnl = self._test_conv_relu_("dpcpp:0", rand_seed)
267+
self.assertTrue(ipex.is_dil_tensor(res_dcpp_dnnl))
268+
res_cpu = self._test_conv_relu_("cpu", rand_seed)
269+
self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu'))
270+
255271
def test_conv_add_relu_(self):
256272
ipex.enable_auto_dnnl()
257273
rand_seed = int(get_rand_seed())
@@ -260,18 +276,18 @@ def test_conv_add_relu_(self):
260276

261277
ipex.disable_auto_dnnl()
262278
res_dcpp_cpu, input_dpcpp_cpu, _ = self._test_conv_add_relu_("dpcpp:0", rand_seed)
263-
279+
264280
res_cpu, input_cpu, _ = self._test_conv_add_relu_("cpu", rand_seed)
265281
self.assertEqual(res_cpu, res_dcpp_cpu.to('cpu'))
266282
self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu'))
267283

268284
ipex.enable_auto_dnnl()
269-
res_dcpp_dnnl.sum()#.backward()
270-
res_dcpp_cpu.sum()#.backward()
271-
res_cpu.sum()#.backward()
285+
res_dcpp_dnnl.sum().backward()
286+
res_dcpp_cpu.sum().backward()
287+
res_cpu.sum().backward()
272288

273-
#self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0)
274-
#self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0)
289+
self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0)
290+
self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0)
275291

276292
class TestLinearAlgebraOps(TestCase):
277293
def test_mm(self):

0 commit comments

Comments
 (0)