Skip to content

Commit f47ad5b

Browse files
authored
Merge pull request #5 from EikanWang/fix_dil_at_strides_issue
Sync the strides and size of DNNL tensor to its aten::tensor wrapper
2 parents 2f37927 + e91c403 commit f47ad5b

File tree

12 files changed

+164
-66
lines changed

12 files changed

+164
-66
lines changed

cmake/CPU.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ENDIF()
2424
# Define build type
2525
IF(CMAKE_BUILD_TYPE MATCHES Debug)
2626
message("Debug build.")
27-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g")
27+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -D_DEBUG")
2828
ELSE()
2929
message("Release build.")
3030
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")

tests/cpu/test_lazy_reorder.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,33 @@ def test_transpose(self):
749749
x_dpcpp.transpose(dim1, dim2),
750750
)
751751

752+
def test_view(self):
753+
ipex.enable_auto_dnnl()
754+
old_shape = (4, 16)
755+
new_shape = (1, 4, 4, 4)
756+
757+
x_cpu = torch.randn(old_shape)
758+
x_dpcpp = x_cpu.to(device=device).clone()
759+
print(x_dpcpp.size())
760+
761+
x_cpu_view = x_cpu.view(new_shape)
762+
print(x_cpu_view.size())
763+
x_dpcpp_view = x_dpcpp.view(new_shape)
764+
print(x_dpcpp_view.size())
765+
766+
y = torch.randn(new_shape)
767+
out_cpu = x_cpu_view * y
768+
# test if the shape of x_dpcpp_view is compatible with y
769+
out_dpcpp = x_dpcpp_view * y
770+
self.assertEqual(out_cpu, out_dpcpp)
771+
772+
# test if metadata of x_dpcpp has not been altered
773+
y = torch.randn(old_shape)
774+
out_cpu = x_cpu * y
775+
out_dpcpp = x_dpcpp * y
776+
self.assertEqual(out_cpu, out_dpcpp)
777+
778+
752779
class TestSoftMax(TestCase):
753780
def test_softmax(self):
754781
ipex.enable_auto_dnnl()

tests/cpu/test_rn50_cpu_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,9 @@ def test_view(self):
416416
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
417417
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
418418

419+
# TODO(Eikan): DNNL OP does not support >6 dim tensor, so we disable it temporily. When we fix it, we will open it
420+
old_dnnl_conf = ipex.get_auto_dnnl()
421+
ipex.disable_auto_dnnl()
419422
# test view when tensor is not contiguous in every dimension, but only
420423
# contiguous dimensions are touched.
421424
tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3)
@@ -441,6 +444,10 @@ def test_view(self):
441444
# adding size 1 dims
442445
view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
443446
self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
447+
if old_dnnl_conf:
448+
ipex.enable_auto_dnnl()
449+
else:
450+
ipex.disable_auto_dnnl()
444451

445452
# invalid views
446453
self.assertRaises(RuntimeError, lambda: tensor.view(-1))

tests/cpu/test_torch.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from multiprocessing.reduction import ForkingPickler
8282
from common_device_type import instantiate_device_type_tests, \
8383
skipIf, skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
84-
dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride
84+
dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride, ipex
8585
import torch.backends.quantized
8686

8787

@@ -8725,7 +8725,10 @@ def test_diagflat(self, device):
87258725

87268726
# Noncontig input
87278727
x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
8728-
self.assertFalse(x.is_contiguous())
8728+
if ipex.get_auto_dnnl():
8729+
self.assertTrue(x.is_contiguous())
8730+
else:
8731+
self.assertFalse(x.is_contiguous())
87298732
result = torch.diagflat(x)
87308733
expected = torch.diag(x.contiguous().view(-1))
87318734
self.assertEqual(result, expected)
@@ -9773,8 +9776,12 @@ def test_cdist_non_contiguous(self, device):
97739776
y = torch.randn(5, 3, device=device).transpose(-1, -2)
97749777
actual = torch.cdist(x, y, p=1, compute_mode=cm)
97759778
expected = brute_cdist(x, y, p=1)
9776-
self.assertFalse(x.is_contiguous())
9777-
self.assertFalse(y.is_contiguous())
9779+
if ipex.get_auto_dnnl():
9780+
self.assertTrue(x.is_contiguous())
9781+
self.assertTrue(y.is_contiguous())
9782+
else:
9783+
self.assertFalse(x.is_contiguous())
9784+
self.assertFalse(y.is_contiguous())
97789785
self.assertTrue(torch.allclose(expected, actual))
97799786

97809787
x = torch.randn(7, 5, device=device)
@@ -9799,23 +9806,33 @@ def test_cdist_non_contiguous_batch(self, device):
97999806
y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2)
98009807
actual = torch.cdist(x, y, p=1, compute_mode=cm)
98019808
expected = brute_cdist(x, y, p=1)
9802-
self.assertFalse(x.is_contiguous())
9803-
self.assertFalse(y.is_contiguous())
9809+
if ipex.get_auto_dnnl():
9810+
self.assertTrue(x.is_contiguous())
9811+
self.assertTrue(y.is_contiguous())
9812+
else:
9813+
self.assertFalse(x.is_contiguous())
9814+
self.assertFalse(y.is_contiguous())
98049815
self.assertTrue(torch.allclose(expected, actual))
98059816

98069817
x = torch.randn(7, 2, 7, 5, device=device)
98079818
y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2)
98089819
actual = torch.cdist(x, y, p=1, compute_mode=cm)
98099820
expected = brute_cdist(x, y, p=1)
98109821
self.assertTrue(x.is_contiguous())
9811-
self.assertFalse(y.is_contiguous())
9822+
if ipex.get_auto_dnnl():
9823+
self.assertTrue(y.is_contiguous())
9824+
else:
9825+
self.assertFalse(y.is_contiguous())
98129826
self.assertTrue(torch.allclose(expected, actual))
98139827

98149828
x = torch.randn(4, 5, 7, device=device).transpose(-1, -2)
98159829
y = torch.randn(4, 3, 5, device=device)
98169830
actual = torch.cdist(x, y, p=1, compute_mode=cm)
98179831
expected = brute_cdist(x, y, p=1)
9818-
self.assertFalse(x.is_contiguous())
9832+
if ipex.get_auto_dnnl():
9833+
self.assertTrue(x.is_contiguous())
9834+
else:
9835+
self.assertFalse(x.is_contiguous())
98199836
self.assertTrue(y.is_contiguous())
98209837
self.assertTrue(torch.allclose(expected, actual))
98219838

@@ -10249,6 +10266,7 @@ def test_unfold_scalars(self, device):
1024910266

1025010267
def test_copy_all_dtypes_and_devices(self, device):
1025110268
from copy import copy
10269+
ipex.enable_auto_dnnl()
1025210270
for dt in torch.testing.get_all_dtypes():
1025310271
x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
1025410272
x_clone = x.clone()
@@ -10264,6 +10282,7 @@ def test_copy_all_dtypes_and_devices(self, device):
1026410282
# copy is a shallow copy, only copies the tensor view,
1026510283
# not the data
1026610284
self.assertEqual(x, y)
10285+
ipex.enable_auto_dnnl()
1026710286

1026810287
def test_resize_all_dtypes_and_devices(self, device):
1026910288
shape = (2, 2)
@@ -10761,7 +10780,8 @@ def test_tensor_shape_empty(self, device):
1076110780
self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
1076210781
[z.shape for z in torch.split(x, (0, 1, 2), dim=2)])
1076310782

10764-
self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1))
10783+
with self.assertRaises(RuntimeError):
10784+
torch.split(x, 0, dim=1)
1076510785
# This is strange because the split size is larger than the dim size, but consistent with
1076610786
# how split handles that case generally (when no 0s are involved).
1076710787
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
@@ -12764,8 +12784,12 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
1276412784
clone = transformation_fn(xc)
1276512785

1276612786
if default_is_preserve:
12767-
self.assertFalse(clone.is_contiguous())
12768-
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
12787+
if ipex.get_auto_dnnl():
12788+
self.assertTrue(clone.is_contiguous())
12789+
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
12790+
else:
12791+
self.assertFalse(clone.is_contiguous())
12792+
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1276912793
else:
1277012794
self.assertTrue(clone.is_contiguous())
1277112795
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
@@ -14398,7 +14422,6 @@ def fn(self, device, dtype):
1439814422
# Runs the tensor op on CPU and device
1439914423
cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
1440014424
device_result = getattr(device_tensor, op_str)(*device_args)
14401-
1440214425
# Compares CPU and device inputs and outputs
1440314426
precision = half_precision if dtype == torch.half else float_precision
1440414427

@@ -14512,4 +14535,5 @@ class TestTorch(TestCase, _TestTorchMixin):
1451214535
instantiate_device_type_tests(TestTensorDeviceOps, globals(), except_for='cpu')
1451314536

1451414537
if __name__ == '__main__':
14538+
ipex.enable_auto_dnnl()
1451514539
run_tests()

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
namespace torch_ipex {
1919
namespace bridge {
2020

21+
#if defined(_DEBUG)
2122
#define CHECK_TENSOR(a, b) \
2223
TORCH_INTERNAL_ASSERT(a.numel() == b.numel()); \
2324
TORCH_INTERNAL_ASSERT(a.dtype() == b.dtype()); \
@@ -30,13 +31,21 @@ namespace bridge {
3031
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->is_wrapped_number() == b.unsafeGetTensorImpl()->is_wrapped_number()); \
3132
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->version_counter().current_version() == b.unsafeGetTensorImpl()->version_counter().current_version()); \
3233
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->allow_tensor_metadata_change() == b.unsafeGetTensorImpl()->allow_tensor_metadata_change())
34+
#else
35+
#define CHECK_TENSOR(a, b) ((void) 0)
36+
#endif
3337

38+
#if defined(_DEBUG)
3439
#define CHECK_TENSOR_CRITICAL(a, b, may_alias) \
3540
TORCH_INTERNAL_ASSERT(!may_alias || a.data_ptr() == b.data_ptr()); \
3641
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->strides() == b.unsafeGetTensorImpl()->strides()); \
3742
TORCH_INTERNAL_ASSERT(a.unsafeGetTensorImpl()->storage_offset() == b.unsafeGetTensorImpl()->storage_offset()); \
3843
CHECK_TENSOR(a, b)
44+
#else
45+
#define CHECK_TENSOR_CRITICAL(a, b, may_alias) ((void) 0)
46+
#endif
3947

48+
#if defined(_DEBUG)
4049
#define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) \
4150
TORCH_INTERNAL_ASSERT(!may_alias || a._indices().data_ptr() == b._indices().data_ptr()); \
4251
TORCH_INTERNAL_ASSERT(!may_alias || a._values().data_ptr() == b._values().data_ptr()); \
@@ -46,43 +55,54 @@ namespace bridge {
4655
TORCH_INTERNAL_ASSERT(a.is_coalesced() == b.is_coalesced()); \
4756
CHECK_TENSOR(a._indices(), b._indices()); \
4857
CHECK_TENSOR(a._values(), b._values())
49-
58+
#else
59+
#define CHECK_SPARSE_TENSOR_CRITICAL(a, b, may_alias) ((void) 0)
60+
#endif
5061

5162
at::Tensor shallowFallbackToCPUTensorImpl(const at::Tensor& ipexTensor);
5263

5364
void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
5465
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
5566
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
56-
// All aten::tensor with dnnl::tensor should be contiguous
67+
#if defined(_DEBUG)
5768
TORCH_WARN(ipexTensor.is_contiguous());
5869
TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty()));
70+
#endif
5971
dil::tensor &dil_tensor = shade_data_context->dil_tensor;
6072

61-
dil::dims sizes = dil_tensor.get_dims();
62-
dil::dims strides;
63-
6473
if (dil_tensor.is_public_format()) {
74+
#if defined(_DEBUG)
6575
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle());
6676
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data != nullptr);
6777
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun != nullptr);
68-
strides = dil_tensor.get_strides();
78+
#endif
6979
} else {
70-
auto dims = dil_tensor.get_dims();
71-
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
72-
at::Tensor cpu_tensor = at::empty(
73-
sizes, ipexTensor.options().device(c10::kCPU).layout(c10::kStrided));
74-
TORCH_INTERNAL_ASSERT(cpu_tensor.scalar_type() == get_at_data_type(dil_tensor.get_data_type()));
75-
auto pub_tensor = dil_tensor.to_public(cpu_tensor.data_ptr(), dil_tensor.get_data_type());
76-
strides = pub_tensor.get_strides();
77-
at::DataPtr& cpu_tensor_data_ptr = cpu_tensor.unsafeGetTensorImpl()->storage().unsafeGetStorageImpl()->data_ptr();
78-
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(cpu_tensor_data_ptr));
79-
// The tensor has been reset to new DataPtr, then we need to attach new shade data context.
80-
attachShadeDataConext(ipexTensor);
80+
#if defined(_DEBUG)
81+
auto& data_ptr = ipexTensor.storage().unsafeGetStorageImpl()->data_ptr();
82+
TORCH_INTERNAL_ASSERT(data_ptr.get_deleter() == &(cpu::ShadeDataContext::freeShadeDataContext));
83+
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun == nullptr);
84+
#endif
85+
auto pub_tensor = dil_tensor.to_public(nullptr, dil_tensor.get_data_type());
86+
87+
cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
88+
new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL;
89+
new_shade_data_context->dil_tensor = pub_tensor;
90+
// Share with DNNL raw data because it is plain format now
91+
new_shade_data_context->cpu_raw_data = pub_tensor.get_data_handle();
92+
// Cannot free CPU data because the the data is owned by DNNL
93+
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
94+
95+
// Create a new DataPtr instances because the DataPtr class does not support set
96+
// its data or context directly
97+
c10::DataPtr shade_data_ptr(
98+
pub_tensor.get_data_handle(),
99+
new_shade_data_context,
100+
&(cpu::ShadeDataContext::freeShadeDataContext),
101+
ipexTensor.device().type());
102+
103+
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
81104
TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous());
82105
}
83-
84-
auto* ipexTensorImpl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
85-
ipexTensorImpl->force_set_strided(sizes, strides);
86106
}
87107

88108

@@ -279,32 +299,6 @@ at::Tensor upgradeToDPCPPTensor(const at::Tensor& cpuTensor) {
279299
return _tensor;
280300
}
281301

282-
at::Tensor shallowUpgradeToDPCPPShadeTensor(const at::Tensor& cpuTensor) {
283-
if (!(cpuTensor.defined())) {
284-
return at::Tensor();
285-
}
286-
TORCH_INTERNAL_ASSERT(cpuTensor.device().type() == at::DeviceType::CPU);
287-
if (cpuTensor.is_sparse()) shallowUpgradeToDPCPPTensor(cpuTensor);
288-
289-
auto cpu_storage_impl = cpuTensor.storage().unsafeGetStorageImpl();
290-
auto& data_ptr = cpu_storage_impl->data_ptr();
291-
auto cur_del_fn = data_ptr.get_deleter();
292-
bool res = data_ptr.compare_exchange_deleter(cur_del_fn, &(c10::detail::deleteNothing));
293-
TORCH_INTERNAL_ASSERT(res);
294-
// Make sure that does not triger free resource for set_ptr
295-
cpu::ShadeDataContext *shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
296-
shade_data_context->cpu_raw_data = data_ptr.get();
297-
shade_data_context->cpu_del_fun = cur_del_fn;
298-
shade_data_context->data_type = cpu::SHADE_DATA_TYPE::CPU_RAW;
299-
c10::DataPtr shade_data_ptr(
300-
data_ptr.get(),
301-
shade_data_context,
302-
cpu::ShadeDataContext::freeShadeDataContext,
303-
at::DeviceType::CPU);
304-
cpuTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
305-
return shallowUpgradeToDPCPPTensor(cpuTensor);
306-
}
307-
308302
// Upgrade CPU tensor to DPCPP Tensor with shallow copy
309303
// It will create an new DPCPP tensor but shares CPU tensor buffer
310304
// [NOTE]: Device info of Dense CPU tensor is polluted.

0 commit comments

Comments
 (0)