@@ -263,16 +263,22 @@ at::Tensor& AtenIpexCPUDev::dil_add_out(
263
263
CHECK_DNNL_OP_PRE_COND (self);
264
264
CHECK_DNNL_OP_PRE_COND (other);
265
265
266
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
267
+ " dil_add not support broadcast yet" );
268
+ auto inferred_size = self.sizes ();
269
+ if (!result.sizes ().equals (inferred_size)) {
270
+ result.resize_ (inferred_size);
271
+ }
272
+
266
273
dbl::comm::reorder_to_bf16_for_mix_prec (self);
267
274
dbl::comm::reorder_to_bf16_for_mix_prec (other);
268
275
dbl::comm::reorder_to_bf16_for_mix_prec (result);
269
276
270
- dil::tensor x = dbl::comm::try_gen_dil_tensor (self);
271
- dil::tensor y = dbl::comm::try_gen_dil_tensor (other);
277
+ auto x = dbl::comm::try_gen_dil_tensor (self);
278
+ auto y = dbl::comm::try_gen_dil_tensor (other);
279
+ auto z = dbl::comm::try_gen_dil_tensor (result);
272
280
273
- dil::tensor z = dbl::comm::try_gen_dil_tensor (result);
274
- const std::vector<float > scales{1.0 , alpha.to <float >()};
275
- dil::sum::compute (scales, {x, y}, z);
281
+ dil::sum::compute ({1.0 , alpha.to <float >()}, {x, y}, z);
276
282
277
283
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (z.is_public_format () || check_tensor_own_whole_storage (result));
278
284
dbl::comm::sync_shape_from_dil_to_aten (result, z);
@@ -284,36 +290,25 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth
284
290
CHECK_DNNL_OP_PRE_COND (self);
285
291
CHECK_DNNL_OP_PRE_COND (other);
286
292
293
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
294
+ " dil_add not support broadcast yet" );
295
+
287
296
dbl::comm::reorder_to_bf16_for_mix_prec (self);
288
297
dbl::comm::reorder_to_bf16_for_mix_prec (other);
289
298
290
- dil::tensor x = dbl::comm::try_gen_dil_tensor (self);
291
- dil::tensor y = dbl::comm::try_gen_dil_tensor (other);
292
-
299
+ auto x = dbl::comm::try_gen_dil_tensor (self);
300
+ auto y = dbl::comm::try_gen_dil_tensor (other);
293
301
dil::tensor z;
294
- const std::vector< float > scales{ 1.0 , alpha. to < float >()};
295
- dil::sum::compute (scales , {x, y}, z);
302
+
303
+ dil::sum::compute ({ 1.0 , alpha. to < float >()} , {x, y}, z);
296
304
297
305
return dbl::comm::gen_aten_tensor_by (std::move (z));
298
306
}
299
307
300
308
at::Tensor & AtenIpexCPUDev::dil_add_ (at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
301
309
DEBUG (" AtenIpexCPUDev::dil_add_\n " );
302
- CHECK_DNNL_OP_PRE_COND (self);
303
- CHECK_DNNL_OP_PRE_COND (other);
304
310
305
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
306
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
307
-
308
- auto dil_self = dbl::comm::try_gen_dil_tensor (self);
309
- auto dil_other = dbl::comm::try_gen_dil_tensor (other);
310
-
311
- const std::vector<float > scales{1.0 , alpha.to <float >()};
312
- dil::sum::compute (scales, {dil_self, dil_other}, dil_self);
313
-
314
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (dil_self.is_public_format () || check_tensor_own_whole_storage (self));
315
- dbl::comm::sync_shape_from_dil_to_aten (self, dil_self);
316
- return self;
311
+ return dil_add_out (self, self, other, alpha);
317
312
}
318
313
319
314
at::Tensor& AtenIpexCPUDev::dil_mul_out (at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
@@ -322,6 +317,13 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se
322
317
CHECK_DNNL_OP_PRE_COND (self);
323
318
CHECK_DNNL_OP_PRE_COND (other);
324
319
320
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
321
+ " dil_mul not support broadcast yet" );
322
+ auto inferred_size = self.sizes ();
323
+ if (!result.sizes ().equals (inferred_size)) {
324
+ result.resize_ (inferred_size);
325
+ }
326
+
325
327
dbl::comm::reorder_to_bf16_for_mix_prec (self);
326
328
dbl::comm::reorder_to_bf16_for_mix_prec (other);
327
329
dbl::comm::reorder_to_bf16_for_mix_prec (result);
@@ -339,21 +341,27 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se
339
341
340
342
at::Tensor AtenIpexCPUDev::dil_mul (const at::Tensor& self, const at::Tensor& other) {
341
343
DEBUG (" AtenIpexCPUDev::dil_mul\n " );
344
+ CHECK_DNNL_OP_PRE_COND (self);
345
+ CHECK_DNNL_OP_PRE_COND (other);
346
+
347
+ TORCH_CHECK (self.sizes ().equals (other.sizes ()),
348
+ " dil_mul not support broadcast yet" );
342
349
343
350
dbl::comm::reorder_to_bf16_for_mix_prec (self);
344
351
dbl::comm::reorder_to_bf16_for_mix_prec (other);
345
352
346
- at::Tensor result = dbl::comm::empty_dil_tensor (self.sizes (), self.options ());
353
+ auto x = dbl::comm::try_gen_dil_tensor (self);
354
+ auto y = dbl::comm::try_gen_dil_tensor (other);
355
+ dil::tensor z;
356
+
357
+ dil::binary::compute (x, y, z, dil::algorithm::binary_mul);
347
358
348
- return dil_mul_out (result, self, other );
359
+ return dbl::comm::gen_aten_tensor_by ( std::move (z) );
349
360
}
350
361
351
362
at::Tensor& AtenIpexCPUDev::dil_mul_ (at::Tensor& self, const at::Tensor& other) {
352
363
DEBUG (" AtenIpexCPUDev::dil_mul_\n " );
353
364
354
- dbl::comm::reorder_to_bf16_for_mix_prec (self);
355
- dbl::comm::reorder_to_bf16_for_mix_prec (other);
356
-
357
365
return dil_mul_out (self, self, other);
358
366
}
359
367
@@ -472,7 +480,7 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
472
480
result.resize_ (inferred_size);
473
481
}
474
482
TORCH_CHECK (self.sizes ().equals (inferred_size),
475
- " baddbmm not support broadcast yet" );
483
+ " dil_baddbmm not support broadcast yet" );
476
484
477
485
dbl::comm::reorder_to_bf16_for_mix_prec (result);
478
486
dbl::comm::reorder_to_bf16_for_mix_prec (self);
@@ -541,7 +549,7 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
541
549
result.resize_ (inferred_size);
542
550
}
543
551
TORCH_CHECK (self.sizes ().equals (inferred_size),
544
- " addmm not support broadcast yet" );
552
+ " dil_addmm not support broadcast yet" );
545
553
546
554
dbl::comm::reorder_to_bf16_for_mix_prec (result);
547
555
dbl::comm::reorder_to_bf16_for_mix_prec (self);
@@ -610,7 +618,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
610
618
result.resize_ (inferred_size);
611
619
}
612
620
TORCH_CHECK (self.sizes ().equals (inferred_size),
613
- " addbmm not support broadcast yet" );
621
+ " dil_addbmm not support broadcast yet" );
614
622
615
623
dbl::comm::reorder_to_bf16_for_mix_prec (result);
616
624
dbl::comm::reorder_to_bf16_for_mix_prec (self);
0 commit comments