Skip to content

Commit 89c4149

Browse files
committed
throw exception on unbroadcastable input for binary ops
1 parent b0e83a0 commit 89c4149

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,22 @@ at::Tensor& AtenIpexCPUDev::dil_add_out(
263263
CHECK_DNNL_OP_PRE_COND(self);
264264
CHECK_DNNL_OP_PRE_COND(other);
265265

266+
TORCH_CHECK(self.sizes().equals(other.sizes()),
267+
"dil_add not support broadcast yet");
268+
auto inferred_size = self.sizes();
269+
if (!result.sizes().equals(inferred_size)) {
270+
result.resize_(inferred_size);
271+
}
272+
266273
dbl::comm::reorder_to_bf16_for_mix_prec(self);
267274
dbl::comm::reorder_to_bf16_for_mix_prec(other);
268275
dbl::comm::reorder_to_bf16_for_mix_prec(result);
269276

270-
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
271-
dil::tensor y = dbl::comm::try_gen_dil_tensor(other);
277+
auto x = dbl::comm::try_gen_dil_tensor(self);
278+
auto y = dbl::comm::try_gen_dil_tensor(other);
279+
auto z = dbl::comm::try_gen_dil_tensor(result);
272280

273-
dil::tensor z = dbl::comm::try_gen_dil_tensor(result);
274-
const std::vector<float> scales{1.0, alpha.to<float>()};
275-
dil::sum::compute(scales, {x, y}, z);
281+
dil::sum::compute({1.0, alpha.to<float>()}, {x, y}, z);
276282

277283
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(z.is_public_format() || check_tensor_own_whole_storage(result));
278284
dbl::comm::sync_shape_from_dil_to_aten(result, z);
@@ -284,36 +290,25 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth
284290
CHECK_DNNL_OP_PRE_COND(self);
285291
CHECK_DNNL_OP_PRE_COND(other);
286292

293+
TORCH_CHECK(self.sizes().equals(other.sizes()),
294+
"dil_add not support broadcast yet");
295+
287296
dbl::comm::reorder_to_bf16_for_mix_prec(self);
288297
dbl::comm::reorder_to_bf16_for_mix_prec(other);
289298

290-
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
291-
dil::tensor y = dbl::comm::try_gen_dil_tensor(other);
292-
299+
auto x = dbl::comm::try_gen_dil_tensor(self);
300+
auto y = dbl::comm::try_gen_dil_tensor(other);
293301
dil::tensor z;
294-
const std::vector<float> scales{1.0, alpha.to<float>()};
295-
dil::sum::compute(scales, {x, y}, z);
302+
303+
dil::sum::compute({1.0, alpha.to<float>()}, {x, y}, z);
296304

297305
return dbl::comm::gen_aten_tensor_by(std::move(z));
298306
}
299307

300308
at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
301309
DEBUG("AtenIpexCPUDev::dil_add_\n");
302-
CHECK_DNNL_OP_PRE_COND(self);
303-
CHECK_DNNL_OP_PRE_COND(other);
304310

305-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
306-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
307-
308-
auto dil_self = dbl::comm::try_gen_dil_tensor(self);
309-
auto dil_other = dbl::comm::try_gen_dil_tensor(other);
310-
311-
const std::vector<float> scales{1.0, alpha.to<float>()};
312-
dil::sum::compute(scales, {dil_self, dil_other}, dil_self);
313-
314-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dil_self.is_public_format() || check_tensor_own_whole_storage(self));
315-
dbl::comm::sync_shape_from_dil_to_aten(self, dil_self);
316-
return self;
311+
return dil_add_out(self, self, other, alpha);
317312
}
318313

319314
at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) {
@@ -322,6 +317,13 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se
322317
CHECK_DNNL_OP_PRE_COND(self);
323318
CHECK_DNNL_OP_PRE_COND(other);
324319

320+
TORCH_CHECK(self.sizes().equals(other.sizes()),
321+
"dil_mul not support broadcast yet");
322+
auto inferred_size = self.sizes();
323+
if (!result.sizes().equals(inferred_size)) {
324+
result.resize_(inferred_size);
325+
}
326+
325327
dbl::comm::reorder_to_bf16_for_mix_prec(self);
326328
dbl::comm::reorder_to_bf16_for_mix_prec(other);
327329
dbl::comm::reorder_to_bf16_for_mix_prec(result);
@@ -339,21 +341,27 @@ at::Tensor& AtenIpexCPUDev::dil_mul_out(at::Tensor& result, const at::Tensor& se
339341

340342
at::Tensor AtenIpexCPUDev::dil_mul(const at::Tensor& self, const at::Tensor& other) {
341343
DEBUG("AtenIpexCPUDev::dil_mul\n");
344+
CHECK_DNNL_OP_PRE_COND(self);
345+
CHECK_DNNL_OP_PRE_COND(other);
346+
347+
TORCH_CHECK(self.sizes().equals(other.sizes()),
348+
"dil_mul not support broadcast yet");
342349

343350
dbl::comm::reorder_to_bf16_for_mix_prec(self);
344351
dbl::comm::reorder_to_bf16_for_mix_prec(other);
345352

346-
at::Tensor result = dbl::comm::empty_dil_tensor(self.sizes(), self.options());
353+
auto x = dbl::comm::try_gen_dil_tensor(self);
354+
auto y = dbl::comm::try_gen_dil_tensor(other);
355+
dil::tensor z;
356+
357+
dil::binary::compute(x, y, z, dil::algorithm::binary_mul);
347358

348-
return dil_mul_out(result, self, other);
359+
return dbl::comm::gen_aten_tensor_by(std::move(z));
349360
}
350361

351362
at::Tensor& AtenIpexCPUDev::dil_mul_(at::Tensor& self, const at::Tensor& other) {
352363
DEBUG("AtenIpexCPUDev::dil_mul_\n");
353364

354-
dbl::comm::reorder_to_bf16_for_mix_prec(self);
355-
dbl::comm::reorder_to_bf16_for_mix_prec(other);
356-
357365
return dil_mul_out(self, self, other);
358366
}
359367

@@ -472,7 +480,7 @@ at::Tensor& AtenIpexCPUDev::dil_baddbmm_out(
472480
result.resize_(inferred_size);
473481
}
474482
TORCH_CHECK(self.sizes().equals(inferred_size),
475-
"baddbmm not support broadcast yet");
483+
"dil_baddbmm not support broadcast yet");
476484

477485
dbl::comm::reorder_to_bf16_for_mix_prec(result);
478486
dbl::comm::reorder_to_bf16_for_mix_prec(self);
@@ -541,7 +549,7 @@ at::Tensor& AtenIpexCPUDev::dil_addmm_out(
541549
result.resize_(inferred_size);
542550
}
543551
TORCH_CHECK(self.sizes().equals(inferred_size),
544-
"addmm not support broadcast yet");
552+
"dil_addmm not support broadcast yet");
545553

546554
dbl::comm::reorder_to_bf16_for_mix_prec(result);
547555
dbl::comm::reorder_to_bf16_for_mix_prec(self);
@@ -610,7 +618,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_out(
610618
result.resize_(inferred_size);
611619
}
612620
TORCH_CHECK(self.sizes().equals(inferred_size),
613-
"addbmm not support broadcast yet");
621+
"dil_addbmm not support broadcast yet");
614622

615623
dbl::comm::reorder_to_bf16_for_mix_prec(result);
616624
dbl::comm::reorder_to_bf16_for_mix_prec(self);

0 commit comments

Comments
 (0)