Skip to content

Commit d685a58

Browse files
committed
enable bf16 layernorm
1 parent 009b5d8 commit d685a58

File tree

2 files changed

+87
-5
lines changed

2 files changed

+87
-5
lines changed

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,90 @@ def test_batch_norm3d_backward(self):
486486
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16.grad))
487487
self.assertEqual(x_man_bf16.grad.float(), x_auto_mix_bf16.grad)
488488

489+
class TestLayerNorm(TestCase):
490+
def test_layer_norm(self):
491+
rand_seed = int(get_rand_seed())
492+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
493+
494+
x_cpu, x_auto_mix_inference, x_auto_mix_train, x_man_bf16, x_auto_mix_train_bf16 = _gen_tensor(
495+
rand_seed, (2, 5, 10, 10))
496+
497+
op_cpu, op_auto_mix_inference, op_auto_mix_train, op_man_bf16, op_auto_mix_train_bf16 = _gen_op(
498+
rand_seed, torch.nn.LayerNorm([10, 10]), is_bn=True)
499+
500+
ref_cpu = op_cpu(x_cpu)
501+
with AutoDNNL(True), AutoMixPrecision(False):
502+
res_bf16 = op_man_bf16(x_man_bf16)
503+
self.assertEqual(res_bf16.dtype, torch.bfloat16)
504+
505+
# FW inference
506+
with AutoMixPrecision(True, train=False):
507+
self.assertEqual(x_auto_mix_inference.dtype, torch.float)
508+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_inference))
509+
res_auto_mix_inference = op_auto_mix_inference(x_auto_mix_inference)
510+
self.assertEqual(res_auto_mix_inference.dtype, torch.float)
511+
self.assertEqual(x_auto_mix_inference.dtype, torch.float)
512+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix_inference))
513+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_inference))
514+
self.assertEqual(res_bf16.float(), res_auto_mix_inference)
515+
516+
# FW train (input is not bf16 dil tensor)
517+
with AutoMixPrecision(True, train=True):
518+
self.assertEqual(x_auto_mix_train.dtype, torch.float)
519+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_train))
520+
res_auto_mix_train = op_auto_mix_train(x_auto_mix_train)
521+
self.assertEqual(res_auto_mix_train.dtype, torch.float)
522+
self.assertEqual(x_auto_mix_train.dtype, torch.float)
523+
self.assertFalse(ipex.core.is_bf16_dil_tensor(res_auto_mix_train))
524+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_train))
525+
self.assertEqual(ref_cpu, res_auto_mix_train)
526+
527+
# FW train (input is bf16 dil tensor)
528+
with AutoMixPrecision(True, train=True):
529+
self.assertEqual(x_auto_mix_train_bf16.dtype, torch.float)
530+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_train_bf16))
531+
res_auto_mix_train_bf16 = op_auto_mix_train_bf16(x_auto_mix_train_bf16)
532+
self.assertEqual(res_auto_mix_train_bf16.dtype, torch.float)
533+
self.assertEqual(x_auto_mix_train_bf16.dtype, torch.float)
534+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix_train_bf16))
535+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_train_bf16))
536+
self.assertEqual(res_bf16.float(), res_auto_mix_train_bf16)
537+
538+
def test_layer_norm_backward(self):
539+
rand_seed = int(get_rand_seed())
540+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
541+
x_cpu, _, x_auto_mix, x_man_bf16, x_auto_mix_bf16 = _gen_tensor(rand_seed, (2, 5, 10, 10), is_forward=False)
542+
543+
op_cpu, _, op_auto_mix, op_man_bf16, op_auto_mix_bf16 = _gen_op(rand_seed, torch.nn.LayerNorm([10, 10]), is_bn=True, is_forward=False)
544+
545+
out_cpu = op_cpu(x_cpu).sum()
546+
out_cpu.backward()
547+
with AutoDNNL(True), AutoMixPrecision(False, train=True):
548+
out_man_bf16 = op_man_bf16(x_man_bf16).sum()
549+
out_man_bf16.backward()
550+
self.assertEqual(x_man_bf16.grad.dtype, torch.bfloat16)
551+
self.assertEqual(x_cpu.grad.bfloat16().float(), x_man_bf16.grad, 1e-2)
552+
553+
# BW train (input is not bf16 dil tensor)
554+
with AutoMixPrecision(True, train=True):
555+
self.assertEqual(x_auto_mix.dtype, torch.float)
556+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix))
557+
out_auto_mix = op_auto_mix(x_auto_mix).sum()
558+
out_auto_mix.backward()
559+
self.assertEqual(x_auto_mix.grad.dtype, torch.float)
560+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix.grad))
561+
self.assertEqual(x_cpu.grad, x_auto_mix.grad)
562+
563+
# BW train (input is bf16 dil tensor)
564+
with AutoMixPrecision(True, train=True):
565+
self.assertEqual(x_auto_mix_bf16.dtype, torch.float)
566+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16))
567+
out_auto_mix_bf16 = op_auto_mix_bf16(x_auto_mix_bf16).sum()
568+
out_auto_mix_bf16.backward()
569+
self.assertEqual(x_auto_mix_bf16.grad.dtype, torch.float)
570+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16.grad))
571+
self.assertEqual(x_man_bf16.grad.float(), x_auto_mix_bf16.grad)
572+
489573
class TestRelu(TestCase):
490574
def test_relu(self):
491575
rand_seed = int(get_rand_seed())

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,8 +2467,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_layer_
24672467
double eps) {
24682468
DEBUG("AtenIpexCPUDev::dil_native_layer_norm\n");
24692469
CHECK_DNNL_OP_PRE_COND(X);
2470-
//It's a temporary solution to fall back to fp32 since bf16 layer_norm is not ready for dnnl path now.
2471-
dbl::comm::reorder_to_dtype(X, at::kFloat);
2470+
dbl::comm::reorder_to_bf16_for_mix_prec(X, true);
24722471
dil::tensor x = dbl::comm::try_gen_dil_tensor(X);
24732472
const dil::tensor scale = dbl::comm::try_gen_dil_tensor(gamma);
24742473
const dil::tensor shift = dbl::comm::try_gen_dil_tensor(beta);
@@ -2508,9 +2507,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_layer_
25082507
DEBUG("AtenIpexCPUDev::dil_native_layer_norm_backward\n");
25092508
CHECK_DNNL_OP_PRE_COND(dY);
25102509
CHECK_DNNL_OP_PRE_COND(X);
2511-
//it's a temporary solution to fall back to fp32 since bf16 layer_norm is not ready for dnnl path now.
2512-
dbl::comm::reorder_to_dtype(dY, at::kFloat);
2513-
dbl::comm::reorder_to_dtype(X, at::kFloat);
2510+
dbl::comm::reorder_to_bf16_for_mix_prec(dY, true);
2511+
dbl::comm::reorder_to_bf16_for_mix_prec(X, true);
25142512
dil::tensor dy = dbl::comm::try_gen_dil_tensor(dY);
25152513
dil::tensor x = dbl::comm::try_gen_dil_tensor(X);
25162514
dil::tensor m = dbl::comm::try_gen_dil_tensor(mean);

0 commit comments

Comments
 (0)