@@ -486,6 +486,90 @@ def test_batch_norm3d_backward(self):
486
486
self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 .grad ))
487
487
self .assertEqual (x_man_bf16 .grad .float (), x_auto_mix_bf16 .grad )
488
488
489
+ class TestLayerNorm (TestCase ):
490
+ def test_layer_norm (self ):
491
+ rand_seed = int (get_rand_seed ())
492
+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
493
+
494
+ x_cpu , x_auto_mix_inference , x_auto_mix_train , x_man_bf16 , x_auto_mix_train_bf16 = _gen_tensor (
495
+ rand_seed , (2 , 5 , 10 , 10 ))
496
+
497
+ op_cpu , op_auto_mix_inference , op_auto_mix_train , op_man_bf16 , op_auto_mix_train_bf16 = _gen_op (
498
+ rand_seed , torch .nn .LayerNorm ([10 , 10 ]), is_bn = True )
499
+
500
+ ref_cpu = op_cpu (x_cpu )
501
+ with AutoDNNL (True ), AutoMixPrecision (False ):
502
+ res_bf16 = op_man_bf16 (x_man_bf16 )
503
+ self .assertEqual (res_bf16 .dtype , torch .bfloat16 )
504
+
505
+ # FW inference
506
+ with AutoMixPrecision (True , train = False ):
507
+ self .assertEqual (x_auto_mix_inference .dtype , torch .float )
508
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_inference ))
509
+ res_auto_mix_inference = op_auto_mix_inference (x_auto_mix_inference )
510
+ self .assertEqual (res_auto_mix_inference .dtype , torch .float )
511
+ self .assertEqual (x_auto_mix_inference .dtype , torch .float )
512
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (res_auto_mix_inference ))
513
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_inference ))
514
+ self .assertEqual (res_bf16 .float (), res_auto_mix_inference )
515
+
516
+ # FW train (input is not bf16 dil tensor)
517
+ with AutoMixPrecision (True , train = True ):
518
+ self .assertEqual (x_auto_mix_train .dtype , torch .float )
519
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_train ))
520
+ res_auto_mix_train = op_auto_mix_train (x_auto_mix_train )
521
+ self .assertEqual (res_auto_mix_train .dtype , torch .float )
522
+ self .assertEqual (x_auto_mix_train .dtype , torch .float )
523
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (res_auto_mix_train ))
524
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix_train ))
525
+ self .assertEqual (ref_cpu , res_auto_mix_train )
526
+
527
+ # FW train (input is bf16 dil tensor)
528
+ with AutoMixPrecision (True , train = True ):
529
+ self .assertEqual (x_auto_mix_train_bf16 .dtype , torch .float )
530
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_train_bf16 ))
531
+ res_auto_mix_train_bf16 = op_auto_mix_train_bf16 (x_auto_mix_train_bf16 )
532
+ self .assertEqual (res_auto_mix_train_bf16 .dtype , torch .float )
533
+ self .assertEqual (x_auto_mix_train_bf16 .dtype , torch .float )
534
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (res_auto_mix_train_bf16 ))
535
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_train_bf16 ))
536
+ self .assertEqual (res_bf16 .float (), res_auto_mix_train_bf16 )
537
+
538
+ def test_layer_norm_backward (self ):
539
+ rand_seed = int (get_rand_seed ())
540
+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
541
+ x_cpu , _ , x_auto_mix , x_man_bf16 , x_auto_mix_bf16 = _gen_tensor (rand_seed , (2 , 5 , 10 , 10 ), is_forward = False )
542
+
543
+ op_cpu , _ , op_auto_mix , op_man_bf16 , op_auto_mix_bf16 = _gen_op (rand_seed , torch .nn .LayerNorm ([10 , 10 ]), is_bn = True , is_forward = False )
544
+
545
+ out_cpu = op_cpu (x_cpu ).sum ()
546
+ out_cpu .backward ()
547
+ with AutoDNNL (True ), AutoMixPrecision (False , train = True ):
548
+ out_man_bf16 = op_man_bf16 (x_man_bf16 ).sum ()
549
+ out_man_bf16 .backward ()
550
+ self .assertEqual (x_man_bf16 .grad .dtype , torch .bfloat16 )
551
+ self .assertEqual (x_cpu .grad .bfloat16 ().float (), x_man_bf16 .grad , 1e-2 )
552
+
553
+ # BW train (input is not bf16 dil tensor)
554
+ with AutoMixPrecision (True , train = True ):
555
+ self .assertEqual (x_auto_mix .dtype , torch .float )
556
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix ))
557
+ out_auto_mix = op_auto_mix (x_auto_mix ).sum ()
558
+ out_auto_mix .backward ()
559
+ self .assertEqual (x_auto_mix .grad .dtype , torch .float )
560
+ self .assertFalse (ipex .core .is_bf16_dil_tensor (x_auto_mix .grad ))
561
+ self .assertEqual (x_cpu .grad , x_auto_mix .grad )
562
+
563
+ # BW train (input is bf16 dil tensor)
564
+ with AutoMixPrecision (True , train = True ):
565
+ self .assertEqual (x_auto_mix_bf16 .dtype , torch .float )
566
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 ))
567
+ out_auto_mix_bf16 = op_auto_mix_bf16 (x_auto_mix_bf16 ).sum ()
568
+ out_auto_mix_bf16 .backward ()
569
+ self .assertEqual (x_auto_mix_bf16 .grad .dtype , torch .float )
570
+ self .assertTrue (ipex .core .is_bf16_dil_tensor (x_auto_mix_bf16 .grad ))
571
+ self .assertEqual (x_man_bf16 .grad .float (), x_auto_mix_bf16 .grad )
572
+
489
573
class TestRelu (TestCase ):
490
574
def test_relu (self ):
491
575
rand_seed = int (get_rand_seed ())
0 commit comments