@@ -254,115 +254,95 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_
254
254
return std::tuple<at::Tensor,at::Tensor,at::Tensor>(bridge::shallowUpgradeToDPCPPTensor (std::get<0 >(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor (std::get<1 >(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor (std::get<2 >(_ipex_result)));
255
255
}
256
256
257
- at::Tensor& AtenIpexCPUDev::dil_add_out (
257
+ template <bool inplace>
258
+ at::Tensor& dil_add_common (
258
259
at::Tensor& result,
259
260
const at::Tensor& self,
260
261
const at::Tensor& other,
261
262
at::Scalar alpha) {
262
- DEBUG (" AtenIpexCPUDev::dil_add_out\n " );
263
263
CHECK_DNNL_OP_PRE_COND (self);
264
264
CHECK_DNNL_OP_PRE_COND (other);
265
265
266
266
TORCH_CHECK (self.sizes ().equals (other.sizes ()),
267
- " dil_add not support broadcast yet" );
268
- auto inferred_size = self.sizes ();
269
- if (!result.sizes ().equals (inferred_size)) {
270
- result.resize_ (inferred_size);
271
- }
267
+ " dil add not support broadcast yet" );
272
268
273
269
dbl::comm::reorder_to_bf16_for_mix_prec (self);
274
270
dbl::comm::reorder_to_bf16_for_mix_prec (other);
275
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
276
271
277
272
auto x = dbl::comm::try_gen_dil_tensor (self);
278
273
auto y = dbl::comm::try_gen_dil_tensor (other);
279
- auto z = dbl::comm::try_gen_dil_tensor (result );
274
+ auto z = inplace ? x : dil::tensor ( );
280
275
281
276
dil::sum::compute ({1.0 , alpha.to <float >()}, {x, y}, z);
282
277
283
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (z.is_public_format () || check_tensor_own_whole_storage (result));
284
- dbl::comm::sync_shape_from_dil_to_aten (result, z);
278
+ if (!inplace) {
279
+ dbl::comm::equip_dil_buffer (result, z);
280
+ }
285
281
return result;
286
282
}
287
283
288
- at::Tensor AtenIpexCPUDev::dil_add (const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
289
- DEBUG (" AtenIpexCPUDev::dil_add\n " );
290
- CHECK_DNNL_OP_PRE_COND (self);
291
- CHECK_DNNL_OP_PRE_COND (other);
292
-
293
- TORCH_CHECK (self.sizes ().equals (other.sizes ()),
294
- " dil_add not support broadcast yet" );
295
-
296
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
297
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
284
+ at::Tensor& AtenIpexCPUDev::dil_add_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
285
+ DEBUG (" AtenIpexCPUDev::dil_add_out\n " );
298
286
299
- auto x = dbl::comm::try_gen_dil_tensor (self);
300
- auto y = dbl::comm::try_gen_dil_tensor (other);
301
- dil::tensor z;
287
+ return dil_add_common</* inplace=*/ false >(result, self, other, alpha);
288
+ }
302
289
303
- dil::sum::compute ({1.0 , alpha.to <float >()}, {x, y}, z);
290
+ at::Tensor AtenIpexCPUDev::dil_add (const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
291
+ DEBUG (" AtenIpexCPUDev::dil_add\n " );
304
292
305
- return dbl::comm::gen_aten_tensor_by (std::move (z));
293
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
294
+ return dil_add_common</* inplace=*/ false >(result, self, other, alpha);
306
295
}
307
296
308
297
at::Tensor & AtenIpexCPUDev::dil_add_ (at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
309
298
DEBUG (" AtenIpexCPUDev::dil_add_\n " );
310
299
311
- return dil_add_out (self, self, other, alpha);
300
+ return dil_add_common< /* inplace= */ true > (self, self, other, alpha);
312
301
}
313
302
314
- at::Tensor& AtenIpexCPUDev::dil_mul_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
315
- DEBUG (" AtenIpexCPUDev::dil_mul_out\n " );
316
- CHECK_DNNL_OP_PRE_COND (result);
303
+ template <bool inplace>
304
+ at::Tensor& dil_mul_common (
305
+ at::Tensor& result,
306
+ const at::Tensor& self,
307
+ const at::Tensor& other) {
317
308
CHECK_DNNL_OP_PRE_COND (self);
318
309
CHECK_DNNL_OP_PRE_COND (other);
319
310
320
311
TORCH_CHECK (self.sizes ().equals (other.sizes ()),
321
- " dil_mul not support broadcast yet" );
322
- auto inferred_size = self.sizes ();
323
- if (!result.sizes ().equals (inferred_size)) {
324
- result.resize_ (inferred_size);
325
- }
312
+ " dil mul not support broadcast yet" );
326
313
327
314
dbl::comm::reorder_to_bf16_for_mix_prec (self);
328
315
dbl::comm::reorder_to_bf16_for_mix_prec (other);
329
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
330
316
331
- auto dil_result = dbl::comm::try_gen_dil_tensor (result );
332
- auto dil_self = dbl::comm::try_gen_dil_tensor (self );
333
- auto dil_other = dbl::comm::try_gen_dil_tensor (other );
317
+ auto x = dbl::comm::try_gen_dil_tensor (self );
318
+ auto y = dbl::comm::try_gen_dil_tensor (other );
319
+ auto z = inplace ? x : dil::tensor ( );
334
320
335
- dil::binary::compute (dil_self, dil_other, dil_result , dil::algorithm::binary_mul);
321
+ dil::binary::compute (x, y, z , dil::algorithm::binary_mul);
336
322
337
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dil_result.is_public_format () || check_tensor_own_whole_storage (result));
338
- dbl::comm::sync_shape_from_dil_to_aten (result, dil_result);
323
+ if (!inplace) {
324
+ dbl::comm::equip_dil_buffer (result, z);
325
+ }
339
326
return result;
340
327
}
341
328
342
- at::Tensor AtenIpexCPUDev::dil_mul (const at::Tensor& self, const at::Tensor& other) {
343
- DEBUG (" AtenIpexCPUDev::dil_mul\n " );
344
- CHECK_DNNL_OP_PRE_COND (self);
345
- CHECK_DNNL_OP_PRE_COND (other);
346
-
347
- TORCH_CHECK (self.sizes ().equals (other.sizes ()),
348
- " dil_mul not support broadcast yet" );
349
-
350
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
351
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
329
+ at::Tensor& AtenIpexCPUDev::dil_mul_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
330
+ DEBUG (" AtenIpexCPUDev::dil_mul_out\n " );
352
331
353
- auto x = dbl::comm::try_gen_dil_tensor (self);
354
- auto y = dbl::comm::try_gen_dil_tensor (other);
355
- dil::tensor z;
332
+ return dil_mul_common</* inplace=*/ false >(result, self, other);
333
+ }
356
334
357
- dil::binary::compute (x, y, z, dil::algorithm::binary_mul);
335
+ at::Tensor AtenIpexCPUDev::dil_mul (const at::Tensor& self, const at::Tensor& other) {
336
+ DEBUG (" AtenIpexCPUDev::dil_mul\n " );
358
337
359
- return dbl::comm::gen_aten_tensor_by (std::move (z));
338
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
339
+ return dil_mul_common</* inplace=*/ false >(result, self, other);
360
340
}
361
341
362
342
at::Tensor& AtenIpexCPUDev::dil_mul_ (at::Tensor& self, const at::Tensor& other) {
363
343
DEBUG (" AtenIpexCPUDev::dil_mul_\n " );
364
344
365
- return dil_mul_out (self, self, other);
345
+ return dil_mul_common< /* inplace= */ true > (self, self, other);
366
346
}
367
347
368
348
void matmul_common (
0 commit comments