Skip to content

Commit 5773951

Browse files
committed
feat: Saving modules using the AOTI format
1 parent 60863a3 commit 5773951

File tree

9 files changed

+327
-459
lines changed

9 files changed

+327
-459
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,5 @@ MODULE.bazel.lock
7878
*.whl
7979
.coverage
8080
coverage.xml
81-
*.log
81+
*.log
82+
*.pt2

docsrc/user_guide/saving_models.rst

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.sav
1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18-
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`), `ExportedProgram` (`torch.export.ExportedProgram`) or `PT2` formats by
1919
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

2121
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
2222
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
23+
* `PT2 Format` : This is a next generation runtime for PyTorch models, allowing them to run in Python and in C++
2324

2425
a) ExportedProgram
2526
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -52,8 +53,8 @@ b) Torchscript
5253
model = MyModel().eval().cuda()
5354
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5455
# trt_gm is a torch.fx.GraphModule object
55-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
56-
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
56+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
57+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)
5758
5859
# Later, you can load it and run inference
5960
model = torch.jit.load("trt.ts").cuda()
@@ -73,7 +74,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7374
7475
model = MyModel().eval().cuda()
7576
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
76-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
77+
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
7778
torch.jit.save(trt_ts, "trt_model.ts")
7879
7980
# Later, you can load it and run inference
@@ -98,3 +99,26 @@ Here's an example usage
9899
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99100
model = torch_tensorrt.load(<file_path>).module()
100101
model(*inputs)
102+
103+
b) PT2 Format
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
PT2 is a new format that allows models to be run outside of Python in the future. It utilizes `AOTInductor <https://docs.pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_
107+
to generate kernels for components that will not be run in TensorRT.
108+
109+
Here's an example on how to save and load Torch-TensorRT Module using AOTInductor in Python
110+
111+
.. code-block:: python
112+
113+
import torch
114+
import torch_tensorrt
115+
116+
model = MyModel().eval().cuda()
117+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
118+
# trt_ep is a torch.fx.GraphModule object
119+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
120+
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)
121+
122+
# Later, you can load it and run inference
123+
model = torch._inductor.aoti_load_package("trt.pt2")
124+
model(*inputs)
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
CXX=g++
22
DEP_DIR=$(PWD)/deps
3-
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include
4-
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib # -Wl,-rpath $(DEP_DIR)/tensorrt/lib
5-
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
3+
CUDA_HOME=/usr/local/cuda
4+
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include -I$(CUDA_HOME)/include -I$(DEP_DIR)/libtorch/include/torch/csrc/api/include
5+
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib -Wl,-rpath $(DEP_DIR)/tensorrt/lib
6+
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -ltorchtrt_plugins -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
67
SRCS=main.cpp
78

89
TARGET=torchtrt_runtime_example
910

1011
$(TARGET):
1112
$(CXX) $(SRCS) $(INCLUDE_DIRS) $(LIB_DIRS) $(LIBS) -o $(TARGET)
13+
echo "Add to LD_LIBRARY_PATH: $(DEP_DIR)/torch_tensorrt/lib:$(DEP_DIR)/libtorch/lib:$(DEP_DIR)/tensorrt/lib:$(CUDA_HOME)/lib64"
14+
15+
generate_pt2:
16+
uv run network.py
1217

1318
clean:
1419
$(RM) $(TARGET)

examples/torchtrt_runtime_example/main.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,63 @@
33
#include <memory>
44
#include <sstream>
55
#include <vector>
6-
#include "torch/script.h"
6+
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
8+
#include "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h"
9+
10+
/*
11+
* This example demonstrates how to load and run a pre-built Torch-TensorRT
12+
* AOTInductor (AOTI) model package using the PyTorch C++ API.
13+
*
14+
* Usage:
15+
* torchtrt_runtime_example <path-to-pre-built-trt-aoti module>
16+
*
17+
* Steps:
18+
* 1. Parse the path to the AOTI model package from the command line.
19+
* 2. Load the model package using AOTIModelPackageLoader.
20+
* 3. Prepare a random CUDA tensor as input.
21+
* 4. Run inference using the loaded model.
22+
* 5. Print the output tensor(s) or an error message if inference fails.
23+
*/
724

825
int main(int argc, const char* argv[]) {
26+
// Check for correct number of command-line arguments
927
if (argc < 2) {
10-
std::cerr << "usage: samplertapp <path-to-pre-built-trt-ts module>\n";
28+
std::cerr << "usage: torchtrt_runtime_example <path-to-pre-built-trt-aoti module>\n";
1129
return -1;
1230
}
1331

14-
std::string trt_ts_module_path = argv[1];
32+
// Get the path to the TRT AOTI model package from the command line
33+
std::string trt_aoti_module_path = argv[1];
1534

16-
torch::jit::Module trt_ts_mod;
35+
// Enable inference mode for thread-local optimizations
36+
c10::InferenceMode mode;
1737
try {
18-
// Deserialize the ScriptModule from a file using torch::jit::load().
19-
trt_ts_mod = torch::jit::load(trt_ts_module_path);
38+
// Load the AOTI model package
39+
torch::inductor::AOTIModelPackageLoader runner(trt_aoti_module_path);
40+
41+
// Create a random input tensor on CUDA with shape [1, 3, 5, 5] and type float32
42+
std::vector<at::Tensor> inputs = {at::randn({1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32)};
43+
44+
// Run inference using the loaded model
45+
std::vector<at::Tensor> outputs = runner.run(inputs);
46+
47+
// Process and print the output tensor(s)
48+
if (!outputs.empty()) {
49+
std::cout << "Model output: " << outputs[0] << std::endl;
50+
} else {
51+
std::cerr << "No output tensors received!" << std::endl;
52+
}
53+
2054
} catch (const c10::Error& e) {
21-
std::cerr << "error loading the model from : " << trt_ts_module_path << std::endl;
22-
return -1;
55+
// Handle errors from the PyTorch C++ API
56+
std::cerr << "Error running model: " << e.what() << std::endl;
57+
return 1;
58+
} catch (const std::exception& e) {
59+
// Handle other standard exceptions
60+
std::cerr << "An unexpected error occurred: " << e.what() << std::endl;
61+
return 1;
2362
}
2463

25-
std::cout << "Running TRT engine" << std::endl;
26-
std::vector<torch::jit::IValue> trt_inputs_ivalues;
27-
trt_inputs_ivalues.push_back(at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32));
28-
torch::jit::IValue trt_results_ivalues = trt_ts_mod.forward(trt_inputs_ivalues);
29-
std::cout << "==================TRT outputs================" << std::endl;
30-
std::cout << trt_results_ivalues << std::endl;
31-
std::cout << "=============================================" << std::endl;
32-
std::cout << "TRT engine execution completed. " << std::endl;
64+
return 0;
3365
}

examples/torchtrt_runtime_example/network.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
import torch_tensorrt as torchtrt
3+
import torch_tensorrt
44

55

66
# create a simple norm layer.
@@ -29,21 +29,41 @@ def forward(self, x):
2929

3030
def main():
3131
model = ConvGelu().eval().cuda()
32-
scripted_model = torch.jit.script(model)
33-
32+
torch_ex_input = torch.randn([1, 3, 5, 5], device="cuda")
3433
compile_settings = {
35-
"inputs": [torchtrt.Input([1, 3, 5, 5])],
34+
"arg_inputs": [torch_ex_input],
35+
"ir": "dynamo",
3636
"enabled_precisions": {torch.float32},
37+
"min_block_size": 1,
3738
}
3839

39-
trt_ts_module = torchtrt.compile(scripted_model, **compile_settings)
40-
torch.jit.save(trt_ts_module, "conv_gelu.jit")
40+
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
41+
torch_tensorrt.save(
42+
cg_trt_module,
43+
file_path="torchtrt_aoti_conv_gelu.pt2",
44+
output_format="aot_inductor",
45+
retrace=True,
46+
arg_inputs=[torch_ex_input],
47+
)
4148

4249
norm_model = Norm().eval().cuda()
43-
norm_ts_module = torch.jit.script(norm_model)
44-
norm_trt_ts = torchtrt.compile(norm_ts_module, **compile_settings)
45-
torch.jit.save(norm_trt_ts, "norm.jit")
46-
print("Generated Torchscript-TRT models.")
50+
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
51+
torch_tensorrt.save(
52+
norm_trt_module,
53+
file_path="torchtrt_aoti_norm.pt2",
54+
output_format="aot_inductor",
55+
retrace=True,
56+
arg_inputs=[torch_ex_input],
57+
)
58+
print("Generated TorchTRT-AOTI models.")
59+
60+
loaded_cg_trt_module = torch._inductor.aoti_load_package(
61+
"torchtrt_aoti_conv_gelu.pt2"
62+
)
63+
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
64+
with torch.inference_mode():
65+
print(loaded_cg_trt_module(torch_ex_input))
66+
print(loaded_norm_trt_module(torch_ex_input))
4767

4868

4969
if __name__ == "__main__":

py/torch_tensorrt/_compile.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def save(
586586
kwarg_inputs: Optional[dict[str, Any]] = None,
587587
retrace: bool = False,
588588
pickle_protocol: int = 2,
589+
**kwargs: Any,
589590
) -> None:
590591
"""
591592
Save the model to disk in the specified output format.
@@ -595,15 +596,15 @@ def save(
595596
inputs (torch.Tensor): Torch input tensors
596597
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
597598
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
598-
output_format (str): Format to save the model. Options include exported_program | torchscript.
599+
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
599600
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
600601
This flag is experimental for now.
601602
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
602603
"""
603604
if isinstance(module, CudaGraphsTorchTensorRTModule):
604605
module = module.compiled_module
605606
module_type = _parse_module_type(module)
606-
accepted_formats = {"exported_program", "torchscript"}
607+
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
607608
if arg_inputs is not None and not all(
608609
isinstance(input, torch.Tensor) for input in arg_inputs
609610
):
@@ -634,9 +635,9 @@ def save(
634635
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
635636
)
636637
elif module_type == _ModuleType.ts:
637-
if output_format == "exported_program":
638+
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
638639
raise ValueError(
639-
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
640+
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
640641
)
641642
else:
642643
if arg_inputs is not None:
@@ -654,7 +655,22 @@ def save(
654655
logger.warning(
655656
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
656657
)
657-
torch.export.save(module, file_path)
658+
if output_format == "exported_program":
659+
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
660+
elif output_format == "aot_inductor":
661+
inductor_configs = {}
662+
if "inductor_configs" in kwargs:
663+
inductor_configs = kwargs["inductor_configs"]
664+
665+
torch._inductor.aoti_compile_and_package(
666+
exp_program,
667+
inductor_configs=inductor_configs,
668+
package_path=file_path,
669+
)
670+
else:
671+
raise RuntimeError(
672+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
673+
)
658674
elif module_type == _ModuleType.fx:
659675
# The module type is torch.fx.GraphModule
660676
if output_format == "torchscript":
@@ -671,9 +687,24 @@ def save(
671687
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
672688
)
673689
exp_program = export(module)
674-
torch.export.save(
675-
exp_program, file_path, pickle_protocol=pickle_protocol
676-
)
690+
if output_format == "exported_program":
691+
torch.export.save(
692+
exp_program, file_path, pickle_protocol=pickle_protocol
693+
)
694+
elif output_format == "aot_inductor":
695+
inductor_configs = {}
696+
if "inductor_configs" in kwargs:
697+
inductor_configs = kwargs["inductor_configs"]
698+
699+
torch._inductor.aoti_compile_and_package(
700+
exp_program,
701+
inductor_configs=inductor_configs,
702+
package_path=file_path,
703+
)
704+
else:
705+
raise RuntimeError(
706+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
707+
)
677708
else:
678709
if arg_inputs is None:
679710
raise ValueError(
@@ -685,6 +716,22 @@ def save(
685716
kwargs=kwarg_inputs,
686717
strict=False,
687718
)
688-
torch.export.save(
689-
exp_program, file_path, pickle_protocol=pickle_protocol
690-
)
719+
720+
if output_format == "exported_program":
721+
torch.export.save(
722+
exp_program, file_path, pickle_protocol=pickle_protocol
723+
)
724+
elif output_format == "aot_inductor":
725+
inductor_configs = {}
726+
if "inductor_configs" in kwargs:
727+
inductor_configs = kwargs["inductor_configs"]
728+
729+
torch._inductor.aoti_compile_and_package(
730+
exp_program,
731+
inductor_configs=inductor_configs,
732+
package_path=file_path,
733+
)
734+
else:
735+
raise RuntimeError(
736+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
737+
)

0 commit comments

Comments
 (0)