Skip to content

Commit a8af68a

Browse files
committed
refine all mm and binary ops
1 parent 1d432c2 commit a8af68a

File tree

7 files changed

+287
-195
lines changed

7 files changed

+287
-195
lines changed

intel_pytorch_extension_py/ops/jit_script.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1313
torch.jit.script = script_
1414

1515
if core.get_jit_opt():
16-
# bypass buggy broadcastable ops in dnnl during folding
17-
core.disable_auto_dnnl()
1816
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
19-
core.enable_auto_dnnl()
2017

2118
return jit_m
2219

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
'aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)',
6464
'aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor',
6565
'aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)',
66+
'aten::size.int(Tensor self, int dim) -> int',
6667
'aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor',
6768
'aten::gelu(Tensor self) -> Tensor',
6869
'aten::gelu_backward(Tensor grad, Tensor self) -> Tensor',

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_mm_out(self):
463463
def test_bmm(self):
464464
rand_seed = int(get_rand_seed())
465465
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
466-
x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed)
466+
x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed, batches=16)
467467

468468
with AutoDNNL(True), AutoMixPrecision(False):
469469
res_man_bf16 = torch.bmm(x_man_bf16_a, x_man_bf16_b)
@@ -477,8 +477,7 @@ def test_bmm(self):
477477
def test_bmm_out(self):
478478
rand_seed = int(get_rand_seed())
479479
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
480-
x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed)
481-
480+
x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed, batches=16)
482481
with AutoDNNL(True), AutoMixPrecision(False):
483482
torch.bmm(x_man_bf16_a, x_man_bf16_b, out=res_man_bf16)
484483
self.assertEqual(res_man_bf16.dtype, torch.bfloat16)

tests/cpu/test_lazy_reorder.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def get_rand_seed():
3333
return int(time.time() * 1000000000)
3434

3535
device = ipex.DEVICE
36+
37+
def convert_blocked(t):
38+
assert t.dim() == 4, "only support converting 4d tensor"
39+
c = t.size(1)
40+
t = t.clone().to(device)
41+
return F.conv2d(t, torch.ones(c, 1, 1, 1).to(device), groups=c)
42+
3643
class TestConv(TestCase):
3744
def test_Conv2d_with_cpu(self):
3845
rand_seed = int(get_rand_seed())
@@ -202,6 +209,78 @@ def test_mul_(self):
202209
a2 = self._test_mul_('cpu', rand_seed)
203210
self.assertEqual(a2, a1.to('cpu'))
204211

212+
def test_mixed_format(self):
213+
ipex.core.enable_auto_dnnl()
214+
rand_seed = int(get_rand_seed())
215+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
216+
torch.manual_seed(rand_seed)
217+
218+
shape = (2, 3, 4, 5)
219+
220+
for fname in ['add', 'mul']:
221+
222+
x_cpu = torch.ones(shape) * 5
223+
y_cpu = torch.ones(shape) * 4
224+
225+
# block tensor is a dpcpp tensor
226+
x_plain = x_cpu.clone().to(device)
227+
y_plain = y_cpu.clone().to(device)
228+
x_block = convert_blocked(x_cpu.clone())
229+
y_block = convert_blocked(y_cpu.clone())
230+
231+
fn = getattr(torch, fname)
232+
ref = fn(x_cpu, y_cpu)
233+
234+
# test add, mul
235+
def test_outplace(a, b):
236+
a = a.clone()
237+
b = b.clone()
238+
self.assertEqual(fn(a, b), ref)
239+
240+
test_outplace(x_plain, y_plain)
241+
test_outplace(x_plain, y_block)
242+
test_outplace(y_block, x_plain)
243+
test_outplace(x_block, y_block)
244+
245+
# test add_out, mul_out
246+
def test_out(a, b, o):
247+
a = a.clone()
248+
b = b.clone()
249+
o = o.clone()
250+
y = fn(a, b, out=o)
251+
self.assertEqual(y, ref)
252+
self.assertEqual(o, ref)
253+
254+
out = torch.ones(shape).to(device)
255+
test_out(x_plain, y_plain, out)
256+
test_out(x_plain, y_block, out)
257+
test_out(y_block, x_plain, out)
258+
test_out(x_block, y_block, out)
259+
out = torch.ones(1).to(device)
260+
test_out(x_plain, y_plain, out)
261+
test_out(x_plain, y_block, out)
262+
test_out(y_block, x_plain, out)
263+
test_out(x_block, y_block, out)
264+
265+
# test add_, mul_
266+
def test_inplace(a, b):
267+
a = a.clone()
268+
b = b.clone()
269+
y = getattr(a, fname + '_')(b)
270+
self.assertEqual(a, ref)
271+
self.assertEqual(y, ref)
272+
273+
test_inplace(x_plain, y_plain)
274+
test_inplace(x_plain, y_block)
275+
test_inplace(y_block, x_plain)
276+
test_inplace(x_block, y_block)
277+
278+
# test broadcast
279+
scalar = torch.ones(1).to(device)
280+
self.assertEqual(fn(x_plain, scalar), fn(x_cpu, scalar))
281+
self.assertEqual(fn(scalar, x_plain), fn(scalar, x_cpu))
282+
283+
205284
class TestRelu(TestCase):
206285
def _test_relu_(self, device, rand_seed):
207286
torch.manual_seed(rand_seed)
@@ -388,6 +467,11 @@ def test_addmm(self):
388467
torch.addmm(input=res_dpcpp, mat1=b1_dpcpp, mat2=b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp)
389468
self.assertEqual(y_cpu, y_dpcpp)
390469

470+
res_cpu.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
471+
res_dpcpp.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
472+
self.assertEqual(res_cpu, res_dpcpp)
473+
474+
391475
def test_addbmm(self):
392476
ipex.core.enable_auto_dnnl()
393477
rand_seed = int(get_rand_seed())
@@ -415,6 +499,10 @@ def test_addbmm(self):
415499
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
416500
self.assertEqual(y_cpu, y_dpcpp, 1e-4)
417501

502+
res_cpu.addbmm_(b1_cpu, b2_cpu, beta=beta, alpha=alpha)
503+
res_dpcpp.addbmm_(b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
504+
self.assertEqual(res_cpu, res_dpcpp, 1e-4)
505+
418506
def test_baddbmm(self):
419507
ipex.core.enable_auto_dnnl()
420508
rand_seed = int(get_rand_seed())
@@ -441,6 +529,9 @@ def test_baddbmm(self):
441529
torch.baddbmm(res_cpu, b1_cpu, b2_cpu, alpha=alpha, beta=beta, out=y_cpu),
442530
torch.baddbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp),
443531
self.assertEqual(y_cpu, y_dpcpp)
532+
res_cpu.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
533+
res_dpcpp.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
534+
self.assertEqual(res_cpu, res_dpcpp)
444535

445536
class TestLinear(TestCase):
446537
def test_linear(self):

0 commit comments

Comments
 (0)