Skip to content

Commit b86278f

Browse files
committed
adding the option in _compiler.py
1 parent b66350e commit b86278f

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def cross_compile_for_windows(
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
9999
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
100100
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
101-
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
101+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
102102
**kwargs: Any,
103103
) -> torch.fx.GraphModule:
104104
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -174,7 +174,7 @@ def cross_compile_for_windows(
174174
enable_weight_streaming (bool): Enable weight streaming.
175175
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
176176
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
177-
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
177+
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
178178
**kwargs: Any,
179179
Returns:
180180
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -334,7 +334,7 @@ def cross_compile_for_windows(
334334
"enable_weight_streaming": enable_weight_streaming,
335335
"tiling_optimization_level": tiling_optimization_level,
336336
"l2_limit_for_tiling": l2_limit_for_tiling,
337-
"use_aot_joint_export": use_aot_joint_export,
337+
"use_distributed_mode_trace": use_distributed_mode_trace,
338338
}
339339

340340
# disable the following settings is not supported for cross compilation for windows feature
@@ -424,7 +424,7 @@ def compile(
424424
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
425425
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
426426
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
427-
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
427+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
428428
**kwargs: Any,
429429
) -> torch.fx.GraphModule:
430430
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -502,7 +502,7 @@ def compile(
502502
enable_weight_streaming (bool): Enable weight streaming.
503503
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
504504
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
505-
505+
ç
506506
**kwargs: Any,
507507
Returns:
508508
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -679,7 +679,7 @@ def compile(
679679
"enable_weight_streaming": enable_weight_streaming,
680680
"tiling_optimization_level": tiling_optimization_level,
681681
"l2_limit_for_tiling": l2_limit_for_tiling,
682-
"use_aot_joint_export": use_aot_joint_export,
682+
"use_distributed_mode_trace": use_distributed_mode_trace,
683683
}
684684

685685
settings = CompilationSettings(**compilation_options)
@@ -970,7 +970,7 @@ def convert_exported_program_to_serialized_trt_engine(
970970
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
971971
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
972972
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
973-
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
973+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
974974
**kwargs: Any,
975975
) -> bytes:
976976
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1036,7 +1036,7 @@ def convert_exported_program_to_serialized_trt_engine(
10361036
enable_weight_streaming (bool): Enable weight streaming.
10371037
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
10381038
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
1039-
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
1039+
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
10401040
Returns:
10411041
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10421042
"""
@@ -1155,7 +1155,7 @@ def convert_exported_program_to_serialized_trt_engine(
11551155
"enable_weight_streaming": enable_weight_streaming,
11561156
"tiling_optimization_level": tiling_optimization_level,
11571157
"l2_limit_for_tiling": l2_limit_for_tiling,
1158-
"use_aot_joint_export": use_aot_joint_export,
1158+
"use_distributed_mode_trace": use_distributed_mode_trace,
11591159
}
11601160

11611161
settings = CompilationSettings(**compilation_options)

0 commit comments

Comments
 (0)