Skip to content

Commit fcab70d

Browse files
authored
Liangan1/embedding bf16 enable (#36)
* Enable BF16 embedding convert type in utils.py Backgroud: the Embedding input index is long type and the output is FP32 by default even BF16 embedding table is enabled. For residual block in BERT, the embedding FP32 result will add to the linear BF16 output and get FP32 output which means many dtype convert will be introduced and only partial ops can use BF16. * Add Layernorm to white list of autocast Backgroud: LayerNorm is fallthroup op by defualt and weight/bias is FP32. When input is BF16 there will be dtype error. * 1) Add LayerNorm to module convert dtype list. 2) Refine code
1 parent 35a8b91 commit fcab70d

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

intel_pytorch_extension_py/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,19 @@ def _replace_lstm_with_ipex_lstm(model):
3333
_replace_lstm_with_ipex_lstm(child)
3434

3535
def convert_module_data_type(module, dtype):
36-
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
37-
weight_data = module.weight.detach().clone().to(dtype)
38-
module.weight.data = weight_data
39-
if module.bias is not None:
40-
bias_data = module.bias.detach().clone().to(dtype)
41-
module.bias.data = bias_data
36+
# convert weights(bias) of module to dtype to reduce dtype reorder
37+
module_convert_list = [torch.nn.Conv2d,
38+
torch.nn.Linear,
39+
torch.nn.Embedding,
40+
torch.nn.LayerNorm]
41+
for module_cls in module_convert_list:
42+
if isinstance(module, module_cls):
43+
weight_data = module.weight.detach().clone().to(dtype)
44+
module.weight.data = weight_data
45+
if hasattr(module, 'bias') and module.bias is not None:
46+
bias_data = module.bias.detach().clone().to(dtype)
47+
module.bias.data = bias_data
48+
break
4249
for child in module.children():
4350
convert_module_data_type(child, dtype)
4451
return module

torch_ipex/csrc/autocast_mode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ MAKE_REGISTER_FUNC(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const T
180180
IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), user_defined_dtype)
181181
MAKE_REGISTER_FUNC(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&,
182182
IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), user_defined_dtype)
183+
MAKE_REGISTER_FUNC(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), user_defined_dtype)
183184

184185
// fp32 cast policy
185186
MAKE_REGISTER_FUNC(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)

0 commit comments

Comments
 (0)