Skip to content

Commit 68fb1cf

Browse files
committed
refine binary ops
- support resizing behavior of out... parameter - fallback to aten impl on broadcastable inputs - remove workaround of broadcast in jit
1 parent 89c4149 commit 68fb1cf

File tree

4 files changed

+136
-68
lines changed

4 files changed

+136
-68
lines changed

intel_pytorch_extension_py/ops/jit_script.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1313
torch.jit.script = script_
1414

1515
if core.get_jit():
16-
# bypass buggy broadcastable ops in dnnl during folding
17-
core.disable_auto_dnnl()
1816
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
19-
core.enable_auto_dnnl()
2017

2118
return jit_m
2219

tests/cpu/test_lazy_reorder.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def get_rand_seed():
3333
return int(time.time() * 1000000000)
3434

3535
device = ipex.DEVICE
36+
37+
def convert_blocked(t):
38+
assert t.dim() == 4, "only support converting 4d tensor"
39+
c = t.size(1)
40+
t = t.clone().to(device)
41+
return F.conv2d(t, torch.ones(c, 1, 1, 1).to(device), groups=c)
42+
3643
class TestConv(TestCase):
3744
def test_Conv2d_with_cpu(self):
3845
rand_seed = int(get_rand_seed())
@@ -202,6 +209,78 @@ def test_mul_(self):
202209
a2 = self._test_mul_('cpu', rand_seed)
203210
self.assertEqual(a2, a1.to('cpu'))
204211

212+
def test_mixed_format(self):
213+
ipex.core.enable_auto_dnnl()
214+
rand_seed = int(get_rand_seed())
215+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
216+
torch.manual_seed(rand_seed)
217+
218+
shape = (2, 3, 4, 5)
219+
220+
for fname in ['add', 'mul']:
221+
222+
x_cpu = torch.ones(shape) * 5
223+
y_cpu = torch.ones(shape) * 4
224+
225+
# block tensor is a dpcpp tensor
226+
x_plain = x_cpu.clone().to(device)
227+
y_plain = y_cpu.clone().to(device)
228+
x_block = convert_blocked(x_cpu.clone())
229+
y_block = convert_blocked(y_cpu.clone())
230+
231+
fn = getattr(torch, fname)
232+
ref = fn(x_cpu, y_cpu)
233+
234+
# test add, mul
235+
def test_outplace(a, b):
236+
a = a.clone()
237+
b = b.clone()
238+
self.assertEqual(fn(a, b), ref)
239+
240+
test_outplace(x_plain, y_plain)
241+
test_outplace(x_plain, y_block)
242+
test_outplace(y_block, x_plain)
243+
test_outplace(x_block, y_block)
244+
245+
# test add_out, mul_out
246+
def test_out(a, b, o):
247+
a = a.clone()
248+
b = b.clone()
249+
o = o.clone()
250+
y = fn(a, b, out=o)
251+
self.assertEqual(y, ref)
252+
self.assertEqual(o, ref)
253+
254+
out = torch.ones(shape).to(device)
255+
test_out(x_plain, y_plain, out)
256+
test_out(x_plain, y_block, out)
257+
test_out(y_block, x_plain, out)
258+
test_out(x_block, y_block, out)
259+
out = torch.ones(1).to(device)
260+
test_out(x_plain, y_plain, out)
261+
test_out(x_plain, y_block, out)
262+
test_out(y_block, x_plain, out)
263+
test_out(x_block, y_block, out)
264+
265+
# test add_, mul_
266+
def test_inplace(a, b):
267+
a = a.clone()
268+
b = b.clone()
269+
y = getattr(a, fname + '_')(b)
270+
self.assertEqual(a, ref)
271+
self.assertEqual(y, ref)
272+
273+
test_inplace(x_plain, y_plain)
274+
test_inplace(x_plain, y_block)
275+
test_inplace(y_block, x_plain)
276+
test_inplace(x_block, y_block)
277+
278+
# test broadcast
279+
scalar = torch.ones(1).to(device)
280+
self.assertEqual(fn(x_plain, scalar), fn(x_cpu, scalar))
281+
self.assertEqual(fn(scalar, x_plain), fn(scalar, x_cpu))
282+
283+
205284
class TestRelu(TestCase):
206285
def _test_relu_(self, device, rand_seed):
207286
torch.manual_seed(rand_seed)

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -254,115 +254,95 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_
254254
return std::tuple<at::Tensor,at::Tensor,at::Tensor>(bridge::shallowUpgradeToDPCPPTensor(std::get<0>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<1>(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor(std::get<2>(_ipex_result)));
255255
}
256256

257-
at::Tensor& AtenIpexCPUDev::dil_add_out(
257+
template<bool inplace>
258+
at::Tensor& dil_add_common(
258259
at::Tensor& result,
259260
const at::Tensor& self,
260261
const at::Tensor& other,
261262
at::Scalar alpha) {
262-
DEBUG("AtenIpexCPUDev::dil_add_out\n");
263263
CHECK_DNNL_OP_PRE_COND(self);
264264
CHECK_DNNL_OP_PRE_COND(other);
265265

266266
TORCH_CHECK(self.sizes().equals(other.sizes()),
267-
"dil_add not support broadcast yet");
268-
auto inferred_size = self.sizes();
269-
if (!result.sizes().equals(inferred_size)) {
270-
result.resize_(inferred_size);
271-
}
267+
"dil add not support broadcast yet");
272268

273269
dbl::comm::reorder_to_bf16_for_mix_prec(self);
274270
dbl::comm::reorder_to_bf16_for_mix_prec(other);
275-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
276271

277272
auto x = dbl::comm::try_gen_dil_tensor(self);
278273
auto y = dbl::comm::try_gen_dil_tensor(other);
279-
auto z = dbl::comm::try_gen_dil_tensor(result);
274+
auto z = inplace ? x : dil::tensor();
280275

281276
dil::sum::compute({1.0, alpha.to<float>()}, {x, y}, z);
282277

283-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result));
284-
dbl::comm::sync_shape_from_dil_to_aten(result, z);
278+
if (!inplace) {
279+
dbl::comm::equip_dil_buffer(result, z);
280+
}
285281
return result;
286282
}
287283

288-
at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
289-
DEBUG("AtenIpexCPUDev::dil_add\n");
290-
CHECK_DNNL_OP_PRE_COND(self);
291-
CHECK_DNNL_OP_PRE_COND(other);
292-
293-
TORCH_CHECK(self.sizes().equals(other.sizes()),
294-
"dil_add not support broadcast yet");
295-
296-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
297-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
284+
at::Tensor& AtenIpexCPUDev::dil_add_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
285+
DEBUG("AtenIpexCPUDev::dil_add_out\n");
298286

299-
auto x = dbl::comm::try_gen_dil_tensor(self);
300-
auto y = dbl::comm::try_gen_dil_tensor(other);
301-
dil::tensor z;
287+
return dil_add_common</*inplace=*/false>(result, self, other, alpha);
288+
}
302289

303-
dil::sum::compute({1.0, alpha.to<float>()}, {x, y}, z);
290+
at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
291+
DEBUG("AtenIpexCPUDev::dil_add\n");
304292

305-
return dbl::comm::gen_aten_tensor_by(std::move(z));
293+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
294+
return dil_add_common</*inplace=*/false>(result, self, other, alpha);
306295
}
307296

308297
at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
309298
DEBUG("AtenIpexCPUDev::dil_add_\n");
310299

311-
return dil_add_out(self, self, other, alpha);
300+
return dil_add_common</*inplace=*/true>(self, self, other, alpha);
312301
}
313302

314-
at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
315-
DEBUG("AtenIpexCPUDev::dil_mul_out\n");
316-
CHECK_DNNL_OP_PRE_COND(result);
303+
template<bool inplace>
304+
at::Tensor& dil_mul_common(
305+
at::Tensor& result,
306+
const at::Tensor& self,
307+
const at::Tensor& other) {
317308
CHECK_DNNL_OP_PRE_COND(self);
318309
CHECK_DNNL_OP_PRE_COND(other);
319310

320311
TORCH_CHECK(self.sizes().equals(other.sizes()),
321-
"dil_mul not support broadcast yet");
322-
auto inferred_size = self.sizes();
323-
if (!result.sizes().equals(inferred_size)) {
324-
result.resize_(inferred_size);
325-
}
312+
"dil mul not support broadcast yet");
326313

327314
dbl::comm::reorder_to_bf16_for_mix_prec(self);
328315
dbl::comm::reorder_to_bf16_for_mix_prec(other);
329-
dbl::comm::reorder_to_bf16_for_mix_prec(result);
330316

331-
auto dil_result = dbl::comm::try_gen_dil_tensor(result);
332-
auto dil_self = dbl::comm::try_gen_dil_tensor(self);
333-
auto dil_other = dbl::comm::try_gen_dil_tensor(other);
317+
auto x = dbl::comm::try_gen_dil_tensor(self);
318+
auto y = dbl::comm::try_gen_dil_tensor(other);
319+
auto z = inplace ? x : dil::tensor();
334320

335-
dil::binary::compute(dil_self, dil_other, dil_result, dil::algorithm::binary_mul);
321+
dil::binary::compute(x, y, z, dil::algorithm::binary_mul);
336322

337-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_result.is_public_format() || check_tensor_own_whole_storage(result));
338-
dbl::comm::sync_shape_from_dil_to_aten(result, dil_result);
323+
if (!inplace) {
324+
dbl::comm::equip_dil_buffer(result, z);
325+
}
339326
return result;
340327
}
341328

342-
at::Tensor AtenIpexCPUDev::dil_mul(const at::Tensor& self, const at::Tensor& other) {
343-
DEBUG("AtenIpexCPUDev::dil_mul\n");
344-
CHECK_DNNL_OP_PRE_COND(self);
345-
CHECK_DNNL_OP_PRE_COND(other);
346-
347-
TORCH_CHECK(self.sizes().equals(other.sizes()),
348-
"dil_mul not support broadcast yet");
349-
350-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
351-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
329+
at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
330+
DEBUG("AtenIpexCPUDev::dil_mul_out\n");
352331

353-
auto x = dbl::comm::try_gen_dil_tensor(self);
354-
auto y = dbl::comm::try_gen_dil_tensor(other);
355-
dil::tensor z;
332+
return dil_mul_common</*inplace=*/false>(result, self, other);
333+
}
356334

357-
dil::binary::compute(x, y, z, dil::algorithm::binary_mul);
335+
at::Tensor AtenIpexCPUDev::dil_mul(const at::Tensor& self, const at::Tensor& other) {
336+
DEBUG("AtenIpexCPUDev::dil_mul\n");
358337

359-
return dbl::comm::gen_aten_tensor_by(std::move(z));
338+
auto result = dbl::comm::empty_dil_tensor({0}, self.options());
339+
return dil_mul_common</*inplace=*/false>(result, self, other);
360340
}
361341

362342
at::Tensor& AtenIpexCPUDev::dil_mul_(at::Tensor& self, const at::Tensor& other) {
363343
DEBUG("AtenIpexCPUDev::dil_mul_\n");
364344

365-
return dil_mul_out(self, self, other);
345+
return dil_mul_common</*inplace=*/true>(self, self, other);
366346
}
367347

368348
void matmul_common(

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ void reorder_to_desc(const at::Tensor& tensor, const dil::tensor::desc& expected
8989
}
9090

9191
void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
92+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
93+
tensor.device().is_dpcpp(),
94+
"dil buffer can only be equipped to dpcpp tensor");
95+
96+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
97+
check_tensor_own_whole_storage(tensor),
98+
"dil buffer can only be equipped to tensors that own the whole storage, "
99+
"as dil buffer is going to replace the original storage");
100+
92101
// Build new shade data context
93102
cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
94103
new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL;
@@ -97,13 +106,10 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
97106
void *tensor_data = nullptr;
98107
if (dil_tensor_buffer.get_data_type() != get_dil_data_type(tensor.scalar_type())) {
99108
new_shade_data_context->mix_prec_type = cpu::MIX_PREC_TYPE::MIX_BF16_FP32;
100-
} else {
101-
if (dil_tensor_buffer.is_public_format()) {
102-
tensor_data = dil_tensor_buffer.get_data_handle();
103-
new_shade_data_context->cpu_raw_data = tensor_data;
104-
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
105-
sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer);
106-
}
109+
} else if (dil_tensor_buffer.is_public_format()) {
110+
tensor_data = dil_tensor_buffer.get_data_handle();
111+
new_shade_data_context->cpu_raw_data = tensor_data;
112+
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
107113
}
108114

109115
// Create a new DataPtr instances because the DataPtr class does not support set
@@ -116,6 +122,12 @@ void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
116122

117123
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)tensor.unsafeGetTensorImpl();
118124
ipex_tensor_impl->storage().set_data_ptr(std::move(shade_data_ptr));
125+
126+
// After equip_dil_buffer(), whole storage should be managed by dil tensor,
127+
// and thus storage metadata should be overwritten by dil tensor
128+
// Note: Storage::set_numel() might be removed later
129+
ipex_tensor_impl->storage().set_numel(dil_tensor_buffer.get_nelems());
130+
cpu::dbl::comm::sync_shape_from_dil_to_aten(tensor, dil_tensor_buffer);
119131
}
120132

121133
dil::tensor try_gen_dil_tensor(const at::Tensor &input) {

0 commit comments

Comments
 (0)