Skip to content

Commit dd06bd8

Browse files
authored
fix: Fix constant folding failure due to modelopt (#3565)
1 parent 2a93df3 commit dd06bd8

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,17 @@ class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
9999
def __init__(self, *args: Any, **kwargs: Any) -> None:
100100
super().__init__(*args, **kwargs)
101101

102-
# TODO: Update this function when quantization is added
103102
def is_impure(self, node: torch.fx.node.Node) -> bool:
104-
if node.target in (torch.ops.tensorrt.quantize_op.default,):
103+
# Set of known quantization ops to be excluded from constant folding.
104+
# Currently, we exclude all quantization ops coming from modelopt library.
105+
quantization_ops = {}
106+
try:
107+
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
108+
import modelopt.torch.quantization as mtq
109+
assert torch.ops.tensorrt.quantize_op.default
110+
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
111+
except Exception as e:
112+
pass
113+
if quantization_ops and node.target in quantization_ops:
105114
return True
106115
return False

0 commit comments

Comments
 (0)