Skip to content

Commit dac3224

Browse files
authored
fix emb under int8 autocast (#14)
1 parent 8f146ca commit dac3224

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch_ipex/csrc/cpu/embeddingbag.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ at::Tensor embedding_bag(
404404
#if defined(ENABLE_AUTOCAST_VERBOSE)
405405
verbose::OpNameGuard op_name("embedding_bag");
406406
#endif
407-
auto casted_weight = at::GradMode::is_enabled() ? weight : cpu_cached_cast(at::kBFloat16, weight);
407+
auto target_type = get_autocast_dtype();
408+
// only have bf16 support now, keep fp32 for other target_type
409+
bool cast_to_bfloat16 = !at::GradMode::is_enabled() && at::kBFloat16 == target_type;
410+
auto casted_weight = cast_to_bfloat16 ? cpu_cached_cast(at::kBFloat16, weight) : weight;
408411
return op.call(casted_weight, indices, offsets, sparse, include_last_offset);
409412
}
410413

0 commit comments

Comments
 (0)