Skip to content

Commit f1c4d61

Browse files
committed
1. Refine test_view unit test case 2. Revert code to make sure all the input tensors of DNNL is contiguous
1 parent ff47102 commit f1c4d61

File tree

7 files changed

+14
-23
lines changed

7 files changed

+14
-23
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,7 @@ def is_out_func(fname):
306306
if param_var == 'out' and is_out_func(fname):
307307
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
308308
else:
309-
# param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
310-
None
309+
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
311310
param_seq_str_vec.append(param_seq_str)
312311
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors))\n'
313312
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))

tests/cpu/test_lazy_reorder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ def test_addbmm(self):
368368

369369
addbmm_cpu = torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha)
370370
addbmm_dpcpp = torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha)
371-
self.assertEqual(addbmm_cpu, addbmm_dpcpp)
371+
self.assertEqual(addbmm_cpu, addbmm_dpcpp, 1e-4)
372372
y_cpu = torch.randn(M, O, dtype=torch.float32)
373373
y_dpcpp = y_cpu.to(device=device)
374374
torch.addbmm(res_cpu, b1_cpu, b2_cpu, beta=beta, alpha=alpha, out=y_cpu)
375375
torch.addbmm(res_dpcpp, b1_dpcpp, b2_dpcpp, beta=beta, alpha=alpha, out=y_dpcpp)
376-
self.assertEqual(y_cpu, y_dpcpp)
376+
self.assertEqual(y_cpu, y_dpcpp, 1e-4)
377377

378378
def test_baddbmm(self):
379379
ipex.enable_auto_dnnl()
@@ -683,7 +683,6 @@ def test_batch_norm2d_backward(self):
683683

684684
bn = torch.nn.BatchNorm2d(3)
685685
bn_dpcpp = copy.deepcopy(bn).to(device=device)
686-
687686
y_cpu = bn(x_cpu).sum()
688687
y_dpcpp = bn_dpcpp(x_dpcpp).sum()
689688
y_cpu.backward()
@@ -758,21 +757,22 @@ def test_view(self):
758757
x_dpcpp = x_cpu.to(device=device).clone()
759758
self.assertTrue(ipex.is_dil_tensor(x_dpcpp))
760759
self.assertEqual(ipex.get_dil_tensor_sizes(x_dpcpp), [4, 16])
761-
self.assertEqual(ipex.is_dil_tensor_strides(x_dpcpp), [16, 1])
760+
self.assertEqual(ipex.get_dil_tensor_strides(x_dpcpp), [16, 1])
762761

763762
x_cpu_view = x_cpu.view(new_shape)
764763
self.assertEqual(x_cpu_view.size(), [1, 4, 4, 4])
765764
self.assertEqual(x_cpu_view.stride(), [64, 16, 4, 1])
766765

767766
x_dpcpp_view = x_dpcpp.view(new_shape)
768767
self.assertTrue(ipex.is_dil_tensor(x_dpcpp_view))
769-
self.assertEqual(ipex.get_dil_tensor_sizes(x_dpcpp_view), [1, 4, 4, 4])
770-
self.assertEqual(ipex.is_dil_tensor_strides(x_dpcpp_view), [64, 16, 4, 1])
771768

772769
y = torch.randn(new_shape)
773770
out_cpu = x_cpu_view * y
774771
# test if the shape of x_dpcpp_view is compatible with y
775772
out_dpcpp = x_dpcpp_view * y
773+
self.assertTrue(ipex.is_dil_tensor(out_dpcpp))
774+
self.assertEqual(ipex.get_dil_tensor_sizes(out_dpcpp), [1, 4, 4, 4])
775+
self.assertEqual(ipex.get_dil_tensor_strides(out_dpcpp), [64, 16, 4, 1])
776776
self.assertEqual(out_cpu, out_dpcpp)
777777

778778
# test if metadata of x_dpcpp has not been altered

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
7575
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle());
7676
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_raw_data != nullptr);
7777
TORCH_INTERNAL_ASSERT(shade_data_context->cpu_del_fun != nullptr);
78-
TORCH_INTERNAL_ASSERT(check_aten_dil_shape_info(ipexTensor, dil_tensor));
7978
#endif
8079
} else {
8180
#if defined(_DEBUG)

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace cpu {
3030

3131
#define CHECK_DNNL_OP_PRE_COND(tensor) \
3232
TORCH_INTERNAL_ASSERT(tensor.defined()); \
33+
TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); \
3334
TORCH_INTERNAL_ASSERT(tensor.layout() == c10::kStrided)
3435

3536
at::Tensor AtenIpexCPUDev::dil_convolution(

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,11 @@ struct ShadeDataContext {
110110
// C = A[4:7, :]
111111
// All these tensors share same buffer of Tensor A with different storge offsets and elements.
112112
// So the context modification will impact all these tensors.
113-
if ((shade_data_context->dil_tensor.get_data_handle() == raw_cpu_data) &&
114-
(shade_data_context->dil_tensor.get_nelems() == tensor.storage().numel()) &&
115-
(shade_data_context->dil_tensor.get_data_type() == get_dil_data_type(tensor.scalar_type()))) {
116-
//TODO: Do we need to check strides here?
113+
if (check_tensor_own_whole_storage(tensor)) {
117114
TORCH_INTERNAL_ASSERT(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity());
118115
return true;
119116
}
120117
}
121-
TORCH_INTERNAL_ASSERT(false);
122118
}
123119

124120
return false;
@@ -145,13 +141,9 @@ struct ShadeDataContext {
145141
TORCH_INTERNAL_ASSERT(tensor.has_storage());
146142
void *raw_context = tensor.storage().data_ptr().get_context();
147143
TORCH_INTERNAL_ASSERT(raw_context != nullptr);
148-
if (isDilTensor(tensor)) {
149-
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
150-
return shade_data_context->dil_tensor;
151-
} else {
152-
TORCH_INTERNAL_ASSERT(false);
153-
return dil::tensor();
154-
}
144+
TORCH_INTERNAL_ASSERT(isDilTensor(tensor));
145+
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
146+
return shade_data_context->dil_tensor;
155147
}
156148

157149
/**

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ at::Tensor dil_tensor_to_dense(const at::Tensor& tensor) {
3737
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {
3838
if (cpu::ShadeDataContext::isDilTensor(input)) {
3939
auto dil_tensor = cpu::ShadeDataContext::getDilTensor(input);
40-
if (dil_tensor.is_public_format()) {
40+
if ((!check_aten_dil_shape_info(input, dil_tensor)) && dil_tensor.is_public_format()) {
4141
dil_tensor.set_dims_and_strides(input.sizes().vec(), input.strides().vec());
4242
}
4343
return dil_tensor;

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void InitIpexModuleBindings(py::module m) {
124124

125125
m.def("is_dil_tensor", &isDilTensor);
126126
m.def("get_dil_tensor_sizes", &getDilTensorSizes);
127-
m.def("is_dil_tensor_strides", &getDilTensorStrides);
127+
m.def("get_dil_tensor_strides", &getDilTensorStrides);
128128
}
129129

130130
} // namespace

0 commit comments

Comments
 (0)