Skip to content

Commit e24f42d

Browse files
committed
refine binary ops
- support resizing behavior of out... parameter - throw exception on unbroadcastable inputs - remove workaround of broadcast in jit
1 parent b0e83a0 commit e24f42d

File tree

4 files changed

+145
-69
lines changed

4 files changed

+145
-69
lines changed

intel_pytorch_extension_py/ops/jit_script.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1313
torch.jit.script = script_
1414

1515
if core.get_jit():
16-
# bypass buggy broadcastable ops in dnnl during folding
17-
core.disable_auto_dnnl()
1816
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
19-
core.enable_auto_dnnl()
2017

2118
return jit_m
2219

tests/cpu/test_lazy_reorder.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def get_rand_seed():
3333
return int(time.time() * 1000000000)
3434

3535
device = ipex.DEVICE
36+
37+
def convert_blocked(t):
38+
assert t.dim() == 4, "only support converting 4d tensor"
39+
c = t.size(1)
40+
t = t.clone().to(device)
41+
return F.conv2d(t, torch.ones(c, 1, 1, 1).to(device), groups=c)
42+
3643
class TestConv(TestCase):
3744
def test_Conv2d_with_cpu(self):
3845
rand_seed = int(get_rand_seed())
@@ -202,6 +209,78 @@ def test_mul_(self):
202209
a2 = self._test_mul_('cpu', rand_seed)
203210
self.assertEqual(a2, a1.to('cpu'))
204211

212+
def test_mixed_format(self):
213+
ipex.core.enable_auto_dnnl()
214+
rand_seed = int(get_rand_seed())
215+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
216+
torch.manual_seed(rand_seed)
217+
218+
shape = (2, 3, 4, 5)
219+
220+
for fname in ['add', 'mul']:
221+
222+
x_cpu = torch.ones(shape) * 5
223+
y_cpu = torch.ones(shape) * 4
224+
225+
# block tensor is a dpcpp tensor
226+
x_plain = x_cpu.clone().to(device)
227+
y_plain = y_cpu.clone().to(device)
228+
x_block = convert_blocked(x_cpu.clone())
229+
y_block = convert_blocked(y_cpu.clone())
230+
231+
fn = getattr(torch, fname)
232+
ref = fn(x_cpu, y_cpu)
233+
234+
# test add, mul
235+
def test_outplace(a, b):
236+
a = a.clone()
237+
b = b.clone()
238+
self.assertEqual(fn(a, b), ref)
239+
240+
test_outplace(x_plain, y_plain)
241+
test_outplace(x_plain, y_block)
242+
test_outplace(y_block, x_plain)
243+
test_outplace(x_block, y_block)
244+
245+
# test add_out, mul_out
246+
def test_out(a, b, o):
247+
a = a.clone()
248+
b = b.clone()
249+
o = o.clone()
250+
y = fn(a, b, out=o)
251+
self.assertEqual(y, ref)
252+
self.assertEqual(o, ref)
253+
254+
out = torch.ones(shape).to(device)
255+
test_out(x_plain, y_plain, out)
256+
test_out(x_plain, y_block, out)
257+
test_out(y_block, x_plain, out)
258+
test_out(x_block, y_block, out)
259+
out = torch.ones(1).to(device)
260+
test_out(x_plain, y_plain, out)
261+
test_out(x_plain, y_block, out)
262+
test_out(y_block, x_plain, out)
263+
test_out(x_block, y_block, out)
264+
265+
# test add_, mul_
266+
def test_inplace(a, b):
267+
a = a.clone()
268+
b = b.clone()
269+
y = getattr(a, fname + '_')(b)
270+
self.assertEqual(a, ref)
271+
self.assertEqual(y, ref)
272+
273+
test_inplace(x_plain, y_plain)
274+
test_inplace(x_plain, y_block)
275+
test_inplace(y_block, x_plain)
276+
test_inplace(x_block, y_block)
277+
278+
# test broadcast
279+
scalar = torch.ones(1).to(device)
280+
self.assertEqual(fn(x_plain, scalar), fn(x_cpu, scalar))
281+
self.assertEqual(fn(scalar, x_plain), fn(scalar, x_cpu))
282+
283+
205284
class TestRelu(TestCase):
206285
def _test_relu_(self, device, rand_seed):
207286
torch.manual_seed(rand_seed)

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 47 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -254,107 +254,95 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_
254254
return std::tuple<at::Tensor,at::Tensor,at::Tensor>(bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<2>(_ipex_result)));
255255
}
256256

257-
at::Tensor& AtenIpexCPUDev::dil_add_out(
257+
template<bool inplace>
258+
at::Tensor& dil_add_common(
258259
at::Tensor& result,
259260
const at::Tensor& self,
260261
const at::Tensor& other,
261262
at::Scalar alpha) {
262-
DEBUG("AtenIpexCPUDev::dil_add_out\n");
263263
CHECK_DNNL_OP_PRE_COND(self);
264264
CHECK_DNNL_OP_PRE_COND(other);
265265

266+
TORCH_CHECK(self.sizes().equals(other.sizes()),
267+
"dil add not support broadcast yet");
268+
266269
dbl::comm::reorder_to_bf16_for_mix_prec(self);
267270
dbl::comm::reorder_to_bf16_for_mix_prec(other);
268-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
269271

270-
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
271-
dil::tensor y = dbl::comm::try_gen_dil_tensor(other);
272+
auto x = dbl::comm::try_gen_dil_tensor(self);
273+
auto y = dbl::comm::try_gen_dil_tensor(other);
274+
auto z = inplace ? x : dil::tensor();
272275

273-
dil::tensor z = dbl::comm::try_gen_dil_tensor(result);
274-
const std::vector<float> scales{1.0, alpha.to<float>()};
275-
dil::sum::compute(scales, {x, y}, z);
276+
dil::sum::compute({1.0, alpha.to<float>()}, {x, y}, z);
276277

277-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result));
278-
dbl::comm::sync_shape_from_dil_to_aten(result, z);
278+
if (!inplace) {
279+
dbl::comm::equip_dil_buffer(result, z);
280+
}
279281
return result;
280282
}
281283

282-
at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
283-
DEBUG("AtenIpexCPUDev::dil_add\n");
284-
CHECK_DNNL_OP_PRE_COND(self);
285-
CHECK_DNNL_OP_PRE_COND(other);
286-
287-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
288-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
284+
at::Tensor& AtenIpexCPUDev::dil_add_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
285+
DEBUG("AtenIpexCPUDev::dil_add_out\n");
289286

290-
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
291-
dil::tensor y = dbl::comm::try_gen_dil_tensor(other);
287+
return dil_add_common</*inplace=*/false>(result, self, other, alpha);
288+
}
292289

293-
dil::tensor z;
294-
const std::vector<float> scales{1.0, alpha.to<float>()};
295-
dil::sum::compute(scales, {x, y}, z);
290+
at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
291+
DEBUG("AtenIpexCPUDev::dil_add\n");
296292

297-
return dbl::comm::gen_aten_tensor_by(std::move(z));
293+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
294+
return dil_add_common</*inplace=*/false>(result, self, other, alpha);
298295
}
299296

300297
at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
301298
DEBUG("AtenIpexCPUDev::dil_add_\n");
299+
300+
return dil_add_common</*inplace=*/true>(self, self, other, alpha);
301+
}
302+
303+
template<bool inplace>
304+
at::Tensor& dil_mul_common(
305+
at::Tensor& result,
306+
const at::Tensor& self,
307+
const at::Tensor& other) {
302308
CHECK_DNNL_OP_PRE_COND(self);
303309
CHECK_DNNL_OP_PRE_COND(other);
304310

311+
TORCH_CHECK(self.sizes().equals(other.sizes()),
312+
"dil mul not support broadcast yet");
313+
305314
dbl::comm::reorder_to_bf16_for_mix_prec(self);
306315
dbl::comm::reorder_to_bf16_for_mix_prec(other);
307316

308-
auto dil_self = dbl::comm::try_gen_dil_tensor(self);
309-
auto dil_other = dbl::comm::try_gen_dil_tensor(other);
317+
auto x = dbl::comm::try_gen_dil_tensor(self);
318+
auto y = dbl::comm::try_gen_dil_tensor(other);
319+
auto z = inplace ? x : dil::tensor();
310320

311-
const std::vector<float> scales{1.0, alpha.to<float>()};
312-
dil::sum::compute(scales, {dil_self, dil_other}, dil_self);
321+
dil::binary::compute(x, y, z, dil::algorithm::binary_mul);
313322

314-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self));
315-
dbl::comm::sync_shape_from_dil_to_aten(self, dil_self);
316-
return self;
323+
if (!inplace) {
324+
dbl::comm::equip_dil_buffer(result, z);
325+
}
326+
return result;
317327
}
318328

319329
at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
320330
DEBUG("AtenIpexCPUDev::dil_mul_out\n");
321-
CHECK_DNNL_OP_PRE_COND(result);
322-
CHECK_DNNL_OP_PRE_COND(self);
323-
CHECK_DNNL_OP_PRE_COND(other);
324-
325-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
326-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
327-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
328-
329-
auto dil_result = dbl::comm::try_gen_dil_tensor(result);
330-
auto dil_self = dbl::comm::try_gen_dil_tensor(self);
331-
auto dil_other = dbl::comm::try_gen_dil_tensor(other);
332-
333-
dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul);
334331

335-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result));
336-
dbl::comm::sync_shape_from_dil_to_aten(result, dil_result);
337-
return result;
332+
return dil_mul_common</*inplace=*/false>(result, self, other);
338333
}
339334

340335
at::Tensor AtenIpexCPUDev::dil_mul(const at::Tensor& self, const at::Tensor& other) {
341336
DEBUG("AtenIpexCPUDev::dil_mul\n");
342337

343-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
344-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
345-
346-
at::Tensor result = dbl::comm::empty_dil_tensor(self.sizes(), self.options());
347-
348-
return dil_mul_out(result, self, other);
338+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
339+
return dil_mul_common</*inplace=*/false>(result, self, other);
349340
}
350341

351342
at::Tensor& AtenIpexCPUDev::dil_mul_(at::Tensor& self, const at::Tensor& other) {
352343
DEBUG("AtenIpexCPUDev::dil_mul_\n");
353344

354-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
355-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
356-
357-
return dil_mul_out(self, self, other);
345+
return dil_mul_common</*inplace=*/true>(self, self, other);
358346
}
359347

360348
void matmul_common(
@@ -472,7 +460,7 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
472460
result.resize_(inferred_size);
473461
}
474462
TORCH_CHECK(self.sizes().equals(inferred_size),
475-
"baddbmm not support broadcast yet");
463+
"dil_baddbmm not support broadcast yet");
476464

477465
dbl::comm::reorder_to_bf16_for_mix_prec(result);
478466
dbl::comm::reorder_to_bf16_for_mix_prec(self);
@@ -541,7 +529,7 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
541529
result.resize_(inferred_size);
542530
}
543531
TORCH_CHECK(self.sizes().equals(inferred_size),
544-
"addmm not support broadcast yet");
532+
"dil_addmm not support broadcast yet");
545533

546534
dbl::comm::reorder_to_bf16_for_mix_prec(result);
547535
dbl::comm::reorder_to_bf16_for_mix_prec(self);
@@ -610,7 +598,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
610598
result.resize_(inferred_size);
611599
}
612600
TORCH_CHECK(self.sizes().equals(inferred_size),
613-
"addbmm not support broadcast yet");
601+
"dil_addbmm not support broadcast yet");
614602

615603
dbl::comm::reorder_to_bf16_for_mix_prec(result);
616604
dbl::comm::reorder_to_bf16_for_mix_prec(self);

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ void reorder_to_desc(const at::Tensor& tensor, const dil::tensor::desc& expected
8989
}
9090

9191
void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
92+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
93+
tensor.device().is_dpcpp(),
94+
"dil buffer can only be equipped to dpcpp tensor");
95+
96+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
97+
check_tensor_own_whole_storage(tensor),
98+
"dil buffer can only be equipped to tensors that own the whole storage, "
99+
"as dil buffer is going to replace the original storage");
100+
92101
// Build new shade data context
93102
cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
94103
new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL;
@@ -97,13 +106,10 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
97106
void *tensor_data = nullptr;
98107
if (dil_tensor_buffer.get_data_type() != get_dil_data_type(tensor.scalar_type())) {
99108
new_shade_data_context->mix_prec_type = cpu::MIX_PREC_TYPE::MIX_BF16_FP32;
100-
} else {
101-
if (dil_tensor_buffer.is_public_format()) {
102-
tensor_data = dil_tensor_buffer.get_data_handle();
103-
new_shade_data_context->cpu_raw_data = tensor_data;
104-
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
105-
sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer);
106-
}
109+
} else if (dil_tensor_buffer.is_public_format()) {
110+
tensor_data = dil_tensor_buffer.get_data_handle();
111+
new_shade_data_context->cpu_raw_data = tensor_data;
112+
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
107113
}
108114

109115
// Create a new DataPtr instances because the DataPtr class does not support set
@@ -116,6 +122,12 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
116122

117123
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)tensor.unsafeGetTensorImpl();
118124
ipex_tensor_impl->storage().set_data_ptr(std::move(shade_data_ptr));
125+
126+
// After equip_dil_buffer(), whole storage should be managed by dil tensor,
127+
// and thus storage metadata should be overwritten by dil tensor
128+
// Note: Storage::set_numel() might be removed later
129+
ipex_tensor_impl->storage().set_numel(dil_tensor_buffer.get_nelems());
130+
cpu::dbl::comm::sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer);
119131
}
120132

121133
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {

0 commit comments

Comments
 (0)