Skip to content

Refine all matmul and binary ops #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions intel_pytorch_extension_py/ops/jit_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
torch.jit.script = script_

if core.get_jit_opt():
# bypass buggy broadcastable ops in dnnl during folding
core.disable_auto_dnnl()
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
core.enable_auto_dnnl()

return jit_m

Expand Down
1 change: 1 addition & 0 deletions scripts/cpu/gen-dense-cpu-ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
'aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)',
'aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor',
'aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)',
'aten::size.int(Tensor self, int dim) -> int',
'aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor',
'aten::gelu(Tensor self) -> Tensor',
'aten::gelu_backward(Tensor grad, Tensor self) -> Tensor',
Expand Down
5 changes: 2 additions & 3 deletions tests/cpu/test_bf16_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_mm_out(self):
def test_bmm(self):
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed)
x_auto_mix_a, x_auto_mix_b, _, x_man_bf16_a, x_man_bf16_b, _ = self._gen_mm_tensor(rand_seed, batches=16)

with AutoDNNL(True), AutoMixPrecision(False):
res_man_bf16 = torch.bmm(x_man_bf16_a, x_man_bf16_b)
Expand All @@ -477,8 +477,7 @@ def test_bmm(self):
def test_bmm_out(self):
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed)

x_auto_mix_a, x_auto_mix_b, res_auto_mix, x_man_bf16_a, x_man_bf16_b, res_man_bf16 = self._gen_mm_tensor(rand_seed, batches=16)
with AutoDNNL(True), AutoMixPrecision(False):
torch.bmm(x_man_bf16_a, x_man_bf16_b, out=res_man_bf16)
self.assertEqual(res_man_bf16.dtype, torch.bfloat16)
Expand Down
91 changes: 91 additions & 0 deletions tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def get_rand_seed():
return int(time.time() * 1000000000)

device = ipex.DEVICE

def convert_blocked(t):
assert t.dim() == 4, "only support converting 4d tensor"
c = t.size(1)
t = t.clone().to(device)
return F.conv2d(t, torch.ones(c, 1, 1, 1).to(device), groups=c)

class TestConv(TestCase):
def test_Conv2d_with_cpu(self):
rand_seed = int(get_rand_seed())
Expand Down Expand Up @@ -202,6 +209,78 @@ def test_mul_(self):
a2 = self._test_mul_('cpu', rand_seed)
self.assertEqual(a2, a1.to('cpu'))

def test_mixed_format(self):
ipex.core.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
torch.manual_seed(rand_seed)

shape = (2, 3, 4, 5)

for fname in ['add', 'mul']:

x_cpu = torch.ones(shape) * 5
y_cpu = torch.ones(shape) * 4

# block tensor is a dpcpp tensor
x_plain = x_cpu.clone().to(device)
y_plain = y_cpu.clone().to(device)
x_block = convert_blocked(x_cpu.clone())
y_block = convert_blocked(y_cpu.clone())

fn = getattr(torch, fname)
ref = fn(x_cpu, y_cpu)

# test add, mul
def test_outplace(a, b):
a = a.clone()
b = b.clone()
self.assertEqual(fn(a, b), ref)

test_outplace(x_plain, y_plain)
test_outplace(x_plain, y_block)
test_outplace(y_block, x_plain)
test_outplace(x_block, y_block)

# test add_out, mul_out
def test_out(a, b, o):
a = a.clone()
b = b.clone()
o = o.clone()
y = fn(a, b, out=o)
self.assertEqual(y, ref)
self.assertEqual(o, ref)

out = torch.ones(shape).to(device)
test_out(x_plain, y_plain, out)
test_out(x_plain, y_block, out)
test_out(y_block, x_plain, out)
test_out(x_block, y_block, out)
out = torch.ones(1).to(device)
test_out(x_plain, y_plain, out)
test_out(x_plain, y_block, out)
test_out(y_block, x_plain, out)
test_out(x_block, y_block, out)

# test add_, mul_
def test_inplace(a, b):
a = a.clone()
b = b.clone()
y = getattr(a, fname + '_')(b)
self.assertEqual(a, ref)
self.assertEqual(y, ref)

test_inplace(x_plain, y_plain)
test_inplace(x_plain, y_block)
test_inplace(y_block, x_plain)
test_inplace(x_block, y_block)

# test broadcast
scalar = torch.ones(1).to(device)
self.assertEqual(fn(x_plain, scalar), fn(x_cpu, scalar))
self.assertEqual(fn(scalar, x_plain), fn(scalar, x_cpu))


class TestRelu(TestCase):
def _test_relu_(self, device, rand_seed):
torch.manual_seed(rand_seed)
Expand Down Expand Up @@ -388,6 +467,11 @@ def test_addmm(self):
torch.addmm(input=res_dpcpp, mat1=b1_dpcpp, mat2=b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp)
self.assertEqual(y_cpu, y_dpcpp)

res_cpu.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
res_dpcpp.addmm_(mat1=b1_cpu, mat2=b2_cpu, alpha=alpha, beta=beta)
self.assertEqual(res_cpu, res_dpcpp)


def test_addbmm(self):
ipex.core.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
Expand Down Expand Up @@ -415,6 +499,10 @@ def test_addbmm(self):
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
self.assertEqual(y_cpu, y_dpcpp, 1e-4)

res_cpu.addbmm_(b1_cpu, b2_cpu, beta=beta, alpha=alpha)
res_dpcpp.addbmm_(b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
self.assertEqual(res_cpu, res_dpcpp, 1e-4)

def test_baddbmm(self):
ipex.core.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
Expand All @@ -441,6 +529,9 @@ def test_baddbmm(self):
torch.baddbmm(res_cpu, b1_cpu, b2_cpu, alpha=alpha, beta=beta, out=y_cpu),
torch.baddbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, alpha=alpha, beta=beta, out=y_dpcpp),
self.assertEqual(y_cpu, y_dpcpp)
res_cpu.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
res_dpcpp.baddbmm_(b1_cpu, b2_cpu, alpha=alpha, beta=beta)
self.assertEqual(res_cpu, res_dpcpp)

class TestLinear(TestCase):
def test_linear(self):
Expand Down
Loading