Skip to content

Commit 8bb15a8

Browse files
committed
feat: Saving modules using the AOTI format
1 parent f09be72 commit 8bb15a8

File tree

9 files changed

+320
-459
lines changed

9 files changed

+320
-459
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,5 @@ MODULE.bazel.lock
7878
*.whl
7979
.coverage
8080
coverage.xml
81-
*.log
81+
*.log
82+
*.pt2

docsrc/user_guide/saving_models.rst

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.sav
1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18-
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`), `ExportedProgram` (`torch.export.ExportedProgram`) or `PT2` formats by
1919
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

2121
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
2222
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
23+
* `PT2 Format` : This is a next generation runtime for PyTorch models, allowing them to run in Python and in C++
2324

2425
a) ExportedProgram
2526
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -52,8 +53,8 @@ b) Torchscript
5253
model = MyModel().eval().cuda()
5354
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5455
# trt_gm is a torch.fx.GraphModule object
55-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
56-
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
56+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
57+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)
5758
5859
# Later, you can load it and run inference
5960
model = torch.jit.load("trt.ts").cuda()
@@ -73,7 +74,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7374
7475
model = MyModel().eval().cuda()
7576
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
76-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
77+
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
7778
torch.jit.save(trt_ts, "trt_model.ts")
7879
7980
# Later, you can load it and run inference
@@ -98,3 +99,26 @@ Here's an example usage
9899
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99100
model = torch_tensorrt.load(<file_path>).module()
100101
model(*inputs)
102+
103+
b) PT2 Format
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
PT2 is a new format that allows models to be run outside of Python in the future. It utilizes `AOTInductor <https://docs.pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_
107+
to generate kernels for components that will not be run in TensorRT.
108+
109+
Here's an example on how to save and load Torch-TensorRT Module using AOTInductor in Python
110+
111+
.. code-block:: python
112+
113+
import torch
114+
import torch_tensorrt
115+
116+
model = MyModel().eval().cuda()
117+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
118+
# trt_ep is a torch.fx.GraphModule object
119+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
120+
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)
121+
122+
# Later, you can load it and run inference
123+
model = torch._inductor.aoti_load_package("trt.pt2")
124+
model(*inputs)
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
CXX=g++
22
DEP_DIR=$(PWD)/deps
3-
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include
4-
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib # -Wl,-rpath $(DEP_DIR)/tensorrt/lib
5-
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
3+
CUDA_HOME=/usr/local/cuda
4+
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include -I$(CUDA_HOME)/include -I$(DEP_DIR)/libtorch/include/torch/csrc/api/include
5+
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib -Wl,-rpath $(DEP_DIR)/tensorrt/lib
6+
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -ltorchtrt_plugins -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
67
SRCS=main.cpp
78

89
TARGET=torchtrt_runtime_example
910

1011
$(TARGET):
1112
$(CXX) $(SRCS) $(INCLUDE_DIRS) $(LIB_DIRS) $(LIBS) -o $(TARGET)
13+
echo "Add to LD_LIBRARY_PATH: $(DEP_DIR)/torch_tensorrt/lib:$(DEP_DIR)/libtorch/lib:$(DEP_DIR)/tensorrt/lib:$(CUDA_HOME)/lib64"
14+
15+
generate_pt2:
16+
uv run network.py
1217

1318
clean:
1419
$(RM) $(TARGET)

examples/torchtrt_runtime_example/main.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,63 @@
33
#include <memory>
44
#include <sstream>
55
#include <vector>
6-
#include "torch/script.h"
6+
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
8+
#include "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h"
9+
10+
/*
11+
* This example demonstrates how to load and run a pre-built Torch-TensorRT
12+
* AOTInductor (AOTI) model package using the PyTorch C++ API.
13+
*
14+
* Usage:
15+
* torchtrt_runtime_example <path-to-pre-built-trt-aoti module>
16+
*
17+
* Steps:
18+
* 1. Parse the path to the AOTI model package from the command line.
19+
* 2. Load the model package using AOTIModelPackageLoader.
20+
* 3. Prepare a random CUDA tensor as input.
21+
* 4. Run inference using the loaded model.
22+
* 5. Print the output tensor(s) or an error message if inference fails.
23+
*/
724

825
int main(int argc, const char* argv[]) {
26+
// Check for correct number of command-line arguments
927
if (argc < 2) {
10-
std::cerr << "usage: samplertapp <path-to-pre-built-trt-ts module>\n";
28+
std::cerr << "usage: torchtrt_runtime_example <path-to-pre-built-trt-aoti module>\n";
1129
return -1;
1230
}
1331

14-
std::string trt_ts_module_path = argv[1];
32+
// Get the path to the TRT AOTI model package from the command line
33+
std::string trt_aoti_module_path = argv[1];
1534

16-
torch::jit::Module trt_ts_mod;
35+
// Enable inference mode for thread-local optimizations
36+
c10::InferenceMode mode;
1737
try {
18-
// Deserialize the ScriptModule from a file using torch::jit::load().
19-
trt_ts_mod = torch::jit::load(trt_ts_module_path);
38+
// Load the AOTI model package
39+
torch::inductor::AOTIModelPackageLoader runner(trt_aoti_module_path);
40+
41+
// Create a random input tensor on CUDA with shape [1, 3, 5, 5] and type float32
42+
std::vector<at::Tensor> inputs = {at::randn({1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32)};
43+
44+
// Run inference using the loaded model
45+
std::vector<at::Tensor> outputs = runner.run(inputs);
46+
47+
// Process and print the output tensor(s)
48+
if (!outputs.empty()) {
49+
std::cout << "Model output: " << outputs[0] << std::endl;
50+
} else {
51+
std::cerr << "No output tensors received!" << std::endl;
52+
}
53+
2054
} catch (const c10::Error& e) {
21-
std::cerr << "error loading the model from : " << trt_ts_module_path << std::endl;
22-
return -1;
55+
// Handle errors from the PyTorch C++ API
56+
std::cerr << "Error running model: " << e.what() << std::endl;
57+
return 1;
58+
} catch (const std::exception& e) {
59+
// Handle other standard exceptions
60+
std::cerr << "An unexpected error occurred: " << e.what() << std::endl;
61+
return 1;
2362
}
2463

25-
std::cout << "Running TRT engine" << std::endl;
26-
std::vector<torch::jit::IValue> trt_inputs_ivalues;
27-
trt_inputs_ivalues.push_back(at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32));
28-
torch::jit::IValue trt_results_ivalues = trt_ts_mod.forward(trt_inputs_ivalues);
29-
std::cout << "==================TRT outputs================" << std::endl;
30-
std::cout << trt_results_ivalues << std::endl;
31-
std::cout << "=============================================" << std::endl;
32-
std::cout << "TRT engine execution completed. " << std::endl;
64+
return 0;
3365
}

examples/torchtrt_runtime_example/network.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
import torch_tensorrt as torchtrt
3+
import torch_tensorrt
44

55

66
# create a simple norm layer.
@@ -29,21 +29,41 @@ def forward(self, x):
2929

3030
def main():
3131
model = ConvGelu().eval().cuda()
32-
scripted_model = torch.jit.script(model)
33-
32+
torch_ex_input = torch.randn([1, 3, 5, 5], device="cuda")
3433
compile_settings = {
35-
"inputs": [torchtrt.Input([1, 3, 5, 5])],
34+
"arg_inputs": [torch_ex_input],
35+
"ir": "dynamo",
3636
"enabled_precisions": {torch.float32},
37+
"min_block_size": 1,
3738
}
3839

39-
trt_ts_module = torchtrt.compile(scripted_model, **compile_settings)
40-
torch.jit.save(trt_ts_module, "conv_gelu.jit")
40+
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
41+
torch_tensorrt.save(
42+
cg_trt_module,
43+
file_path="torchtrt_aoti_conv_gelu.pt2",
44+
output_format="aot_inductor",
45+
retrace=True,
46+
arg_inputs=[torch_ex_input],
47+
)
4148

4249
norm_model = Norm().eval().cuda()
43-
norm_ts_module = torch.jit.script(norm_model)
44-
norm_trt_ts = torchtrt.compile(norm_ts_module, **compile_settings)
45-
torch.jit.save(norm_trt_ts, "norm.jit")
46-
print("Generated Torchscript-TRT models.")
50+
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
51+
torch_tensorrt.save(
52+
norm_trt_module,
53+
file_path="torchtrt_aoti_norm.pt2",
54+
output_format="aot_inductor",
55+
retrace=True,
56+
arg_inputs=[torch_ex_input],
57+
)
58+
print("Generated TorchTRT-AOTI models.")
59+
60+
loaded_cg_trt_module = torch._inductor.aoti_load_package(
61+
"torchtrt_aoti_conv_gelu.pt2"
62+
)
63+
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
64+
with torch.inference_mode():
65+
print(loaded_cg_trt_module(torch_ex_input))
66+
print(loaded_norm_trt_module(torch_ex_input))
4767

4868

4969
if __name__ == "__main__":

py/torch_tensorrt/_compile.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def save(
585585
kwarg_inputs: Optional[dict[str, Any]] = None,
586586
retrace: bool = False,
587587
pickle_protocol: int = 2,
588+
**kwargs: Any,
588589
) -> None:
589590
"""
590591
Save the model to disk in the specified output format.
@@ -594,15 +595,15 @@ def save(
594595
inputs (torch.Tensor): Torch input tensors
595596
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
596597
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
597-
output_format (str): Format to save the model. Options include exported_program | torchscript.
598+
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
598599
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
599600
This flag is experimental for now.
600601
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
601602
"""
602603
if isinstance(module, CudaGraphsTorchTensorRTModule):
603604
module = module.compiled_module
604605
module_type = _parse_module_type(module)
605-
accepted_formats = {"exported_program", "torchscript"}
606+
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
606607
if arg_inputs is not None and not all(
607608
isinstance(input, torch.Tensor) for input in arg_inputs
608609
):
@@ -633,9 +634,9 @@ def save(
633634
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
634635
)
635636
elif module_type == _ModuleType.ts:
636-
if output_format == "exported_program":
637+
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
637638
raise ValueError(
638-
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
639+
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
639640
)
640641
else:
641642
if arg_inputs is not None:
@@ -653,7 +654,22 @@ def save(
653654
logger.warning(
654655
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
655656
)
656-
torch.export.save(module, file_path)
657+
if output_format == "exported_program":
658+
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
659+
elif output_format == "aot_inductor":
660+
inductor_configs = {}
661+
if "inductor_configs" in kwargs:
662+
inductor_configs = kwargs["inductor_configs"]
663+
664+
torch._inductor.aoti_compile_and_package(
665+
exp_program,
666+
inductor_configs=inductor_configs,
667+
package_path=file_path,
668+
)
669+
else:
670+
raise RuntimeError(
671+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
672+
)
657673
elif module_type == _ModuleType.fx:
658674
# The module type is torch.fx.GraphModule
659675
if output_format == "torchscript":
@@ -670,9 +686,24 @@ def save(
670686
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
671687
)
672688
exp_program = export(module)
673-
torch.export.save(
674-
exp_program, file_path, pickle_protocol=pickle_protocol
675-
)
689+
if output_format == "exported_program":
690+
torch.export.save(
691+
exp_program, file_path, pickle_protocol=pickle_protocol
692+
)
693+
elif output_format == "aot_inductor":
694+
inductor_configs = {}
695+
if "inductor_configs" in kwargs:
696+
inductor_configs = kwargs["inductor_configs"]
697+
698+
torch._inductor.aoti_compile_and_package(
699+
exp_program,
700+
inductor_configs=inductor_configs,
701+
package_path=file_path,
702+
)
703+
else:
704+
raise RuntimeError(
705+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
706+
)
676707
else:
677708
if arg_inputs is None:
678709
raise ValueError(
@@ -684,6 +715,22 @@ def save(
684715
kwargs=kwarg_inputs,
685716
strict=False,
686717
)
687-
torch.export.save(
688-
exp_program, file_path, pickle_protocol=pickle_protocol
689-
)
718+
719+
if output_format == "exported_program":
720+
torch.export.save(
721+
exp_program, file_path, pickle_protocol=pickle_protocol
722+
)
723+
elif output_format == "aot_inductor":
724+
inductor_configs = {}
725+
if "inductor_configs" in kwargs:
726+
inductor_configs = kwargs["inductor_configs"]
727+
728+
torch._inductor.aoti_compile_and_package(
729+
exp_program,
730+
inductor_configs=inductor_configs,
731+
package_path=file_path,
732+
)
733+
else:
734+
raise RuntimeError(
735+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
736+
)

0 commit comments

Comments
 (0)