@@ -400,6 +400,77 @@ def forward(self, x):
400
400
self .assertFused (graph , ['aten::_convolution' , 'aten::relu' , 'aten::quantize_per_channel' ])
401
401
self .checkPatterns (graph , patterns )
402
402
403
+ @llga_test_env
404
+ def test_bmm_div_scalar (self ):
405
+ class M (nn .Module ):
406
+ def __init__ (self , div_value ):
407
+ super (M , self ).__init__ ()
408
+ self .div_value = div_value
409
+
410
+ def forward (self , x , y ):
411
+ mm_res = torch .matmul (x , y )
412
+ return mm_res .div (self .div_value )
413
+
414
+ x = torch .randn (128 , 16 , 384 , 64 )
415
+ y = torch .randn (128 , 1 , 64 , 384 )
416
+ patterns = [
417
+ ["aten::dequantize" , "aten::matmul" , "aten::div" ],
418
+ ]
419
+ m = M (8. )
420
+ graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , config_name = "bmm_div_scalar" , qscheme = torch .per_tensor_affine )
421
+ # TODO: enable the below check when matmul-div fusion is supported in the backend
422
+ # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
423
+ # self.assertFused(graph, ['aten::matmul', 'aten::div'])
424
+ # self.checkPatterns(graph, patterns)
425
+
426
+ @llga_test_env
427
+ def test_bmm_div_identity (self ):
428
+ class M (nn .Module ):
429
+ def __init__ (self , div_value ):
430
+ super (M , self ).__init__ ()
431
+ self .div_value = div_value
432
+
433
+ def forward (self , x , y ):
434
+ mm_res = torch .matmul (x , y )
435
+ return mm_res .div (self .div_value )
436
+
437
+ x = torch .randn (128 , 16 , 384 , 64 )
438
+ y = torch .randn (128 , 1 , 64 , 384 )
439
+ patterns = [
440
+ ["aten::dequantize" , "aten::matmul" ],
441
+ ]
442
+ m = M (1. )
443
+ graph = self .checkQuantizeTrace (m , [x , y ], atol = 2e-1 , config_name = "bmm_div_identity" , qscheme = torch .per_tensor_affine )
444
+ # divide by 1 should be removed by Constant Propagation
445
+ self .assertGraphContainsExactly (graph , "aten::div" , 0 , consider_subgraphs = True )
446
+ self .assertGraphContainsExactly (graph , LLGA_FUSION_GROUP , 1 )
447
+ self .assertFused (graph , ['aten::matmul' ])
448
+ # TODO: enable this check when int8 matmul is supported in the backend
449
+ # self.checkPatterns(graph, patterns)
450
+
451
+ @llga_test_env
452
+ def test_bmm_div_tensor (self ):
453
+ class M (nn .Module ):
454
+ def __init__ (self ):
455
+ super (M , self ).__init__ ()
456
+
457
+ def forward (self , x , y , z ):
458
+ mm_res = torch .matmul (x , y )
459
+ return mm_res .div (z )
460
+
461
+ x = torch .randn (128 , 16 , 384 , 64 )
462
+ y = torch .randn (128 , 1 , 64 , 384 )
463
+ patterns = [
464
+ ["aten::dequantize" , "aten::matmul" , "aten::div" ],
465
+ ]
466
+ for z in [torch .randn (384 ), torch .randn (128 , 16 , 384 , 384 )]:
467
+ m = M ()
468
+ graph = self .checkQuantizeTrace (m , [x , y , z ], atol = 2e-1 , config_name = "bmm_div_tensor" , qscheme = torch .per_tensor_affine )
469
+ # TODO: enable the below check when matmul-div fusion is supported in the backend
470
+ # self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
471
+ # self.assertFused(graph, ['aten::matmul', 'aten::div'])
472
+ # self.checkPatterns(graph, patterns)
473
+
403
474
class TestShapeFallback (JitLlgaTestCase ):
404
475
@unittest .skipIf (True , 'Size peephole optimization not enabled yet' )
405
476
@llga_test_env
0 commit comments