Skip to content

Sync the strides and size of DNNL tensor to its aten::tensor wrapper #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CPU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ add_subdirectory(${DPCPP_THIRD_PARTY_ROOT}/mkl-dnn)
# Define build type
IF(CMAKE_BUILD_TYPE MATCHES Debug)
message("Debug build.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -D_DEBUG")
ELSE()
message("Release build.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
Expand Down
27 changes: 27 additions & 0 deletions tests/cpu/test_lazy_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,33 @@ def test_transpose(self):
x_dpcpp.transpose(dim1, dim2),
)

def test_view(self):
ipex.enable_auto_dnnl()
old_shape = (4, 16)
new_shape = (1, 4, 4, 4)

x_cpu = torch.randn(old_shape)
x_dpcpp = x_cpu.to(device=device).clone()
print(x_dpcpp.size())

x_cpu_view = x_cpu.view(new_shape)
print(x_cpu_view.size())
x_dpcpp_view = x_dpcpp.view(new_shape)
print(x_dpcpp_view.size())

y = torch.randn(new_shape)
out_cpu = x_cpu_view * y
# test if the shape of x_dpcpp_view is compatible with y
out_dpcpp = x_dpcpp_view * y
self.assertEqual(out_cpu, out_dpcpp)

# test if metadata of x_dpcpp has not been altered
y = torch.randn(old_shape)
out_cpu = x_cpu * y
out_dpcpp = x_dpcpp * y
self.assertEqual(out_cpu, out_dpcpp)


class TestSoftMax(TestCase):
def test_softmax(self):
ipex.enable_auto_dnnl()
Expand Down
7 changes: 7 additions & 0 deletions tests/cpu/test_rn50_cpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def test_view(self):
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))

# TODO(Eikan): DNNL OP does not support >6 dim tensor, so we disable it temporily. When we fix it, we will open it
old_dnnl_conf = ipex.get_auto_dnnl()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May consider to do with with context manager to save the complexity of save/restore original conf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like that "with" does not work for native C++ API.

ipex.disable_auto_dnnl()
# test view when tensor is not contiguous in every dimension, but only
# contiguous dimensions are touched.
tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3)
Expand All @@ -441,6 +444,10 @@ def test_view(self):
# adding size 1 dims
view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
if old_dnnl_conf:
ipex.enable_auto_dnnl()
else:
ipex.disable_auto_dnnl()

# invalid views
self.assertRaises(RuntimeError, lambda: tensor.view(-1))
Expand Down
48 changes: 36 additions & 12 deletions tests/cpu/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from multiprocessing.reduction import ForkingPickler
from common_device_type import instantiate_device_type_tests, \
skipIf, skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride
dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride, ipex
import torch.backends.quantized


Expand Down Expand Up @@ -8725,7 +8725,10 @@ def test_diagflat(self, device):

# Noncontig input
x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
self.assertFalse(x.is_contiguous())
if ipex.get_auto_dnnl():
self.assertTrue(x.is_contiguous())
else:
self.assertFalse(x.is_contiguous())
result = torch.diagflat(x)
expected = torch.diag(x.contiguous().view(-1))
self.assertEqual(result, expected)
Expand Down Expand Up @@ -9773,8 +9776,12 @@ def test_cdist_non_contiguous(self, device):
y = torch.randn(5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=1, compute_mode=cm)
expected = brute_cdist(x, y, p=1)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
if ipex.get_auto_dnnl():
self.assertTrue(x.is_contiguous())
self.assertTrue(y.is_contiguous())
else:
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(7, 5, device=device)
Expand All @@ -9799,23 +9806,33 @@ def test_cdist_non_contiguous_batch(self, device):
y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=1, compute_mode=cm)
expected = brute_cdist(x, y, p=1)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
if ipex.get_auto_dnnl():
self.assertTrue(x.is_contiguous())
self.assertTrue(y.is_contiguous())
else:
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(7, 2, 7, 5, device=device)
y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2)
actual = torch.cdist(x, y, p=1, compute_mode=cm)
expected = brute_cdist(x, y, p=1)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
if ipex.get_auto_dnnl():
self.assertTrue(y.is_contiguous())
else:
self.assertFalse(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

x = torch.randn(4, 5, 7, device=device).transpose(-1, -2)
y = torch.randn(4, 3, 5, device=device)
actual = torch.cdist(x, y, p=1, compute_mode=cm)
expected = brute_cdist(x, y, p=1)
self.assertFalse(x.is_contiguous())
if ipex.get_auto_dnnl():
self.assertTrue(x.is_contiguous())
else:
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertTrue(torch.allclose(expected, actual))

Expand Down Expand Up @@ -10249,6 +10266,7 @@ def test_unfold_scalars(self, device):

def test_copy_all_dtypes_and_devices(self, device):
from copy import copy
ipex.enable_auto_dnnl()
for dt in torch.testing.get_all_dtypes():
x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
x_clone = x.clone()
Expand All @@ -10264,6 +10282,7 @@ def test_copy_all_dtypes_and_devices(self, device):
# copy is a shallow copy, only copies the tensor view,
# not the data
self.assertEqual(x, y)
ipex.enable_auto_dnnl()

def test_resize_all_dtypes_and_devices(self, device):
shape = (2, 2)
Expand Down Expand Up @@ -10761,7 +10780,8 @@ def test_tensor_shape_empty(self, device):
self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
[z.shape for z in torch.split(x, (0, 1, 2), dim=2)])

self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1))
with self.assertRaises(RuntimeError):
torch.split(x, 0, dim=1)
# This is strange because the split size is larger than the dim size, but consistent with
# how split handles that case generally (when no 0s are involved).
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
Expand Down Expand Up @@ -12764,8 +12784,12 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
clone = transformation_fn(xc)

if default_is_preserve:
self.assertFalse(clone.is_contiguous())
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
if ipex.get_auto_dnnl():
self.assertTrue(clone.is_contiguous())
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
else:
self.assertFalse(clone.is_contiguous())
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
else:
self.assertTrue(clone.is_contiguous())
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
Expand Down Expand Up @@ -14398,7 +14422,6 @@ def fn(self, device, dtype):
# Runs the tensor op on CPU and device
cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
device_result = getattr(device_tensor, op_str)(*device_args)

# Compares CPU and device inputs and outputs
precision = half_precision if dtype == torch.half else float_precision

Expand Down Expand Up @@ -14512,4 +14535,5 @@ class TestTorch(TestCase, _TestTorchMixin):
instantiate_device_type_tests(TestTensorDeviceOps, globals(), except_for='cpu')

if __name__ == '__main__':
ipex.enable_auto_dnnl()
run_tests()
86 changes: 40 additions & 46 deletions torch_ipex/csrc/aten_ipex_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
namespace torch_ipex {
namespace bridge {

#if defined(_DEBUG)
#define CHECK_TENSOR(a, b) \
TORCH_INTERNAL_ASSERT(a.numel() == b.numel()); \
TORCH_INTERNAL_ASSERT(a.dtype() == b.dtype()); \
Expand All @@ -30,13 +31,21 @@ namespace bridge {
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->is_wrapped_number() == b.unsafeGetTensorImpl()->is_wrapped_number()); \
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->version_counter().current_version() == b.unsafeGetTensorImpl()->version_counter().current_version()); \
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->allow_tensor_metadata_change() == b.unsafeGetTensorImpl()->allow_tensor_metadata_change())
#else
#define CHECK_TENSOR(a, b) ((void) 0)
#endif

#if defined(_DEBUG)
#define CHECK_TENSOR_CRITICAL(a, b, may_alias) \
TORCH_INTERNAL_ASSERT(!may_alias || a.data_ptr() == b.data_ptr()); \
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->strides() == b.unsafeGetTensorImpl()->strides()); \
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->storage_offset() == b.unsafeGetTensorImpl()->storage_offset()); \
CHECK_TENSOR(a, b)
#else
#define CHECK_TENSOR_CRITICAL(a, b, may_alias) ((void) 0)
#endif

#if defined(_DEBUG)
#define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) \
TORCH_INTERNAL_ASSERT(!may_alias || a._indices().data_ptr() == b._indices().data_ptr()); \
TORCH_INTERNAL_ASSERT(!may_alias || a._values().data_ptr() == b._values().data_ptr()); \
Expand All @@ -46,43 +55,54 @@ namespace bridge {
TORCH_INTERNAL_ASSERT(a.is_coalesced() == b.is_coalesced()); \
CHECK_TENSOR(a._indices(), b._indices()); \
CHECK_TENSOR(a._values(), b._values())

#else
#define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) ((void) 0)
#endif

at::Tensor shallowFallbackToCPUTensorImpl(const at::Tensor& ipexTensor);

void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
// All aten::tensor with dnnl::tensor should be contiguous
#if defined(_DEBUG)
TORCH_WARN(ipexTensor.is_contiguous());
TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty()));
#endif
dil::tensor &dil_tensor = shade_data_context->dil_tensor;

dil::dims sizes = dil_tensor.get_dims();
dil::dims strides;

if (dil_tensor.is_public_format()) {
#if defined(_DEBUG)
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle());
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data != nullptr);
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun != nullptr);
strides = dil_tensor.get_strides();
#endif
} else {
auto dims = dil_tensor.get_dims();
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
at::Tensor cpu_tensor = at::empty(
sizes, ipexTensor.options().device(c10::kCPU).layout(c10::kStrided));
TORCH_INTERNAL_ASSERT(cpu_tensor.scalar_type() == get_at_data_type(dil_tensor.get_data_type()));
auto pub_tensor = dil_tensor.to_public(cpu_tensor.data_ptr(), dil_tensor.get_data_type());
strides = pub_tensor.get_strides();
at::DataPtr& cpu_tensor_data_ptr = cpu_tensor.unsafeGetTensorImpl()->storage().unsafeGetStorageImpl()->data_ptr();
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(cpu_tensor_data_ptr));
// The tensor has been reset to new DataPtr, then we need to attach new shade data context.
attachShadeDataConext(ipexTensor);
#if defined(_DEBUG)
auto& data_ptr = ipexTensor.storage().unsafeGetStorageImpl()->data_ptr();
TORCH_INTERNAL_ASSERT(data_ptr.get_deleter() == &(cpu::ShadeDataContext::freeShadeDataContext));
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun == nullptr);
#endif
auto pub_tensor = dil_tensor.to_public(nullptr, dil_tensor.get_data_type());

cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful? supposing this tensor is temporary, if so, shadedatacontext will be useless, right?

Copy link
Contributor Author

@EikanWang EikanWang May 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the returned tensor is dnnl tensor, it should be as same as dil_tensor. @pinzhenx , is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my personal view, shade data context should only be attached in upgrade to DPCPP related interfaces.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with your point. In this case, the DNNL tensor is reordered from block formant to plain format. And the buffer of the reordered DNNL tensor can be shared with the CPU. But the DataPtr does not expose the interface to modify its "data" field. Then we replace the old DataPtr for sharing data between CPU buffer and DNNL buffer while attaching a ShadeDataContext for keeping DNNL tensor to avoid resource-release.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should discuss more about "device exchange" and "data type conversion", make it more simple and clear. current implementation may cause data_type conversion attaching shadecontext too.

new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL;
new_shade_data_context->dil_tensor = pub_tensor;
// Share with DNNL raw data because it is plain format now
new_shade_data_context->cpu_raw_data = pub_tensor.get_data_handle();
// Cannot free CPU data because the the data is owned by DNNL
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);

// Create a new DataPtr instances because the DataPtr class does not support set
// its data or context directly
c10::DataPtr shade_data_ptr(
pub_tensor.get_data_handle(),
new_shade_data_context,
&(cpu::ShadeDataContext::freeShadeDataContext),
ipexTensor.device().type());

ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous());
}

auto* ipexTensorImpl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
ipexTensorImpl->force_set_strided(sizes, strides);
}


Expand Down Expand Up @@ -279,32 +299,6 @@ at::Tensor upgradeToDPCPPTensor(const at::Tensor& cpuTensor) {
return _tensor;
}

at::Tensor shallowUpgradeToDPCPPShadeTensor(const at::Tensor& cpuTensor) {
if (!(cpuTensor.defined())) {
return at::Tensor();
}
TORCH_INTERNAL_ASSERT(cpuTensor.device().type() == at::DeviceType::CPU);
if (cpuTensor.is_sparse()) shallowUpgradeToDPCPPTensor(cpuTensor);

auto cpu_storage_impl = cpuTensor.storage().unsafeGetStorageImpl();
auto& data_ptr = cpu_storage_impl->data_ptr();
auto cur_del_fn = data_ptr.get_deleter();
bool res = data_ptr.compare_exchange_deleter(cur_del_fn, &(c10::detail::deleteNothing));
TORCH_INTERNAL_ASSERT(res);
// Make sure that does not triger free resource for set_ptr
cpu::ShadeDataContext *shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
shade_data_context->cpu_raw_data = data_ptr.get();
shade_data_context->cpu_del_fun = cur_del_fn;
shade_data_context->data_type = cpu::SHADE_DATA_TYPE::CPU_RAW;
c10::DataPtr shade_data_ptr(
data_ptr.get(),
shade_data_context,
cpu::ShadeDataContext::freeShadeDataContext,
at::DeviceType::CPU);
cpuTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
return shallowUpgradeToDPCPPTensor(cpuTensor);
}

// Upgrade CPU tensor to DPCPP Tensor with shallow copy
// It will create an new DPCPP tensor but shares CPU tensor buffer
// [NOTE]: Device info of Dense CPU tensor is polluted.
Expand Down
Loading