Skip to content

Check DNNL Buffer #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ def _test_conv_add_relu_(self, device, rand_seed):

return conv_op_output, conv_op_input, add_src

def _test_conv_relu_(self, device, rand_seed):
ipex.enable_auto_dnnl()
torch.manual_seed(rand_seed)
conv_op = torch.nn.Conv2d(1, 1, (7, 7)).to(device=device)
conv_op_input = torch.rand((1, 1, 10, 10)).to(device=device)
conv_op_output = conv_op(conv_op_input)
conv_op_output.relu_()
return conv_op_output

def test_conv_relu_(self):
rand_seed = int(get_rand_seed())
res_dcpp_dnnl = self._test_conv_relu_("dpcpp:0", rand_seed)
self.assertTrue(ipex.is_dil_tensor(res_dcpp_dnnl))
res_cpu = self._test_conv_relu_("cpu", rand_seed)
self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu'))

def test_conv_add_relu_(self):
ipex.enable_auto_dnnl()
rand_seed = int(get_rand_seed())
Expand All @@ -260,18 +276,18 @@ def test_conv_add_relu_(self):

ipex.disable_auto_dnnl()
res_dcpp_cpu, input_dpcpp_cpu, _ = self._test_conv_add_relu_("dpcpp:0", rand_seed)

res_cpu, input_cpu, _ = self._test_conv_add_relu_("cpu", rand_seed)
self.assertEqual(res_cpu, res_dcpp_cpu.to('cpu'))
self.assertEqual(res_cpu, res_dcpp_dnnl.to('cpu'))

ipex.enable_auto_dnnl()
res_dcpp_dnnl.sum()#.backward()
res_dcpp_cpu.sum()#.backward()
res_cpu.sum()#.backward()
res_dcpp_dnnl.sum().backward()
res_dcpp_cpu.sum().backward()
res_cpu.sum().backward()

#self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0)
#self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0)
self.assertEqual(input_dpcpp_dnnl.grad.to('cpu'), input_cpu.grad, prec=0.0)
self.assertEqual(input_dpcpp_cpu.grad.to('cpu'), input_cpu.grad, prec=0.0)

class TestLinearAlgebraOps(TestCase):
def test_mm(self):
Expand Down
14 changes: 14 additions & 0 deletions torch_ipex/csrc/cpu/DevOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ at::Tensor& AtenIpexCPUDev::dil_add_out(
const std::vector<float> scales{1.0, alpha.to<float>()};
dil::sum::compute(scales, {x, y}, z);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, z);
return result;
}
Expand Down Expand Up @@ -267,6 +268,7 @@ at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other,
const std::vector<float> scales{1.0, alpha.to<float>()};
dil::sum::compute(scales, {dil_self, dil_other}, dil_self);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self));
dbl::comm::sync_shape_from_dil_to_aten(self, dil_self);
return self;
}
Expand All @@ -283,6 +285,7 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se

dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, dil_result);
return result;
}
Expand Down Expand Up @@ -346,6 +349,7 @@ at::Tensor& AtenIpexCPUDev::dil_bmm_out(
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
matmul_common(x, w, dil::tensor(), y);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, y);
return result;
}
Expand Down Expand Up @@ -390,6 +394,8 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
auto attr_ = dil::attr_t::fuse_sum();
matmul_common(x, w, bias, y, beta, alpha, attr_);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, y);
return result;
}
Expand Down Expand Up @@ -489,6 +495,8 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
}
}
matmul_common(x_, w_, bias, y, beta, alpha, attr_);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, y);
return result;
}
Expand Down Expand Up @@ -974,6 +982,8 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
dil::algorithm::eltwise_relu,
dil::prop_kind::forward_training,
/*alpha*/ 0.0);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(input));
dbl::comm::sync_shape_from_dil_to_aten(input, dil_self);
return input;
}
Expand Down Expand Up @@ -1041,6 +1051,8 @@ at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) {
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
dil::eltwise_forward::compute(
x, x, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(x.is_public_format() || check_tensor_own_whole_storage(self));
dbl::comm::sync_shape_from_dil_to_aten(self, x);
return self;
}
Expand Down Expand Up @@ -1122,6 +1134,8 @@ at::Tensor& AtenIpexCPUDev::dil_cat_out(at::Tensor& result, at::TensorList tenso
}
dil::tensor y = dbl::comm::try_gen_dil_tensor(result);
dil::concat::compute(x, dim, y);

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(y.is_public_format() || check_tensor_own_whole_storage(result));
dbl::comm::sync_shape_from_dil_to_aten(result, y);
return result;
}
Expand Down
2 changes: 2 additions & 0 deletions torch_ipex/csrc/cpu/ShadeDataContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ struct ShadeDataContext {
if (raw_cpu_data == nullptr) {
// the dnnl tensor does not share data with raw tensor data.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor));
return true;
} else {
// The dnnl tensor shares some data with raw tensor.
Expand Down
26 changes: 14 additions & 12 deletions torch_ipex/csrc/cpu/dbl/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) {
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
}
// Does not support the case if the dil tensor is block format but it is just a part of tensor buffer
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_tensor.is_public_format() || check_tensor_own_whole_storage(input));
return dil_tensor;
} else {
return dil_tensor_from_dense(input);
Expand Down Expand Up @@ -71,14 +73,7 @@ at::Tensor gen_aten_tensor_by(dil::tensor dil_tensor) {
nullptr,
/*resizeable=*/false);
auto _tensor = at::detail::make_tensor<torch_ipex::IPEXTensorImpl>(storage_impl, at::DispatchKey::DPCPPTensorId);
if (dil_tensor.is_public_format()) {
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
} else {
// Blockformat does not inlcude stride information
auto tensor_sizes = dil_tensor.get_dims();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor_sizes.size() != 1 || tensor_sizes[0] != 0);
_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(tensor_sizes);
}
dbl::comm::sync_shape_from_dil_to_aten(_tensor, dil_tensor);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_tensor.layout() == c10::kStrided);
return _tensor;
}
Expand All @@ -94,10 +89,17 @@ at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& opti

void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) {
dil::dims sizes = dil_tensor.get_dims();
dil::dims strides = dil_tensor.get_strides();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
_tensor_impl->force_set_strided(sizes, strides);
if (dil_tensor.is_public_format()) {
dil::dims strides = dil_tensor.get_strides();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
_tensor_impl->force_set_strided(sizes, strides);
} else {
// Blockformat does not inlcude stride information
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.size() != 1 || sizes[0] != 0);
ipex_tensor.unsafeGetTensorImpl()->set_sizes_contiguous(sizes);
}

}

} // namespace comm
Expand Down