@@ -254,107 +254,95 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_
254
254
return std::tuple<at::Tensor,at::Tensor,at::Tensor>(bridge::shallowUpgradeToDPCPPTensor (std::get<0 >(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor (std::get<1 >(_ipex_result)), bridge::shallowUpgradeToDPCPPTensor (std::get<2 >(_ipex_result)));
255
255
}
256
256
257
- at::Tensor& AtenIpexCPUDev::dil_add_out (
257
+ template <bool inplace>
258
+ at::Tensor& dil_add_common (
258
259
at::Tensor& result,
259
260
const at::Tensor& self,
260
261
const at::Tensor& other,
261
262
at::Scalar alpha) {
262
- DEBUG (" AtenIpexCPUDev::dil_add_out\n " );
263
263
CHECK_DNNL_OP_PRE_COND (self);
264
264
CHECK_DNNL_OP_PRE_COND (other);
265
265
266
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
267
+ " dil add not support broadcast yet" );
268
+
266
269
dbl::comm::reorder_to_bf16_for_mix_prec (self);
267
270
dbl::comm::reorder_to_bf16_for_mix_prec (other);
268
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
269
271
270
- dil::tensor x = dbl::comm::try_gen_dil_tensor (self);
271
- dil::tensor y = dbl::comm::try_gen_dil_tensor (other);
272
+ auto x = dbl::comm::try_gen_dil_tensor (self);
273
+ auto y = dbl::comm::try_gen_dil_tensor (other);
274
+ auto z = inplace ? x : dil::tensor ();
272
275
273
- dil::tensor z = dbl::comm::try_gen_dil_tensor (result);
274
- const std::vector<float > scales{1.0 , alpha.to <float >()};
275
- dil::sum::compute (scales, {x, y}, z);
276
+ dil::sum::compute ({1.0 , alpha.to <float >()}, {x, y}, z);
276
277
277
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (z.is_public_format () || check_tensor_own_whole_storage (result));
278
- dbl::comm::sync_shape_from_dil_to_aten (result, z);
278
+ if (!inplace) {
279
+ dbl::comm::equip_dil_buffer (result, z);
280
+ }
279
281
return result;
280
282
}
281
283
282
- at::Tensor AtenIpexCPUDev::dil_add (const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
283
- DEBUG (" AtenIpexCPUDev::dil_add\n " );
284
- CHECK_DNNL_OP_PRE_COND (self);
285
- CHECK_DNNL_OP_PRE_COND (other);
286
-
287
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
288
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
284
+ at::Tensor& AtenIpexCPUDev::dil_add_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
285
+ DEBUG (" AtenIpexCPUDev::dil_add_out\n " );
289
286
290
- dil::tensor x = dbl::comm::try_gen_dil_tensor ( self);
291
- dil::tensor y = dbl::comm::try_gen_dil_tensor (other);
287
+ return dil_add_common< /* inplace= */ false >(result, self, other, alpha );
288
+ }
292
289
293
- dil::tensor z;
294
- const std::vector<float > scales{1.0 , alpha.to <float >()};
295
- dil::sum::compute (scales, {x, y}, z);
290
+ at::Tensor AtenIpexCPUDev::dil_add (const at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
291
+ DEBUG (" AtenIpexCPUDev::dil_add\n " );
296
292
297
- return dbl::comm::gen_aten_tensor_by (std::move (z));
293
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
294
+ return dil_add_common</* inplace=*/ false >(result, self, other, alpha);
298
295
}
299
296
300
297
at::Tensor & AtenIpexCPUDev::dil_add_ (at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
301
298
DEBUG (" AtenIpexCPUDev::dil_add_\n " );
299
+
300
+ return dil_add_common</* inplace=*/ true >(self, self, other, alpha);
301
+ }
302
+
303
+ template <bool inplace>
304
+ at::Tensor& dil_mul_common (
305
+ at::Tensor& result,
306
+ const at::Tensor& self,
307
+ const at::Tensor& other) {
302
308
CHECK_DNNL_OP_PRE_COND (self);
303
309
CHECK_DNNL_OP_PRE_COND (other);
304
310
311
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
312
+ " dil mul not support broadcast yet" );
313
+
305
314
dbl::comm::reorder_to_bf16_for_mix_prec (self);
306
315
dbl::comm::reorder_to_bf16_for_mix_prec (other);
307
316
308
- auto dil_self = dbl::comm::try_gen_dil_tensor (self);
309
- auto dil_other = dbl::comm::try_gen_dil_tensor (other);
317
+ auto x = dbl::comm::try_gen_dil_tensor (self);
318
+ auto y = dbl::comm::try_gen_dil_tensor (other);
319
+ auto z = inplace ? x : dil::tensor ();
310
320
311
- const std::vector<float > scales{1.0 , alpha.to <float >()};
312
- dil::sum::compute (scales, {dil_self, dil_other}, dil_self);
321
+ dil::binary::compute (x, y, z, dil::algorithm::binary_mul);
313
322
314
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dil_self.is_public_format () || check_tensor_own_whole_storage (self));
315
- dbl::comm::sync_shape_from_dil_to_aten (self, dil_self);
316
- return self;
323
+ if (!inplace) {
324
+ dbl::comm::equip_dil_buffer (result, z);
325
+ }
326
+ return result;
317
327
}
318
328
319
329
at::Tensor& AtenIpexCPUDev::dil_mul_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
320
330
DEBUG (" AtenIpexCPUDev::dil_mul_out\n " );
321
- CHECK_DNNL_OP_PRE_COND (result);
322
- CHECK_DNNL_OP_PRE_COND (self);
323
- CHECK_DNNL_OP_PRE_COND (other);
324
-
325
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
326
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
327
- dbl::comm::reorder_to_bf16_for_mix_prec (result);
328
-
329
- auto dil_result = dbl::comm::try_gen_dil_tensor (result);
330
- auto dil_self = dbl::comm::try_gen_dil_tensor (self);
331
- auto dil_other = dbl::comm::try_gen_dil_tensor (other);
332
-
333
- dil::binary::compute (dil_self, dil_other, dil_result, dil::algorithm::binary_mul);
334
331
335
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dil_result.is_public_format () || check_tensor_own_whole_storage (result));
336
- dbl::comm::sync_shape_from_dil_to_aten (result, dil_result);
337
- return result;
332
+ return dil_mul_common</* inplace=*/ false >(result, self, other);
338
333
}
339
334
340
335
at::Tensor AtenIpexCPUDev::dil_mul (const at::Tensor& self, const at::Tensor& other) {
341
336
DEBUG (" AtenIpexCPUDev::dil_mul\n " );
342
337
343
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
344
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
345
-
346
- at::Tensor result = dbl::comm::empty_dil_tensor (self.sizes (), self.options ());
347
-
348
- return dil_mul_out (result, self, other);
338
+ auto result = dbl::comm::empty_dil_tensor ({0 }, self.options ());
339
+ return dil_mul_common</* inplace=*/ false >(result, self, other);
349
340
}
350
341
351
342
at::Tensor& AtenIpexCPUDev::dil_mul_ (at::Tensor& self, const at::Tensor& other) {
352
343
DEBUG (" AtenIpexCPUDev::dil_mul_\n " );
353
344
354
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
355
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
356
-
357
- return dil_mul_out (self, self, other);
345
+ return dil_mul_common</* inplace=*/ true >(self, self, other);
358
346
}
359
347
360
348
void matmul_common (
@@ -472,7 +460,7 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
472
460
result.resize_ (inferred_size);
473
461
}
474
462
TORCH_CHECK (self.sizes ().equals (inferred_size),
475
- " baddbmm not support broadcast yet" );
463
+ " dil_baddbmm not support broadcast yet" );
476
464
477
465
dbl::comm::reorder_to_bf16_for_mix_prec (result);
478
466
dbl::comm::reorder_to_bf16_for_mix_prec (self);
@@ -541,7 +529,7 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
541
529
result.resize_ (inferred_size);
542
530
}
543
531
TORCH_CHECK (self.sizes ().equals (inferred_size),
544
- " addmm not support broadcast yet" );
532
+ " dil_addmm not support broadcast yet" );
545
533
546
534
dbl::comm::reorder_to_bf16_for_mix_prec (result);
547
535
dbl::comm::reorder_to_bf16_for_mix_prec (self);
@@ -610,7 +598,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
610
598
result.resize_ (inferred_size);
611
599
}
612
600
TORCH_CHECK (self.sizes ().equals (inferred_size),
613
- " addbmm not support broadcast yet" );
601
+ " dil_addbmm not support broadcast yet" );
614
602
615
603
dbl::comm::reorder_to_bf16_for_mix_prec (result);
616
604
dbl::comm::reorder_to_bf16_for_mix_prec (self);
0 commit comments