File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
py/torch_tensorrt/dynamo/lowering/passes Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -99,8 +99,17 @@ class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
99
99
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
100
100
super ().__init__ (* args , ** kwargs )
101
101
102
- # TODO: Update this function when quantization is added
103
102
def is_impure (self , node : torch .fx .node .Node ) -> bool :
104
- if node .target in (torch .ops .tensorrt .quantize_op .default ,):
103
+ # Set of known quantization ops to be excluded from constant folding.
104
+ # Currently, we exclude all quantization ops coming from modelopt library.
105
+ quantization_ops = {}
106
+ try :
107
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
108
+ import modelopt .torch .quantization as mtq
109
+ assert torch .ops .tensorrt .quantize_op .default
110
+ quantization_ops .add (torch .ops .tensorrt .quantize_op .default )
111
+ except Exception as e :
112
+ pass
113
+ if quantization_ops and node .target in quantization_ops :
105
114
return True
106
115
return False
You can’t perform that action at this time.
0 commit comments