Skip to content

Commit a51b709

Browse files
authored
Merge pull request #17 from EikanWang/master
Refine assert IPEX in case of performance penalty
2 parents 04cf2f7 + 313dc87 commit a51b709

14 files changed

+183
-285
lines changed

cmake/CPU.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ IF(CMAKE_BUILD_TYPE MATCHES Debug)
2727
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -D_DEBUG")
2828
ELSE()
2929
message("Release build.")
30-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
30+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DNDEBUG")
3131
ENDIF()
3232

3333
# ---[ Build flags

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def is_out_func(fname):
304304
param_seq_str = param_var
305305
if param_var in dnnl_tensor_param_vars:
306306
if param_var == 'out' and is_out_func(fname):
307-
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
307+
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.is_contiguous());\n'.format(param_var)
308308
else:
309309
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
310310
param_seq_str_vec.append(param_seq_str)
@@ -334,10 +334,10 @@ def gen_fallback_prepare_code(self, cpp_sig):
334334
ipex_name = '_ipex_{}'.format(param.name)
335335
param.ipex_name = ipex_name
336336
check_cond = '{}.device().type() == at::DeviceType::DPCPP'.format(param.name)
337-
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
337+
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
338338
code += ' at::TensorOptions {} = {}.device(at::DeviceType::CPU);\n'.format(ipex_name, param.name)
339339
elif param.core_type == 'Storage':
340-
code += ' TORCH_INTERNAL_ASSERT({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
340+
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
341341
elif param.core_type == 'MemoryFormat':
342342
if param.is_optional:
343343
check_cond = '{}.value_or(c10::MemoryFormat::Contiguous) != c10::MemoryFormat::Contiguous'.format(param.name)
@@ -352,7 +352,7 @@ def gen_fallback_prepare_code(self, cpp_sig):
352352
assert param.core_type == 'Tensor'
353353
ipex_name = '_ipex_{}'.format(param.name)
354354
check_cond = '{}.layout() == c10::kStrided'.format(param.name)
355-
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
355+
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
356356
code += ' auto&& {} = bridge::{}({});\n'.format(ipex_name, _SHALLOW_FALLBACK_TO_CPU_TENSOR, param.name)
357357
param.ipex_name = ipex_name
358358
return op_check_code + code

scripts/cpu/gen-sparse-cpu-ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,10 @@ def gen_fallback_prepare_code(self, cpp_sig):
260260
ipex_name = '_ipex_{}'.format(param.name)
261261
param.ipex_name = ipex_name
262262
check_cond = '{}.device().type() == at::DeviceType::DPCPP'.format(param.name)
263-
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
263+
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
264264
code += ' at::TensorOptions {} = {}.device(at::DeviceType::CPU);\n'.format(ipex_name, param.name)
265265
elif param.core_type == 'Storage':
266-
code += ' TORCH_INTERNAL_ASSERT({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
266+
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
267267
elif param.core_type == 'MemoryFormat':
268268
None
269269
elif param.core_type != 'Tensor':

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 76 additions & 175 deletions
Large diffs are not rendered by default.

torch_ipex/csrc/aten_ipex_bridge.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ namespace torch_ipex {
1010
namespace bridge {
1111

1212
// Convert DPCPP tensor to CPU tensor
13-
at::Tensor fallbackToCPUTensor(const at::Tensor& ipexTensor);
1413
at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor);
15-
std::vector<at::Tensor> fallbackToCPUTensorList(const at::TensorList&);
1614
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);
1715

1816
void attachShadeDataConext(const at::Tensor& tensor);
@@ -51,9 +49,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy
5149
void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);
5250

5351
// Convert CPU tensor to DPCPP tensor
54-
at::Tensor upgradeToDPCPPTensor(const at::Tensor& ipexTensor);
5552
at::Tensor shallowUpgradeToDPCPPTensor(const at::Tensor& ipexTensor);
56-
std::vector<at::Tensor> upgradeToDPCPPTensorVec(const std::vector<at::Tensor> &);
5753
std::vector<at::Tensor> shallowUpgradeToDPCPPTensorVec(const std::vector<at::Tensor> &);
5854

5955
// The last character A means alias. This function is for aten alias

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ namespace cpu {
2828
#define DEBUG(fmt)
2929
#endif
3030

31-
#define CHECK_DNNL_OP_PRE_COND(tensor) \
32-
TORCH_INTERNAL_ASSERT(tensor.defined()); \
33-
TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); \
34-
TORCH_INTERNAL_ASSERT(tensor.layout() == c10::kStrided)
31+
#define CHECK_DNNL_OP_PRE_COND(tensor) \
32+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.defined()); \
33+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.is_contiguous()); \
34+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.layout() == c10::kStrided)
3535

3636
at::Tensor AtenIpexCPUDev::dil_convolution(
3737
const at::Tensor & input,
@@ -41,6 +41,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
4141
at::IntArrayRef padding,
4242
at::IntArrayRef dilation,
4343
int64_t groups) {
44+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
4445
DEBUG("AtenIpexCPUDev::dil_convolution\n");
4546
dil::tensor dil_input;
4647
dil::tensor dil_weight;
@@ -175,18 +176,18 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
175176

176177
at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
177178
DEBUG("AtenIpexCPUDev::mkldnn_convolution\n");
178-
TORCH_INTERNAL_ASSERT(self.defined());
179-
TORCH_INTERNAL_ASSERT(weight.defined());
180-
TORCH_INTERNAL_ASSERT(self.layout() == c10::kStrided);
181-
TORCH_INTERNAL_ASSERT(weight.layout() == c10::kStrided);
182-
TORCH_INTERNAL_ASSERT(!(bias.defined()) || (bias.defined() && bias.layout() == c10::kStrided));
179+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.defined());
180+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.defined());
181+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.layout() == c10::kStrided);
182+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
183+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(bias.defined()) || (bias.defined() && bias.layout() == c10::kStrided));
183184
auto&& _ipex_self = bridge::shallowFallbackToCPUTensor(self);
184185
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
185186
auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias);
186187
auto&& _ipex_result = at::mkldnn_convolution(_ipex_self.contiguous(), _ipex_weight.contiguous(), _ipex_bias.contiguous(), padding, stride, dilation, groups);
187188
static_cast<void>(_ipex_result); // Avoid warnings in case not used
188-
TORCH_INTERNAL_ASSERT(_ipex_result.is_contiguous());
189-
TORCH_INTERNAL_ASSERT(_ipex_result.layout() == c10::kStrided);
189+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.is_contiguous());
190+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.layout() == c10::kStrided);
190191
return bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
191192
}
192193

@@ -210,12 +211,12 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac
210211

211212
std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
212213
DEBUG("AtenIpexCPUDev::mkldnn_convolution_backward\n");
213-
TORCH_INTERNAL_ASSERT(self.defined());
214-
TORCH_INTERNAL_ASSERT(grad_output.defined());
215-
TORCH_INTERNAL_ASSERT(weight.defined());
216-
TORCH_INTERNAL_ASSERT(self.layout() == c10::kStrided);
217-
TORCH_INTERNAL_ASSERT(grad_output.layout() == c10::kStrided);
218-
TORCH_INTERNAL_ASSERT(weight.layout() == c10::kStrided);
214+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.defined());
215+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_output.defined());
216+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.defined());
217+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.layout() == c10::kStrided);
218+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_output.layout() == c10::kStrided);
219+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
219220
auto&& _ipex_self = bridge::shallowFallbackToCPUTensor(self);
220221
auto&& _ipex_grad_output = bridge::shallowFallbackToCPUTensor(grad_output);
221222
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@
1313
namespace torch_ipex {
1414

1515
void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha) {
16-
TORCH_INTERNAL_ASSERT(grad.scalar_type() == at::ScalarType::BFloat16);
17-
TORCH_INTERNAL_ASSERT(top_half.scalar_type() == at::ScalarType::BFloat16);
18-
TORCH_INTERNAL_ASSERT(bot_half.scalar_type() == at::ScalarType::BFloat16);
19-
TORCH_INTERNAL_ASSERT(grad.device().type() == at::DeviceType::DPCPP);
20-
TORCH_INTERNAL_ASSERT(top_half.device().type() == at::DeviceType::DPCPP);
21-
TORCH_INTERNAL_ASSERT(bot_half.device().type() == at::DeviceType::DPCPP);
22-
TORCH_INTERNAL_ASSERT(top_half.sizes() == bot_half.sizes());
23-
TORCH_INTERNAL_ASSERT(top_half.is_contiguous());
24-
TORCH_INTERNAL_ASSERT(bot_half.is_contiguous());
16+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.scalar_type() == at::ScalarType::BFloat16);
17+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.scalar_type() == at::ScalarType::BFloat16);
18+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.scalar_type() == at::ScalarType::BFloat16);
19+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.device().type() == at::DeviceType::DPCPP);
20+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.device().type() == at::DeviceType::DPCPP);
21+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.device().type() == at::DeviceType::DPCPP);
22+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.sizes() == bot_half.sizes());
23+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.is_contiguous());
24+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.is_contiguous());
2525

2626
RECORD_FUNCTION("packed_add_", std::vector<c10::IValue>({top_half, bot_half, grad, alpha}), torch::autograd::Node::peek_at_next_sequence_nr());
2727
if (grad.is_sparse()) {
28-
TORCH_INTERNAL_ASSERT(top_half.dim() == 2);
28+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.dim() == 2);
2929
auto sparse_nnz = grad._nnz();
3030
auto sparse_dim = grad.sparse_dim();
3131
auto values = grad._values();
@@ -34,14 +34,14 @@ void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half,
3434
auto feature_size = values.stride(0);
3535
auto indices_accessor = indices.accessor<int64_t, 2>();
3636

37-
TORCH_INTERNAL_ASSERT(values.is_contiguous());
37+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
3838
auto value_ptr = values.data_ptr<at::BFloat16>();
3939
auto top_half_ptr = top_half.data_ptr<at::BFloat16>();
4040
auto bot_half_ptr = bot_half.data_ptr<at::BFloat16>();
4141

42-
TORCH_INTERNAL_ASSERT(value_ptr != nullptr);
43-
TORCH_INTERNAL_ASSERT(top_half_ptr != nullptr);
44-
TORCH_INTERNAL_ASSERT(bot_half_ptr != nullptr);
42+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value_ptr != nullptr);
43+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half_ptr != nullptr);
44+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half_ptr != nullptr);
4545

4646
std::vector<int64_t> sparse_stride(sparse_dim);
4747
for (int64_t d = 0; d < sparse_dim; d++) {
@@ -80,7 +80,7 @@ void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half,
8080
}
8181
});
8282
} else {
83-
TORCH_INTERNAL_ASSERT(grad.is_contiguous());
83+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.is_contiguous());
8484
//TODO: vector implementation basing on vector size
8585
union packed_bf16 {
8686
unsigned short s[2];
@@ -201,15 +201,15 @@ inline at::Tensor _interaction_forward(const std::vector<at::Tensor> & input) {
201201
std::vector<uint32_t> feature_sizes(input.size());
202202
std::vector<T *> input_data(input.size());
203203
for (int i = 0; i < input.size(); i++) {
204-
TORCH_INTERNAL_ASSERT(input[i].is_contiguous());
205-
TORCH_INTERNAL_ASSERT(input[i].device().is_dpcpp());
206-
TORCH_INTERNAL_ASSERT(input[i].dim() == 2);
204+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].is_contiguous());
205+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].device().is_dpcpp());
206+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].dim() == 2);
207207
feature_sizes[i] = input[i].sizes()[1];
208208
total_feature_size += input[i].sizes()[1];
209209
input_data[i] = input[i].data_ptr<T>();
210210
}
211211
auto vector_nums = total_feature_size / vector_size;
212-
TORCH_INTERNAL_ASSERT(total_feature_size % vector_size == 0);
212+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(total_feature_size % vector_size == 0);
213213
auto interact_feature_size = vector_nums * (vector_nums - 1) / 2;
214214
auto tr_vector_size = sizeof(T) == 4 ? vector_size : vector_size / 2;
215215
auto out = at::empty({batch_size, interact_feature_size + vector_size}, input[0].options());
@@ -239,7 +239,7 @@ inline at::Tensor _interaction_forward(const std::vector<at::Tensor> & input) {
239239

240240
template<typename T>
241241
inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input) {
242-
TORCH_INTERNAL_ASSERT(grad_out.is_contiguous());
242+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.is_contiguous());
243243
RECORD_FUNCTION("_interaction_backward", std::vector<c10::IValue>({grad_out, input}), torch::autograd::Node::peek_at_next_sequence_nr());
244244
uint32_t total_feature_size = 0;
245245
int64_t batch_size = input[0].sizes()[0];
@@ -257,7 +257,7 @@ inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out
257257
output_data[i] = output[i].data_ptr<T>();
258258
}
259259
auto vector_nums = total_feature_size / vector_size;
260-
TORCH_INTERNAL_ASSERT(total_feature_size % vector_size == 0);
260+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(total_feature_size % vector_size == 0);
261261
auto interact_feature_size = vector_nums * (vector_nums - 1) / 2;
262262
auto grad_out_data = grad_out.data_ptr<T>();
263263

@@ -305,11 +305,11 @@ inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out
305305

306306
at::Tensor AtenIpexTypeExt::interaction_forward(const std::vector<at::Tensor> & input) {
307307
if (input[0].scalar_type() == at::kFloat) {
308-
for (const auto &in : input) { TORCH_INTERNAL_ASSERT(in.scalar_type() == at::kFloat); }
308+
for (const auto &in : input) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(in.scalar_type() == at::kFloat); }
309309
return _interaction_forward<float>(input);
310310
} else {
311-
TORCH_INTERNAL_ASSERT(input[0].scalar_type() == at::kBFloat16);
312-
for (const auto &in : input) { TORCH_INTERNAL_ASSERT(in.scalar_type() == at::kBFloat16); }
311+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[0].scalar_type() == at::kBFloat16);
312+
for (const auto &in : input) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(in.scalar_type() == at::kBFloat16); }
313313
return _interaction_forward<at::BFloat16>(input);
314314
}
315315
}
@@ -318,18 +318,18 @@ std::vector<at::Tensor> AtenIpexTypeExt::interaction_backward(const at::Tensor &
318318
if (grad_out.scalar_type() == at::kFloat) {
319319
return _interaction_backward<float>(grad_out, input);
320320
} else {
321-
TORCH_INTERNAL_ASSERT(grad_out.scalar_type() == at::kBFloat16);
321+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.scalar_type() == at::kBFloat16);
322322
return _interaction_backward<at::BFloat16>(grad_out, input);
323323
}
324324
}
325325

326326
template<typename T>
327327
static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) {
328-
TORCH_INTERNAL_ASSERT(weights.is_contiguous());
329-
TORCH_INTERNAL_ASSERT(inputs.is_contiguous());
330-
TORCH_INTERNAL_ASSERT(offsets.is_contiguous());
331-
TORCH_INTERNAL_ASSERT(inputs.dim() == 1);
332-
TORCH_INTERNAL_ASSERT(weights.dim() == 2);
328+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.is_contiguous());
329+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.is_contiguous());
330+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(offsets.is_contiguous());
331+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.dim() == 1);
332+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.dim() == 2);
333333
RECORD_FUNCTION("_embedding_bag_forward", std::vector<c10::IValue>({weights, inputs, offsets}), torch::autograd::Node::peek_at_next_sequence_nr());
334334
auto batch_size = offsets.size(0);
335335
auto num_input = inputs.size(0);
@@ -345,7 +345,7 @@ static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const
345345
auto inputs_start = offsets_data[i];
346346
auto inputs_end = (i < batch_size - 1) ? offsets_data[i + 1] : num_input;
347347
// TODO: add acc_t support for bag size larger than 1
348-
TORCH_INTERNAL_ASSERT(inputs_end - inputs_start == 1);
348+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs_end - inputs_start == 1);
349349
auto out_data_ptr = &output_data[i * vector_size];
350350
#pragma omp simd
351351
for (int64_t v = 0; v < vector_size; v++) out_data_ptr[v] = 0.0;
@@ -361,8 +361,8 @@ static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const
361361
template<typename T>
362362
static inline at::Tensor _embedding_bag_backward(const at::Tensor &grad_out,
363363
const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor offsets) {
364-
TORCH_INTERNAL_ASSERT(inputs.dim() == 1);
365-
TORCH_INTERNAL_ASSERT(grad_out.dim() == 2);
364+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.dim() == 1);
365+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.dim() == 2);
366366
RECORD_FUNCTION("_embedding_bag_backward", std::vector<c10::IValue>({grad_out, weights, inputs, offsets}), torch::autograd::Node::peek_at_next_sequence_nr());
367367
auto batch_size = offsets.size(0);
368368
auto num_input = inputs.size(0);
@@ -408,7 +408,7 @@ at::Tensor AtenIpexTypeExt::embedding_bag_forward(const at::Tensor &weights, con
408408
if (weights.scalar_type() == at::kFloat) {
409409
return _embedding_bag_forward<float>(weights, inputs, offsets);
410410
} else {
411-
TORCH_INTERNAL_ASSERT(weights.scalar_type() == at::kBFloat16);
411+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.scalar_type() == at::kBFloat16);
412412
return _embedding_bag_forward<at::BFloat16>(weights, inputs, offsets);
413413
}
414414
}
@@ -418,7 +418,7 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out,
418418
if (grad_out.scalar_type() == at::kFloat) {
419419
return _embedding_bag_backward<float>(grad_out, weights, inputs, offsets);
420420
} else {
421-
TORCH_INTERNAL_ASSERT(grad_out.scalar_type() == at::kBFloat16);
421+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.scalar_type() == at::kBFloat16);
422422
return _embedding_bag_backward<at::BFloat16>(grad_out, weights, inputs, offsets);
423423
}
424424
}

0 commit comments

Comments
 (0)