Skip to content

feat: Saving modules using the AOTI format #3567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

narendasan
Copy link
Collaborator

Description

Adds an option to torch_tensorrt.save to save a module using the aot_inductor workflow.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: build system Issues re: Build system component: api [Python] Issues re: Python API labels Jun 11, 2025
@github-actions github-actions bot requested a review from peri044 June 11, 2025 21:24
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:24:19.021193+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:24:42.171661+00:00
@@ -32,27 +32,31 @@
    torch_ex_input = torch.randn([1, 3, 5, 5])
    compile_settings = {
        "arg_inputs": [torch_ex_input],
        "ir": "dynamo",
        "enabled_precisions": {torch.float32},
-        "min_block_size": 1
+        "min_block_size": 1,
    }

    cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
-    torch_tensorrt.save(cg_trt_module,
-                        file_path="torchtrt_aoti_conv_gelu.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        cg_trt_module,
+        file_path="torchtrt_aoti_conv_gelu.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )

    norm_model = Norm().eval().cuda()
    norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
-    torch_tensorrt.save(norm_trt_module,
-                        file_path="torchtrt_aoti_norm.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        norm_trt_module,
+        file_path="torchtrt_aoti_norm.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )
    print("Generated TorchTRT-AOTI models.")


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:24:19.028193+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:24:42.611815+00:00
@@ -662,11 +662,13 @@
                    # [Optional] Specify the generated shared library path. If not specified,
                    # the generated artifact is stored in your system temp directory.
                    package_path=file_path,
                )
            else:
-                raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                raise RuntimeError(
+                    "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                )
    elif module_type == _ModuleType.fx:
        # The module type is torch.fx.GraphModule
        if output_format == "torchscript":
            module_ts = torch.jit.trace(
                module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -680,20 +682,24 @@
                    logger.warning(
                        "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
                    )
                exp_program = export(module)
                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    torch._inductor.aoti_compile_and_package(
                        exp_program,
                        # [Optional] Specify the generated shared library path. If not specified,
                        # the generated artifact is stored in your system temp directory.
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )
            else:
                if arg_inputs is None:
                    raise ValueError(
                        "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
                    )
@@ -703,11 +709,13 @@
                    kwargs=kwarg_inputs,
                    strict=False,
                )

                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -715,6 +723,8 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:26:05.932035+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:26:29.367903+00:00
@@ -32,27 +32,31 @@
    torch_ex_input = torch.randn([1, 3, 5, 5])
    compile_settings = {
        "arg_inputs": [torch_ex_input],
        "ir": "dynamo",
        "enabled_precisions": {torch.float32},
-        "min_block_size": 1
+        "min_block_size": 1,
    }

    cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
-    torch_tensorrt.save(cg_trt_module,
-                        file_path="torchtrt_aoti_conv_gelu.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        cg_trt_module,
+        file_path="torchtrt_aoti_conv_gelu.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )

    norm_model = Norm().eval().cuda()
    norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
-    torch_tensorrt.save(norm_trt_module,
-                        file_path="torchtrt_aoti_norm.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        norm_trt_module,
+        file_path="torchtrt_aoti_norm.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )
    print("Generated TorchTRT-AOTI models.")


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:26:05.939035+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:26:29.897101+00:00
@@ -665,11 +665,13 @@
                    exp_program,
                    inductor_configs=inductor_configs,
                    package_path=file_path,
                )
            else:
-                raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                raise RuntimeError(
+                    "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                )
    elif module_type == _ModuleType.fx:
        # The module type is torch.fx.GraphModule
        if output_format == "torchscript":
            module_ts = torch.jit.trace(
                module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
                    logger.warning(
                        "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
                    )
                exp_program = export(module)
                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -695,11 +699,13 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )
            else:
                if arg_inputs is None:
                    raise ValueError(
                        "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
                    )
@@ -709,11 +715,13 @@
                    kwargs=kwarg_inputs,
                    strict=False,
                )

                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -721,6 +729,8 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jun 11, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:53:06.019454+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 21:53:31.989118+00:00
@@ -32,27 +32,31 @@
    torch_ex_input = torch.randn([1, 3, 5, 5])
    compile_settings = {
        "arg_inputs": [torch_ex_input],
        "ir": "dynamo",
        "enabled_precisions": {torch.float32},
-        "min_block_size": 1
+        "min_block_size": 1,
    }

    cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
-    torch_tensorrt.save(cg_trt_module,
-                        file_path="torchtrt_aoti_conv_gelu.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        cg_trt_module,
+        file_path="torchtrt_aoti_conv_gelu.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )

    norm_model = Norm().eval().cuda()
    norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
-    torch_tensorrt.save(norm_trt_module,
-                        file_path="torchtrt_aoti_norm.pt2",
-                        output_format="aot_inductor",
-                        retrace=True,
-                        arg_inputs=[torch_ex_input])
+    torch_tensorrt.save(
+        norm_trt_module,
+        file_path="torchtrt_aoti_norm.pt2",
+        output_format="aot_inductor",
+        retrace=True,
+        arg_inputs=[torch_ex_input],
+    )
    print("Generated TorchTRT-AOTI models.")


if __name__ == "__main__":
    main()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:53:06.026454+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:53:32.496806+00:00
@@ -665,11 +665,13 @@
                    exp_program,
                    inductor_configs=inductor_configs,
                    package_path=file_path,
                )
            else:
-                raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                raise RuntimeError(
+                    "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                )
    elif module_type == _ModuleType.fx:
        # The module type is torch.fx.GraphModule
        if output_format == "torchscript":
            module_ts = torch.jit.trace(
                module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
                    logger.warning(
                        "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
                    )
                exp_program = export(module)
                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -695,11 +699,13 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )
            else:
                if arg_inputs is None:
                    raise ValueError(
                        "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
                    )
@@ -709,11 +715,13 @@
                    kwargs=kwarg_inputs,
                    strict=False,
                )

                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -721,6 +729,8 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )

@narendasan narendasan force-pushed the push-mlwrqkkovmoz branch 2 times, most recently from b6e9520 to f9280a4 Compare June 11, 2025 21:56
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:56:04.736524+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py	2025-06-11 21:56:27.792951+00:00
@@ -665,11 +665,13 @@
                    exp_program,
                    inductor_configs=inductor_configs,
                    package_path=file_path,
                )
            else:
-                raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                raise RuntimeError(
+                    "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                )
    elif module_type == _ModuleType.fx:
        # The module type is torch.fx.GraphModule
        if output_format == "torchscript":
            module_ts = torch.jit.trace(
                module, arg_inputs, example_kwarg_inputs=kwarg_inputs
@@ -683,11 +685,13 @@
                    logger.warning(
                        "Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
                    )
                exp_program = export(module)
                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -695,11 +699,13 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )
            else:
                if arg_inputs is None:
                    raise ValueError(
                        "Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
                    )
@@ -709,11 +715,13 @@
                    kwargs=kwarg_inputs,
                    strict=False,
                )

                if output_format == "exported_program":
-                    torch.export.save(exp_program, file_path, pickle_protocol=pickle_protocol)
+                    torch.export.save(
+                        exp_program, file_path, pickle_protocol=pickle_protocol
+                    )
                elif output_format == "aot_inductor":
                    inductor_configs = {}
                    if "inductor_configs" in kwargs:
                        inductor_configs = kwargs["inductor_configs"]

@@ -721,6 +729,8 @@
                        exp_program,
                        inductor_configs=inductor_configs,
                        package_path=file_path,
                    )
                else:
-                    raise RuntimeError("Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor")
+                    raise RuntimeError(
+                        "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
+                    )

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:07:25.926913+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:07:49.213965+00:00
@@ -55,15 +55,16 @@
        retrace=True,
        arg_inputs=[torch_ex_input],
    )
    print("Generated TorchTRT-AOTI models.")

-    loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+    loaded_cg_trt_module = torch._inductor.aoti_load_package(
+        "torchtrt_aoti_conv_gelu.pt2"
+    )
    loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
    with torch.inference_mode():
        print(loaded_cg_trt_module(torch_ex_input))
        print(loaded_norm_trt_module(torch_ex_input))


-
if __name__ == "__main__":
    main()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:08:42.860192+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:09:06.447260+00:00
@@ -55,15 +55,16 @@
        retrace=True,
        arg_inputs=[torch_ex_input],
    )
    print("Generated TorchTRT-AOTI models.")

-    loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+    loaded_cg_trt_module = torch._inductor.aoti_load_package(
+        "torchtrt_aoti_conv_gelu.pt2"
+    )
    loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
    with torch.inference_mode():
        print(loaded_cg_trt_module(torch_ex_input))
        print(loaded_norm_trt_module(torch_ex_input))


-
if __name__ == "__main__":
    main()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:09:26.168027+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/torchtrt_runtime_example/network.py	2025-06-11 22:09:48.306309+00:00
@@ -55,15 +55,16 @@
        retrace=True,
        arg_inputs=[torch_ex_input],
    )
    print("Generated TorchTRT-AOTI models.")

-    loaded_cg_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_conv_gelu.pt2")
+    loaded_cg_trt_module = torch._inductor.aoti_load_package(
+        "torchtrt_aoti_conv_gelu.pt2"
+    )
    loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
    with torch.inference_mode():
        print(loaded_cg_trt_module(torch_ex_input))
        print(loaded_norm_trt_module(torch_ex_input))


-
if __name__ == "__main__":
    main()

@narendasan narendasan force-pushed the push-mlwrqkkovmoz branch 2 times, most recently from b23bab2 to ef51eac Compare June 11, 2025 23:42
@github-actions github-actions bot added the component: tests Issues re: Tests label Jun 11, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:42:57.102541+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:43:25.570042+00:00
@@ -595,10 +595,11 @@
    assertions.assertTrue(
        cos_sim > COSINE_THRESHOLD,
        msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
    )

+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
    """
    This tests save/load API on the AOTI format
    """
@@ -630,11 +631,17 @@
        isinstance(trt_gm, torch.fx.GraphModule),
        msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
    )
    outputs_trt = trt_gm(input)
    print(f"{tmp_path}/trt.pt2")
-    torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+    torchtrt.save(
+        trt_gm,
+        f"{tmp_path}/trt.pt2",
+        output_format="aot_inductor",
+        arg_inputs=[input],
+        retrace=True,
+    )

    trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
    outputs_trt_deser = trt_ts_module(input)

    cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)

@narendasan narendasan force-pushed the push-mlwrqkkovmoz branch 2 times, most recently from 8bb15a8 to 74b7a97 Compare June 11, 2025 23:44
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:44:22.235633+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:44:51.642890+00:00
@@ -595,10 +595,11 @@
    assertions.assertTrue(
        cos_sim > COSINE_THRESHOLD,
        msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
    )

+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
    """
    This tests save/load API on the AOTI format
    """
@@ -630,11 +631,17 @@
        isinstance(trt_gm, torch.fx.GraphModule),
        msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
    )
    outputs_trt = trt_gm(input)
    print(f"{tmp_path}/trt.pt2")
-    torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+    torchtrt.save(
+        trt_gm,
+        f"{tmp_path}/trt.pt2",
+        output_format="aot_inductor",
+        arg_inputs=[input],
+        retrace=True,
+    )

    trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
    outputs_trt_deser = trt_ts_module(input)

    cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:44:54.315219+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-11 23:45:21.971652+00:00
@@ -595,10 +595,11 @@
    assertions.assertTrue(
        cos_sim > COSINE_THRESHOLD,
        msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
    )

+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
    """
    This tests save/load API on the AOTI format
    """
@@ -630,11 +631,17 @@
        isinstance(trt_gm, torch.fx.GraphModule),
        msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
    )
    outputs_trt = trt_gm(input)
    print(f"{tmp_path}/trt.pt2")
-    torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+    torchtrt.save(
+        trt_gm,
+        f"{tmp_path}/trt.pt2",
+        output_format="aot_inductor",
+        arg_inputs=[input],
+        retrace=True,
+    )

    trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
    outputs_trt_deser = trt_ts_module(input)

    cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-13 18:41:08.235067+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_export_serde.py	2025-06-13 18:41:43.974199+00:00
@@ -595,10 +595,11 @@
    assertions.assertTrue(
        cos_sim > COSINE_THRESHOLD,
        msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
    )

+
@pytest.mark.unit
def test_save_load_aoti(ir, tmp_path):
    """
    This tests save/load API on the AOTI format
    """
@@ -630,11 +631,17 @@
        isinstance(trt_gm, torch.fx.GraphModule),
        msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
    )
    outputs_trt = trt_gm(input)
    print(f"{tmp_path}/trt.pt2")
-    torchtrt.save(trt_gm, f"{tmp_path}/trt.pt2", output_format="aot_inductor", arg_inputs=[input], retrace=True)
+    torchtrt.save(
+        trt_gm,
+        f"{tmp_path}/trt.pt2",
+        output_format="aot_inductor",
+        arg_inputs=[input],
+        retrace=True,
+    )

    trt_ts_module = torch._inductor.aoti_load_package(f"{tmp_path}/trt.pt2")
    outputs_trt_deser = trt_ts_module(input)

    cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)

Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


[[tool.uv.index]]
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = false

[[tool.uv.index]]
name = "jetson-containers"
url = "https://pypi.jetson-ai-lab.dev/jp6/cu126"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a page : https://pypi.jetson-ai-lab.dev/jp6/cu128
should this be cu128 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: tests Issues re: Tests documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants