Skip to content

Commit 04cf2f7

Browse files
authored
Merge pull request #16 from EikanWang/master
Refine unit test case and write the shape meta info back to aten tensor when the tensor is reordered to public format
2 parents f7fbd8a + f1c4d61 commit 04cf2f7

File tree

8 files changed

+61
-26
lines changed

8 files changed

+61
-26
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,15 +306,17 @@ def is_out_func(fname):
306306
if param_var == 'out' and is_out_func(fname):
307307
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
308308
else:
309-
# param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
310-
None
309+
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
311310
param_seq_str_vec.append(param_seq_str)
312311
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n'
313312
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
314313

315314
code += ' }\n'
316315

317316
code += ' } catch (std::exception& e) {\n'
317+
code += '#if defined(_DEBUG)\n'
318+
code += ' TORCH_WARN(e.what());\n'
319+
code += '#endif\n'
318320
code += ' }\n\n'
319321

320322

tests/cpu/test_lazy_reorder.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ def test_addbmm(self):
368368

369369
addbmm_cpu = torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha)
370370
addbmm_dpcpp = torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
371-
self.assertEqual(addbmm_cpu, addbmm_dpcpp)
371+
self.assertEqual(addbmm_cpu, addbmm_dpcpp, 1e-4)
372372
y_cpu = torch.randn(M, O, dtype=torch.float32)
373373
y_dpcpp = y_cpu.to(device=device)
374374
torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha, out=y_cpu)
375375
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
376-
self.assertEqual(y_cpu, y_dpcpp)
376+
self.assertEqual(y_cpu, y_dpcpp, 1e-4)
377377

378378
def test_baddbmm(self):
379379
ipex.enable_auto_dnnl()
@@ -683,7 +683,6 @@ def test_batch_norm2d_backward(self):
683683

684684
bn = torch.nn.BatchNorm2d(3)
685685
bn_dpcpp = copy.deepcopy(bn).to(device=device)
686-
687686
y_cpu = bn(x_cpu).sum()
688687
y_dpcpp = bn_dpcpp(x_dpcpp).sum()
689688
y_cpu.backward()
@@ -756,17 +755,24 @@ def test_view(self):
756755

757756
x_cpu = torch.randn(old_shape)
758757
x_dpcpp = x_cpu.to(device=device).clone()
759-
print(x_dpcpp.size())
758+
self.assertTrue(ipex.is_dil_tensor(x_dpcpp))
759+
self.assertEqual(ipex.get_dil_tensor_sizes(x_dpcpp), [4, 16])
760+
self.assertEqual(ipex.get_dil_tensor_strides(x_dpcpp), [16, 1])
760761

761762
x_cpu_view = x_cpu.view(new_shape)
762-
print(x_cpu_view.size())
763+
self.assertEqual(x_cpu_view.size(), [1, 4, 4, 4])
764+
self.assertEqual(x_cpu_view.stride(), [64, 16, 4, 1])
765+
763766
x_dpcpp_view = x_dpcpp.view(new_shape)
764-
print(x_dpcpp_view.size())
767+
self.assertTrue(ipex.is_dil_tensor(x_dpcpp_view))
765768

766769
y = torch.randn(new_shape)
767770
out_cpu = x_cpu_view * y
768771
# test if the shape of x_dpcpp_view is compatible with y
769772
out_dpcpp = x_dpcpp_view * y
773+
self.assertTrue(ipex.is_dil_tensor(out_dpcpp))
774+
self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4])
775+
self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1])
770776
self.assertEqual(out_cpu, out_dpcpp)
771777

772778
# test if metadata of x_dpcpp has not been altered

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "ipex_tensor_impl.h"
1313
#include "ipex_sparse_tensor_impl.h"
14+
#include "cpu/dbl/Common.h"
1415
#include "cpu/ShadeDataContext.h"
1516
#include "cpu/bf16/Converter.h"
1617
#include "utils.h"
@@ -65,7 +66,6 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
6566
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
6667
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
6768
#if defined(_DEBUG)
68-
TORCH_WARN(ipexTensor.is_contiguous());
6969
TORCH_INTERNAL_ASSERT(! (shade_data_context->dil_tensor.is_empty()));
7070
#endif
7171
dil::tensor &dil_tensor = shade_data_context->dil_tensor;
@@ -101,7 +101,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
101101
ipexTensor.device().type());
102102

103103
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
104-
TORCH_INTERNAL_ASSERT(ipexTensor.is_contiguous());
104+
cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, pub_tensor);
105105
}
106106
}
107107

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ struct ShadeDataContext {
9090
TORCH_INTERNAL_ASSERT((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));
9191

9292
if (data_type == SHADE_DATA_TYPE::DIL) {
93-
#if defined(_DEBUG)
94-
TORCH_WARN(tensor.is_contiguous());
95-
#endif
9693
auto raw_cpu_data = tensor.storage().data_ptr().get();
9794
if (raw_cpu_data == nullptr) {
9895
// the dnnl tensor does not share data with raw tensor data.
@@ -113,15 +110,11 @@ struct ShadeDataContext {
113110
// C = A[4:7, :]
114111
// All these tensors share same buffer of Tensor A with different storge offsets and elements.
115112
// So the context modification will impact all these tensors.
116-
if ((shade_data_context->dil_tensor.get_data_handle() == raw_cpu_data) &&
117-
(shade_data_context->dil_tensor.get_nelems() == tensor.storage().numel()) &&
118-
(shade_data_context->dil_tensor.get_data_type() == get_dil_data_type(tensor.scalar_type()))) {
119-
//TODO: Do we need to check strides here?
113+
if (check_tensor_own_whole_storage(tensor)) {
120114
TORCH_INTERNAL_ASSERT(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity());
121115
return true;
122116
}
123117
}
124-
TORCH_INTERNAL_ASSERT(false);
125118
}
126119

127120
return false;
@@ -148,13 +141,9 @@ struct ShadeDataContext {
148141
TORCH_INTERNAL_ASSERT(tensor.has_storage());
149142
void *raw_context = tensor.storage().data_ptr().get_context();
150143
TORCH_INTERNAL_ASSERT(raw_context != nullptr);
151-
if (isDilTensor(tensor)) {
152-
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
153-
return shade_data_context->dil_tensor;
154-
} else {
155-
TORCH_INTERNAL_ASSERT(false);
156-
return dil::tensor();
157-
}
144+
TORCH_INTERNAL_ASSERT(isDilTensor(tensor));
145+
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
146+
return shade_data_context->dil_tensor;
158147
}
159148

160149
/**

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
3737
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
3838
if (cpu::ShadeDataContext::isDilTensor(input)) {
3939
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input);
40-
if (dil_tensor.is_public_format()) {
40+
if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) {
4141
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
4242
}
4343
return dil_tensor;

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "aten_ipex_type.h"
1414
#include "auto_opt_config.h"
15+
#include "cpu/dil/dil.hpp"
16+
#include "cpu/ShadeDataContext.h"
1517
#include "cpu/ExtendOPs.h"
1618
#include "cpu/MlpOPs.h"
1719

@@ -29,6 +31,28 @@ void setAutoDNNL(bool val) {
2931
AutoOptConfig::singleton().set_auto_dnnl(val);
3032
}
3133

34+
/// **** Only for unit test ****
35+
bool isDilTensor(const at::Tensor &tensor) {
36+
return cpu::ShadeDataContext::isDilTensor(tensor);
37+
}
38+
39+
dil::dims getDilTensorSizes(const at::Tensor &tensor) {
40+
if (isDilTensor(tensor)) {
41+
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor);
42+
return dil_tensor.get_dims();
43+
}
44+
return dil::dims();
45+
}
46+
47+
dil::dims getDilTensorStrides(const at::Tensor &tensor) {
48+
if (isDilTensor(tensor)) {
49+
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(tensor);
50+
return dil_tensor.get_strides();
51+
}
52+
return dil::dims();
53+
}
54+
/// ****************************
55+
3256
void InitIpexModuleBindings(py::module m) {
3357
m.def("_initialize_aten_bindings",
3458
[]() { AtenIpexType::InitializeAtenBindings(); });
@@ -97,6 +121,10 @@ void InitIpexModuleBindings(py::module m) {
97121
m.def("mlp_create_handle", &AtenIpexTypeMLPExt::create_handle);
98122
m.def("mlp_set_relu_mask", &AtenIpexTypeMLPExt::set_relu_mask);
99123
m.def("mlp_release_handle", &AtenIpexTypeMLPExt::release_handle);
124+
125+
m.def("is_dil_tensor", &isDilTensor);
126+
m.def("get_dil_tensor_sizes", &getDilTensorSizes);
127+
m.def("get_dil_tensor_strides", &getDilTensorStrides);
100128
}
101129

102130
} // namespace

torch_ipex/csrc/utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,13 @@ bool check_tensor_own_shade_context(const at::Tensor& tensor) {
127127
return (data_ptr != data_ctx) && (data_ctx != nullptr);
128128
}
129129

130+
bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor) {
131+
if (dil_tensor.is_public_format()) {
132+
return ipex_tensor.sizes().vec() == dil_tensor.get_dims() &&
133+
ipex_tensor.strides().vec() == dil_tensor.get_strides();
134+
} else {
135+
return ipex_tensor.sizes().vec() == dil_tensor.get_dims();
136+
}
137+
}
138+
130139
} // namespace torch_ipex

torch_ipex/csrc/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ at::ScalarType get_at_data_type(dil::data_type);
2020
bool check_auto_dnnl();
2121
bool check_tensor_own_whole_storage(const at::Tensor& tensor);
2222
bool check_tensor_own_shade_context(const at::Tensor& tensor);
23+
bool check_aten_dil_shape_info(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor);
2324

2425
} // namespace torch_ipex

0 commit comments

Comments
 (0)