Skip to content

Refine unit test case and write the shape meta info back to aten tensor when the tensor is reordered to public format #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/cpu/gen-dense-cpu-ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,17 @@ def is_out_func(fname):
if param_var == 'out' and is_out_func(fname):
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
else:
# param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
None
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
param_seq_str_vec.append(param_seq_str)
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n'
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))

code += ' }\n'

code += ' } catch (std::exception& e) {\n'
code += '#if defined(_DEBUG)\n'
code += ' TORCH_WARN(e.what());\n'
code += '#endif\n'
code += ' }\n\n'


Expand Down
18 changes: 12 additions & 6 deletions tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,12 @@ def test_addbmm(self):

addbmm_cpu = torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha)
addbmm_dpcpp = torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
self.assertEqual(addbmm_cpu, addbmm_dpcpp)
self.assertEqual(addbmm_cpu, addbmm_dpcpp, 1e-4)
y_cpu = torch.randn(M, O, dtype=torch.float32)
y_dpcpp = y_cpu.to(device=device)
torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha, out=y_cpu)
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
self.assertEqual(y_cpu, y_dpcpp)
self.assertEqual(y_cpu, y_dpcpp, 1e-4)

def test_baddbmm(self):
ipex.enable_auto_dnnl()
Expand Down Expand Up @@ -683,7 +683,6 @@ def test_batch_norm2d_backward(self):

bn = torch.nn.BatchNorm2d(3)
bn_dpcpp = copy.deepcopy(bn).to(device=device)

y_cpu = bn(x_cpu).sum()
y_dpcpp = bn_dpcpp(x_dpcpp).sum()
y_cpu.backward()
Expand Down Expand Up @@ -756,17 +755,24 @@ def test_view(self):

x_cpu = torch.randn(old_shape)
x_dpcpp = x_cpu.to(device=device).clone()
print(x_dpcpp.size())
self.assertTrue(ipex.is_dil_tensor(x_dpcpp))
self.assertEqual(ipex.get_dil_tensor_sizes(x_dpcpp), [4, 16])
self.assertEqual(ipex.get_dil_tensor_strides(x_dpcpp), [16, 1])

x_cpu_view = x_cpu.view(new_shape)
print(x_cpu_view.size())
self.assertEqual(x_cpu_view.size(), [1, 4, 4, 4])
self.assertEqual(x_cpu_view.stride(), [64, 16, 4, 1])

x_dpcpp_view = x_dpcpp.view(new_shape)
print(x_dpcpp_view.size())
self.assertTrue(ipex.is_dil_tensor(x_dpcpp_view))

y = torch.randn(new_shape)
out_cpu = x_cpu_view * y
# test if the shape of x_dpcpp_view is compatible with y
out_dpcpp = x_dpcpp_view * y
self.assertTrue(ipex.is_dil_tensor(out_dpcpp))
self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4])
self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1])
self.assertEqual(out_cpu, out_dpcpp)

# test if metadata of x_dpcpp has not been altered
Expand Down
4 changes: 2 additions & 2 deletions torch_ipex/csrc/aten_ipex_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "ipex_tensor_impl.h"
#include "ipex_sparse_tensor_impl.h"
#include "cpu/dbl/Common.h"
#include "cpu/ShadeDataContext.h"
#include "cpu/bf16/Converter.h"
#include "utils.h"
Expand Down Expand Up @@ -65,7 +66,6 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
#if defined(_DEBUG)
TORCH_WARN(ipexTensor.is_contiguous());
TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty()));
#endif
dil::tensor &dil_tensor = shade_data_context->dil_tensor;
Expand Down Expand Up @@ -101,7 +101,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
ipexTensor.device().type());

ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous());
cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, pub_tensor);
}
}

Expand Down
19 changes: 4 additions & 15 deletions torch_ipex/csrc/cpu/ShadeDataContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ struct ShadeDataContext {
TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));

if (data_type == SHADE_DATA_TYPE::DIL) {
#if defined(_DEBUG)
TORCH_WARN(tensor.is_contiguous());
#endif
auto raw_cpu_data = tensor.storage().data_ptr().get();
if (raw_cpu_data == nullptr) {
// the dnnl tensor does not share data with raw tensor data.
Expand All @@ -113,15 +110,11 @@ struct ShadeDataContext {
// C = A[4:7, :]
// All these tensors share same buffer of Tensor A with different storge offsets and elements.
// So the context modification will impact all these tensors.
if ((shade_data_context->dil_tensor.get_data_handle() == raw_cpu_data) &&
(shade_data_context->dil_tensor.get_nelems() == tensor.storage().numel()) &&
(shade_data_context->dil_tensor.get_data_type() == get_dil_data_type(tensor.scalar_type()))) {
//TODO: Do we need to check strides here?
if (check_tensor_own_whole_storage(tensor)) {
TORCH_INTERNAL_ASSERT(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity());
return true;
}
}
TORCH_INTERNAL_ASSERT(false);
}

return false;
Expand All @@ -148,13 +141,9 @@ struct ShadeDataContext {
TORCH_INTERNAL_ASSERT(tensor.has_storage());
void *raw_context = tensor.storage().data_ptr().get_context();
TORCH_INTERNAL_ASSERT(raw_context != nullptr);
if (isDilTensor(tensor)) {
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
return shade_data_context->dil_tensor;
} else {
TORCH_INTERNAL_ASSERT(false);
return dil::tensor();
}
TORCH_INTERNAL_ASSERT(isDilTensor(tensor));
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
return shade_data_context->dil_tensor;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion torch_ipex/csrc/cpu/dbl/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
if (cpu::ShadeDataContext::isDilTensor(input)) {
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input);
if (dil_tensor.is_public_format()) {
if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) {
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
}
return dil_tensor;
Expand Down
28 changes: 28 additions & 0 deletions torch_ipex/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "aten_ipex_type.h"
#include "auto_opt_config.h"
#include "cpu/dil/dil.hpp"
#include "cpu/ShadeDataContext.h"
#include "cpu/ExtendOPs.h"
#include "cpu/MlpOPs.h"

Expand All @@ -29,6 +31,28 @@ void setAutoDNNL(bool val) {
AutoOptConfig::singleton().set_auto_dnnl(val);
}

/// **** Only for unit test ****
bool isDilTensor(const at::Tensor &tensor) {
return cpu::ShadeDataContext::isDilTensor(tensor);
}

dil::dims getDilTensorSizes(const at::Tensor &tensor) {
if (isDilTensor(tensor)) {
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor);
return dil_tensor.get_dims();
}
return dil::dims();
}

dil::dims getDilTensorStrides(const at::Tensor &tensor) {
if (isDilTensor(tensor)) {
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor);
return dil_tensor.get_strides();
}
return dil::dims();
}
/// ****************************

void InitIpexModuleBindings(py::module m) {
m.def("_initialize_aten_bindings",
[]() { AtenIpexType::InitializeAtenBindings(); });
Expand Down Expand Up @@ -97,6 +121,10 @@ void InitIpexModuleBindings(py::module m) {
m.def("mlp_create_handle", &AtenIpexTypeMLPExt::create_handle);
m.def("mlp_set_relu_mask", &AtenIpexTypeMLPExt::set_relu_mask);
m.def("mlp_release_handle", &AtenIpexTypeMLPExt::release_handle);

m.def("is_dil_tensor", &isDilTensor);
m.def("get_dil_tensor_sizes", &getDilTensorSizes);
m.def("get_dil_tensor_strides", &getDilTensorStrides);
}

} // namespace
Expand Down
9 changes: 9 additions & 0 deletions torch_ipex/csrc/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,13 @@ bool check_tensor_own_shade_context(const at::Tensor& tensor) {
return (data_ptr != data_ctx) && (data_ctx != nullptr);
}

bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) {
if (dil_tensor.is_public_format()) {
return ipex_tensor.sizes().vec() == dil_tensor.get_dims() &&
ipex_tensor.strides().vec() == dil_tensor.get_strides();
} else {
return ipex_tensor.sizes().vec() == dil_tensor.get_dims();
}
}

} // namespace torch_ipex
1 change: 1 addition & 0 deletions torch_ipex/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ at::ScalarType get_at_data_type(dil::data_type);
bool check_auto_dnnl();
bool check_tensor_own_whole_storage(const at::Tensor& tensor);
bool check_tensor_own_shade_context(const at::Tensor& tensor);
bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor);

} // namespace torch_ipex