Skip to content

Refine assert IPEX #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CPU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ IF(CMAKE_BUILD_TYPE MATCHES Debug)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -D_DEBUG")
ELSE()
message("Release build.")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DNDEBUG")
ENDIF()

# ---[ Build flags
Expand Down
8 changes: 4 additions & 4 deletions scripts/cpu/gen-dense-cpu-ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def is_out_func(fname):
param_seq_str = param_var
if param_var in dnnl_tensor_param_vars:
if param_var == 'out' and is_out_func(fname):
code += ' TORCH_INTERNAL_ASSERT({}.is_contiguous());\n'.format(param_var)
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.is_contiguous());\n'.format(param_var)
else:
param_seq_str = '{}.is_contiguous() ? {} : {}.contiguous()'.format(param_var, param_var, param_var)
param_seq_str_vec.append(param_seq_str)
Expand Down Expand Up @@ -334,10 +334,10 @@ def gen_fallback_prepare_code(self, cpp_sig):
ipex_name = '_ipex_{}'.format(param.name)
param.ipex_name = ipex_name
check_cond = '{}.device().type() == at::DeviceType::DPCPP'.format(param.name)
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
code += ' at::TensorOptions {} = {}.device(at::DeviceType::CPU);\n'.format(ipex_name, param.name)
elif param.core_type == 'Storage':
code += ' TORCH_INTERNAL_ASSERT({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
elif param.core_type == 'MemoryFormat':
if param.is_optional:
check_cond = '{}.value_or(c10::MemoryFormat::Contiguous) != c10::MemoryFormat::Contiguous'.format(param.name)
Expand All @@ -352,7 +352,7 @@ def gen_fallback_prepare_code(self, cpp_sig):
assert param.core_type == 'Tensor'
ipex_name = '_ipex_{}'.format(param.name)
check_cond = '{}.layout() == c10::kStrided'.format(param.name)
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
code += ' auto&& {} = bridge::{}({});\n'.format(ipex_name, _SHALLOW_FALLBACK_TO_CPU_TENSOR, param.name)
param.ipex_name = ipex_name
return op_check_code + code
Expand Down
4 changes: 2 additions & 2 deletions scripts/cpu/gen-sparse-cpu-ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ def gen_fallback_prepare_code(self, cpp_sig):
ipex_name = '_ipex_{}'.format(param.name)
param.ipex_name = ipex_name
check_cond = '{}.device().type() == at::DeviceType::DPCPP'.format(param.name)
op_check_code += ' TORCH_INTERNAL_ASSERT({});\n'.format(check_cond)
op_check_code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({});\n'.format(check_cond)
code += ' at::TensorOptions {} = {}.device(at::DeviceType::CPU);\n'.format(ipex_name, param.name)
elif param.core_type == 'Storage':
code += ' TORCH_INTERNAL_ASSERT({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
code += ' TORCH_INTERNAL_ASSERT_DEBUG_ONLY({}.device_type() == c10::DeviceType::DPCPP);\n'.format(param.name)
elif param.core_type == 'MemoryFormat':
None
elif param.core_type != 'Tensor':
Expand Down
251 changes: 76 additions & 175 deletions torch_ipex/csrc/aten_ipex_bridge.cpp

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions torch_ipex/csrc/aten_ipex_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ namespace torch_ipex {
namespace bridge {

// Convert DPCPP tensor to CPU tensor
at::Tensor fallbackToCPUTensor(const at::Tensor& ipexTensor);
at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor);
std::vector<at::Tensor> fallbackToCPUTensorList(const at::TensorList&);
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);

void attachShadeDataConext(const at::Tensor& tensor);
Expand Down Expand Up @@ -51,9 +49,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy
void reorderTensorToScalaraType(const at::Tensor& ipexTensor, at::ScalarType dstScalarType);

// Convert CPU tensor to DPCPP tensor
at::Tensor upgradeToDPCPPTensor(const at::Tensor& ipexTensor);
at::Tensor shallowUpgradeToDPCPPTensor(const at::Tensor& ipexTensor);
std::vector<at::Tensor> upgradeToDPCPPTensorVec(const std::vector<at::Tensor> &);
std::vector<at::Tensor> shallowUpgradeToDPCPPTensorVec(const std::vector<at::Tensor> &);

// The last character A means alias. This function is for aten alias
Expand Down
35 changes: 18 additions & 17 deletions torch_ipex/csrc/cpu/DevOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace cpu {
#define DEBUG(fmt)
#endif

#define CHECK_DNNL_OP_PRE_COND(tensor) \
TORCH_INTERNAL_ASSERT(tensor.defined()); \
TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); \
TORCH_INTERNAL_ASSERT(tensor.layout() == c10::kStrided)
#define CHECK_DNNL_OP_PRE_COND(tensor) \
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.defined()); \
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.is_contiguous()); \
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.layout() == c10::kStrided)

at::Tensor AtenIpexCPUDev::dil_convolution(
const at::Tensor & input,
Expand All @@ -41,6 +41,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
at::IntArrayRef padding,
at::IntArrayRef dilation,
int64_t groups) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
DEBUG("AtenIpexCPUDev::dil_convolution\n");
dil::tensor dil_input;
dil::tensor dil_weight;
Expand Down Expand Up @@ -175,18 +176,18 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input

at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
DEBUG("AtenIpexCPUDev::mkldnn_convolution\n");
TORCH_INTERNAL_ASSERT(self.defined());
TORCH_INTERNAL_ASSERT(weight.defined());
TORCH_INTERNAL_ASSERT(self.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT(weight.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT(!(bias.defined()) || (bias.defined() && bias.layout() == c10::kStrided));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.defined());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.defined());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!(bias.defined()) || (bias.defined() && bias.layout() == c10::kStrided));
auto&& _ipex_self = bridge::shallowFallbackToCPUTensor(self);
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias);
auto&& _ipex_result = at::mkldnn_convolution(_ipex_self.contiguous(), _ipex_weight.contiguous(), _ipex_bias.contiguous(), padding, stride, dilation, groups);
static_cast<void>(_ipex_result); // Avoid warnings in case not used
TORCH_INTERNAL_ASSERT(_ipex_result.is_contiguous());
TORCH_INTERNAL_ASSERT(_ipex_result.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_ipex_result.layout() == c10::kStrided);
return bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
}

Expand All @@ -210,12 +211,12 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac

std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::mkldnn_convolution_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask) {
DEBUG("AtenIpexCPUDev::mkldnn_convolution_backward\n");
TORCH_INTERNAL_ASSERT(self.defined());
TORCH_INTERNAL_ASSERT(grad_output.defined());
TORCH_INTERNAL_ASSERT(weight.defined());
TORCH_INTERNAL_ASSERT(self.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT(grad_output.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT(weight.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.defined());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_output.defined());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.defined());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_output.layout() == c10::kStrided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
auto&& _ipex_self = bridge::shallowFallbackToCPUTensor(self);
auto&& _ipex_grad_output = bridge::shallowFallbackToCPUTensor(grad_output);
auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight);
Expand Down
70 changes: 35 additions & 35 deletions torch_ipex/csrc/cpu/ExtendOPs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
namespace torch_ipex {

void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half, const at::Tensor & grad, float alpha) {
TORCH_INTERNAL_ASSERT(grad.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT(top_half.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT(bot_half.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT(grad.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT(top_half.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT(bot_half.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT(top_half.sizes() == bot_half.sizes());
TORCH_INTERNAL_ASSERT(top_half.is_contiguous());
TORCH_INTERNAL_ASSERT(bot_half.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.scalar_type() == at::ScalarType::BFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.device().type() == at::DeviceType::DPCPP);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.sizes() == bot_half.sizes());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half.is_contiguous());

RECORD_FUNCTION("packed_add_", std::vector<c10::IValue>({top_half, bot_half, grad, alpha}), torch::autograd::Node::peek_at_next_sequence_nr());
if (grad.is_sparse()) {
TORCH_INTERNAL_ASSERT(top_half.dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half.dim() == 2);
auto sparse_nnz = grad._nnz();
auto sparse_dim = grad.sparse_dim();
auto values = grad._values();
Expand All @@ -34,14 +34,14 @@ void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half,
auto feature_size = values.stride(0);
auto indices_accessor = indices.accessor<int64_t, 2>();

TORCH_INTERNAL_ASSERT(values.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
auto value_ptr = values.data_ptr<at::BFloat16>();
auto top_half_ptr = top_half.data_ptr<at::BFloat16>();
auto bot_half_ptr = bot_half.data_ptr<at::BFloat16>();

TORCH_INTERNAL_ASSERT(value_ptr != nullptr);
TORCH_INTERNAL_ASSERT(top_half_ptr != nullptr);
TORCH_INTERNAL_ASSERT(bot_half_ptr != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(value_ptr != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(top_half_ptr != nullptr);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bot_half_ptr != nullptr);

std::vector<int64_t> sparse_stride(sparse_dim);
for (int64_t d = 0; d < sparse_dim; d++) {
Expand Down Expand Up @@ -80,7 +80,7 @@ void AtenIpexTypeExt::packed_add_(at::Tensor & top_half, at::Tensor & bot_half,
}
});
} else {
TORCH_INTERNAL_ASSERT(grad.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad.is_contiguous());
//TODO: vector implementation basing on vector size
union packed_bf16 {
unsigned short s[2];
Expand Down Expand Up @@ -201,15 +201,15 @@ inline at::Tensor _interaction_forward(const std::vector<at::Tensor> & input) {
std::vector<uint32_t> feature_sizes(input.size());
std::vector<T *> input_data(input.size());
for (int i = 0; i < input.size(); i++) {
TORCH_INTERNAL_ASSERT(input[i].is_contiguous());
TORCH_INTERNAL_ASSERT(input[i].device().is_dpcpp());
TORCH_INTERNAL_ASSERT(input[i].dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].device().is_dpcpp());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[i].dim() == 2);
feature_sizes[i] = input[i].sizes()[1];
total_feature_size += input[i].sizes()[1];
input_data[i] = input[i].data_ptr<T>();
}
auto vector_nums = total_feature_size / vector_size;
TORCH_INTERNAL_ASSERT(total_feature_size % vector_size == 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(total_feature_size % vector_size == 0);
auto interact_feature_size = vector_nums * (vector_nums - 1) / 2;
auto tr_vector_size = sizeof(T) == 4 ? vector_size : vector_size / 2;
auto out = at::empty({batch_size, interact_feature_size + vector_size}, input[0].options());
Expand Down Expand Up @@ -239,7 +239,7 @@ inline at::Tensor _interaction_forward(const std::vector<at::Tensor> & input) {

template<typename T>
inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input) {
TORCH_INTERNAL_ASSERT(grad_out.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.is_contiguous());
RECORD_FUNCTION("_interaction_backward", std::vector<c10::IValue>({grad_out, input}), torch::autograd::Node::peek_at_next_sequence_nr());
uint32_t total_feature_size = 0;
int64_t batch_size = input[0].sizes()[0];
Expand All @@ -257,7 +257,7 @@ inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out
output_data[i] = output[i].data_ptr<T>();
}
auto vector_nums = total_feature_size / vector_size;
TORCH_INTERNAL_ASSERT(total_feature_size % vector_size == 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(total_feature_size % vector_size == 0);
auto interact_feature_size = vector_nums * (vector_nums - 1) / 2;
auto grad_out_data = grad_out.data_ptr<T>();

Expand Down Expand Up @@ -305,11 +305,11 @@ inline std::vector<at::Tensor> _interaction_backward(const at::Tensor & grad_out

at::Tensor AtenIpexTypeExt::interaction_forward(const std::vector<at::Tensor> & input) {
if (input[0].scalar_type() == at::kFloat) {
for (const auto &in : input) { TORCH_INTERNAL_ASSERT(in.scalar_type() == at::kFloat); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongzhen1 , Do you think it should be replaced by TORCH_CHECK?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's OK for me

for (const auto &in : input) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(in.scalar_type() == at::kFloat); }
return _interaction_forward<float>(input);
} else {
TORCH_INTERNAL_ASSERT(input[0].scalar_type() == at::kBFloat16);
for (const auto &in : input) { TORCH_INTERNAL_ASSERT(in.scalar_type() == at::kBFloat16); }
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input[0].scalar_type() == at::kBFloat16);
for (const auto &in : input) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(in.scalar_type() == at::kBFloat16); }
return _interaction_forward<at::BFloat16>(input);
}
}
Expand All @@ -318,18 +318,18 @@ std::vector<at::Tensor> AtenIpexTypeExt::interaction_backward(const at::Tensor &
if (grad_out.scalar_type() == at::kFloat) {
return _interaction_backward<float>(grad_out, input);
} else {
TORCH_INTERNAL_ASSERT(grad_out.scalar_type() == at::kBFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.scalar_type() == at::kBFloat16);
return _interaction_backward<at::BFloat16>(grad_out, input);
}
}

template<typename T>
static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets) {
TORCH_INTERNAL_ASSERT(weights.is_contiguous());
TORCH_INTERNAL_ASSERT(inputs.is_contiguous());
TORCH_INTERNAL_ASSERT(offsets.is_contiguous());
TORCH_INTERNAL_ASSERT(inputs.dim() == 1);
TORCH_INTERNAL_ASSERT(weights.dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(offsets.is_contiguous());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.dim() == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.dim() == 2);
RECORD_FUNCTION("_embedding_bag_forward", std::vector<c10::IValue>({weights, inputs, offsets}), torch::autograd::Node::peek_at_next_sequence_nr());
auto batch_size = offsets.size(0);
auto num_input = inputs.size(0);
Expand All @@ -345,7 +345,7 @@ static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const
auto inputs_start = offsets_data[i];
auto inputs_end = (i < batch_size - 1) ? offsets_data[i + 1] : num_input;
// TODO: add acc_t support for bag size larger than 1
TORCH_INTERNAL_ASSERT(inputs_end - inputs_start == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs_end - inputs_start == 1);
auto out_data_ptr = &output_data[i * vector_size];
#pragma omp simd
for (int64_t v = 0; v < vector_size; v++) out_data_ptr[v] = 0.0;
Expand All @@ -361,8 +361,8 @@ static inline at::Tensor _embedding_bag_forward(const at::Tensor &weights, const
template<typename T>
static inline at::Tensor _embedding_bag_backward(const at::Tensor &grad_out,
const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor offsets) {
TORCH_INTERNAL_ASSERT(inputs.dim() == 1);
TORCH_INTERNAL_ASSERT(grad_out.dim() == 2);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inputs.dim() == 1);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.dim() == 2);
RECORD_FUNCTION("_embedding_bag_backward", std::vector<c10::IValue>({grad_out, weights, inputs, offsets}), torch::autograd::Node::peek_at_next_sequence_nr());
auto batch_size = offsets.size(0);
auto num_input = inputs.size(0);
Expand Down Expand Up @@ -408,7 +408,7 @@ at::Tensor AtenIpexTypeExt::embedding_bag_forward(const at::Tensor &weights, con
if (weights.scalar_type() == at::kFloat) {
return _embedding_bag_forward<float>(weights, inputs, offsets);
} else {
TORCH_INTERNAL_ASSERT(weights.scalar_type() == at::kBFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weights.scalar_type() == at::kBFloat16);
return _embedding_bag_forward<at::BFloat16>(weights, inputs, offsets);
}
}
Expand All @@ -418,7 +418,7 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out,
if (grad_out.scalar_type() == at::kFloat) {
return _embedding_bag_backward<float>(grad_out, weights, inputs, offsets);
} else {
TORCH_INTERNAL_ASSERT(grad_out.scalar_type() == at::kBFloat16);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_out.scalar_type() == at::kBFloat16);
return _embedding_bag_backward<at::BFloat16>(grad_out, weights, inputs, offsets);
}
}
Expand Down
Loading