-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Saving modules using the AOTI format #3567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:24:19.021193+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:24:42.171661+00:00
@@ -32,27 +32,31 @@
torch_ex_input = torch.randn([1, 3, 5, 5])
compile_settings = {
"arg_inputs": [torch_ex_input],
"ir": "dynamo",
"enabled_precisions": {torch.float32},
- "min_block_size": 1
+ "min_block_size": 1,
}
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
- torch_tensorrt.save(cg_trt_module,
- file_path="torchtrt_aoti_conv_gelu.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ cg_trt_module,
+ file_path="torchtrt_aoti_conv_gelu.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
norm_model = Norm().eval().cuda()
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
- torch_tensorrt.save(norm_trt_module,
- file_path="torchtrt_aoti_norm.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ norm_trt_module,
+ file_path="torchtrt_aoti_norm.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
print("Generated TorchTRT-AOTI models.")
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:24:19.028193+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:24:42.611815+00:00
@@ -662,11 +662,13 @@
# [Optional] Specify the generated shared library path. If not specified,
# the generated artifact is stored in your system temp directory.
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -680,20 +682,24 @@
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
torch._inductor.aoti_compile_and_package(
exp_program,
# [Optional] Specify the generated shared library path. If not specified,
# the generated artifact is stored in your system temp directory.
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
else:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
@@ -703,11 +709,13 @@
kwargs=kwarg_inputs,
strict=False,
)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -715,6 +723,8 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
8a88c8c
to
151d4aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:26:05.932035+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:26:29.367903+00:00
@@ -32,27 +32,31 @@
torch_ex_input = torch.randn([1, 3, 5, 5])
compile_settings = {
"arg_inputs": [torch_ex_input],
"ir": "dynamo",
"enabled_precisions": {torch.float32},
- "min_block_size": 1
+ "min_block_size": 1,
}
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
- torch_tensorrt.save(cg_trt_module,
- file_path="torchtrt_aoti_conv_gelu.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ cg_trt_module,
+ file_path="torchtrt_aoti_conv_gelu.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
norm_model = Norm().eval().cuda()
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
- torch_tensorrt.save(norm_trt_module,
- file_path="torchtrt_aoti_norm.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ norm_trt_module,
+ file_path="torchtrt_aoti_norm.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
print("Generated TorchTRT-AOTI models.")
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:26:05.939035+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:26:29.897101+00:00
@@ -665,11 +665,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -695,11 +699,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
else:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
@@ -709,11 +715,13 @@
kwargs=kwarg_inputs,
strict=False,
)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -721,6 +729,8 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
151d4aa
to
3158232
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:53:06.019454+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 21:53:31.989118+00:00
@@ -32,27 +32,31 @@
torch_ex_input = torch.randn([1, 3, 5, 5])
compile_settings = {
"arg_inputs": [torch_ex_input],
"ir": "dynamo",
"enabled_precisions": {torch.float32},
- "min_block_size": 1
+ "min_block_size": 1,
}
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
- torch_tensorrt.save(cg_trt_module,
- file_path="torchtrt_aoti_conv_gelu.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ cg_trt_module,
+ file_path="torchtrt_aoti_conv_gelu.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
norm_model = Norm().eval().cuda()
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
- torch_tensorrt.save(norm_trt_module,
- file_path="torchtrt_aoti_norm.pt2",
- output_format="aot_inductor",
- retrace=True,
- arg_inputs=[torch_ex_input])
+ torch_tensorrt.save(
+ norm_trt_module,
+ file_path="torchtrt_aoti_norm.pt2",
+ output_format="aot_inductor",
+ retrace=True,
+ arg_inputs=[torch_ex_input],
+ )
print("Generated TorchTRT-AOTI models.")
if __name__ == "__main__":
main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:53:06.026454+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:53:32.496806+00:00
@@ -665,11 +665,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -695,11 +699,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
else:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
@@ -709,11 +715,13 @@
kwargs=kwarg_inputs,
strict=False,
)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -721,6 +729,8 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
b6e9520
to
f9280a4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:56:04.736524+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2025-06-11 21:56:27.792951+00:00
@@ -665,11 +665,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
logger.warning(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -695,11 +699,13 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
else:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
)
@@ -709,11 +715,13 @@
kwargs=kwarg_inputs,
strict=False,
)
if output_format == "exported_program":
- torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+ torch.export.save(
+ exp_program, file_path, pickle_protocol=pickle_protocol
+ )
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]
@@ -721,6 +729,8 @@
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
- raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+ raise RuntimeError(
+ "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+ )
f9280a4
to
6a81cb3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:07:25.926913+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:07:49.213965+00:00
@@ -55,15 +55,16 @@
retrace=True,
arg_inputs=[torch_ex_input],
)
print("Generated TorchTRT-AOTI models.")
- loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+ loaded_cg_trt_module = torch._inductor.aoti_load_package(
+ "torchtrt_aoti_conv_gelu.pt2"
+ )
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
with torch.inference_mode():
print(loaded_cg_trt_module(torch_ex_input))
print(loaded_norm_trt_module(torch_ex_input))
-
if __name__ == "__main__":
main()
6a81cb3
to
079dd81
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:08:42.860192+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:09:06.447260+00:00
@@ -55,15 +55,16 @@
retrace=True,
arg_inputs=[torch_ex_input],
)
print("Generated TorchTRT-AOTI models.")
- loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+ loaded_cg_trt_module = torch._inductor.aoti_load_package(
+ "torchtrt_aoti_conv_gelu.pt2"
+ )
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
with torch.inference_mode():
print(loaded_cg_trt_module(torch_ex_input))
print(loaded_norm_trt_module(torch_ex_input))
-
if __name__ == "__main__":
main()
079dd81
to
607df61
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:09:26.168027+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py 2025-06-11 22:09:48.306309+00:00
@@ -55,15 +55,16 @@
retrace=True,
arg_inputs=[torch_ex_input],
)
print("Generated TorchTRT-AOTI models.")
- loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+ loaded_cg_trt_module = torch._inductor.aoti_load_package(
+ "torchtrt_aoti_conv_gelu.pt2"
+ )
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
with torch.inference_mode():
print(loaded_cg_trt_module(torch_ex_input))
print(loaded_norm_trt_module(torch_ex_input))
-
if __name__ == "__main__":
main()
b23bab2
to
ef51eac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:42:57.102541+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:43:25.570042+00:00
@@ -595,10 +595,11 @@
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
"""
This tests save/load API on the AOTI format
"""
@@ -630,11 +631,17 @@
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
print(f"{tmp_path}/trt.pt2")
- torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+ torchtrt.save(
+ trt_gm,
+ f"{tmp_path}/trt.pt2",
+ output_format="aot_inductor",
+ arg_inputs=[input],
+ retrace=True,
+ )
trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
outputs_trt_deser = trt_ts_module(input)
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
8bb15a8
to
74b7a97
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:44:22.235633+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:44:51.642890+00:00
@@ -595,10 +595,11 @@
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
"""
This tests save/load API on the AOTI format
"""
@@ -630,11 +631,17 @@
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
print(f"{tmp_path}/trt.pt2")
- torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+ torchtrt.save(
+ trt_gm,
+ f"{tmp_path}/trt.pt2",
+ output_format="aot_inductor",
+ arg_inputs=[input],
+ retrace=True,
+ )
trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
outputs_trt_deser = trt_ts_module(input)
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:44:54.315219+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-11 23:45:21.971652+00:00
@@ -595,10 +595,11 @@
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
"""
This tests save/load API on the AOTI format
"""
@@ -630,11 +631,17 @@
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
print(f"{tmp_path}/trt.pt2")
- torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+ torchtrt.save(
+ trt_gm,
+ f"{tmp_path}/trt.pt2",
+ output_format="aot_inductor",
+ arg_inputs=[input],
+ retrace=True,
+ )
trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
outputs_trt_deser = trt_ts_module(input)
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
74b7a97
to
c2e52bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-13 18:41:08.235067+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py 2025-06-13 18:41:43.974199+00:00
@@ -595,10 +595,11 @@
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
"""
This tests save/load API on the AOTI format
"""
@@ -630,11 +631,17 @@
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
print(f"{tmp_path}/trt.pt2")
- torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+ torchtrt.save(
+ trt_gm,
+ f"{tmp_path}/trt.pt2",
+ output_format="aot_inductor",
+ arg_inputs=[input],
+ retrace=True,
+ )
trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
outputs_trt_deser = trt_ts_module(input)
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
c2e52bf
to
5773951
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
[[tool.uv.index]] | ||
name = "pytorch-nightly-cu128" | ||
url = "https://download.pytorch.org/whl/nightly/cu128" | ||
explicit = false | ||
|
||
[[tool.uv.index]] | ||
name = "jetson-containers" | ||
url = "https://pypi.jetson-ai-lab.dev/jp6/cu126" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a page : https://pypi.jetson-ai-lab.dev/jp6/cu128
should this be cu128 ?
Description
Adds an option to
torch_tensorrt.save
to save a module using the aot_inductor workflow.Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: